import React, { useContext, useEffect } from 'react'
import { round, orderBy, sum } from 'lodash'
import Table from '@mui/material/Table'
import TableBody from '@mui/material/TableBody'
import TableHead from '@mui/material/TableHead'
import TableRow from '@mui/material/TableRow'
import TableCell from '@mui/material/TableCell'
import makeStyles from '@mui/styles/makeStyles'
import Grid from '@mui/material/Grid'
import { useTranslation } from 'react-i18next'
import { ExportDataContext } from '../../../../components/widget/WidgetFactory/ExportDataContext'

export interface Trajectory {
  E_hat: number[][]
  Pi_hat: number[]
  Pi_patients: number[]
  R_hat: number[]
  T_hat: number[][]
  T_patients: number[][]
  diseases_dict: string[]
  statistics: Array<{
    id: string
    statistics: Array<{ name: string; value: number | number[] }>
  }>
}

interface Dictionary {
  content: {
    list: {
      [id: string]: string
    }
  }
}

interface TrajectoryOptions {
  relEmissionThreshold: number
  absTransitionThreshold: number
  relTransitionThreshold: number
}

export function getStateTable(
  trajectory: Trajectory,
  state: number | string,
  dictionary: Dictionary,
  options: TrajectoryOptions,
): Array<{
  condition: string
  probText: string
  nPatients: number
  prob: number
}> {
  if (typeof state === 'string') {
    state = parseInt(state)
  }
  const totalPatients = trajectory.T_patients[state][state]
  const stateEmissions = trajectory.E_hat[state]
  const emissionsTable = stateEmissions.map((prob, index) => ({
    condition: `${
      trajectory.diseases_dict[index].includes('=')
        ? trajectory.diseases_dict[index].split('=')[1]
        : (trajectory.diseases_dict[index] ?? '')
    } ${dictionary.content.list[trajectory.diseases_dict[index] ?? ''] ?? ''} `,
    probText: `${Math.round(prob * 100)}%`,
    nPatients: Math.round(totalPatients * prob),
    prob,
  }))
  return orderBy(
    emissionsTable.filter(({ prob }) => prob >= options.relEmissionThreshold),
    ['prob'],
    ['desc'],
  )
}

const transitionsToTable = (
  stateTransitionsRel: number[],
  stateTransitionsAbs: number[],
  options: TrajectoryOptions,
): Array<{
  prob: number
  nPatients: number
  state: string
}> => {
  const transitions = stateTransitionsRel
    .map((prob, index) => ({
      prob,
      nPatients: stateTransitionsAbs[index],
      state: String(index),
    }))
    .filter(
      ({ prob, nPatients }) =>
        prob >= options.relTransitionThreshold
        && nPatients >= options.absTransitionThreshold,
    )
  return transitions
}

export function getIncomingTable(
  trajectory: Trajectory,
  state: number | string,
  options: TrajectoryOptions,
): Array<{
  prob: number
  nPatients: number
  state: string
}> {
  if (typeof state === 'string') {
    state = parseInt(state)
  }
  const stateTransitionsRel = trajectory.T_hat.map((row) => row[state])
  const stateTransitionsAbs = trajectory.T_patients.map((row) => row[state])
  const totalIncomingPatients = Math.max(
    sum(stateTransitionsAbs.slice(0, state)) + trajectory.Pi_patients[state],
    1,
  )

  let transitions = transitionsToTable(
    stateTransitionsRel,
    stateTransitionsAbs,
    options,
  )
  if (
    trajectory.Pi_hat[state] >= options.relTransitionThreshold
    && trajectory.Pi_patients[state] >= options.absTransitionThreshold
  ) {
    transitions = [
      {
        prob: trajectory.Pi_hat[state],
        nPatients: trajectory.Pi_patients[state],
        state: 'start',
      },
      ...transitions,
    ]
  }
  const computedTransitions = orderBy(transitions, ['prob'], ['desc']).filter(
    (trans) => trans.state !== String(state),
  )
  return computedTransitions.map((transition) => ({
    ...transition,
    prob: transition.nPatients / totalIncomingPatients,
  }))
}

export function getOutgoingTable(
  trajectory: Trajectory,
  state: number | string,
  options: TrajectoryOptions,
): Array<{
  prob: number
  nPatients: number
  state: string
}> {
  if (typeof state === 'string') {
    state = parseInt(state)
  }
  const stateTransitionsRel = trajectory.T_hat[state]
  const stateTransitionsAbs = trajectory.T_patients[state]
  return orderBy(
    transitionsToTable(stateTransitionsRel, stateTransitionsAbs, options),
    ['prob'],
    ['desc'],
  ).filter((trans) => trans.state !== String(state))
}

const useStyles = makeStyles((theme) => ({
  tableroot: {
    width: '100%',
  },
  tableHeader: {
    fontWeight: 'bold',
  },
  tableTitleContainer: { borderBottom: '1px solid #aaaaaa' },
  tableTitle: {
    background: theme.palette.primary.main,
    borderRadius: '10px 10px 0px 0px',
    width: 180,
    fontWeight: 900,
    color: '#ffffff',
    fontSize: 18,
    padding: '2px 13px',
    margin: 0,
  },
}))

interface TableComponentProps {
  data: Array<{
    label: string
    probabilityText: string
    nPatients: number
  }>
  headers: string[]
  title: string
}

const TableTitle = ({ title }: { title: string }): React.JSX.Element => {
  const classes = useStyles()
  return (
    <div className={classes.tableTitleContainer}>
      <p className={classes.tableTitle}>{title}</p>
    </div>
  )
}

const TableComponent = ({
  data,
  headers,
  title,
}: TableComponentProps): React.JSX.Element => {
  const classes = useStyles()
  return (
    <>
      <TableTitle title={title} />
      <Table className={classes.tableroot}>
        <TableHead>
          {headers.map((header) => (
            <TableCell key={header} className={classes.tableHeader}>
              {header}
            </TableCell>
          ))}
        </TableHead>
        <TableBody>
          {data.map((row, index) => (
            <TableRow key={index}>
              <TableCell>{row.label}</TableCell>
              <TableCell>{row.probabilityText}</TableCell>
              <TableCell>{row.nPatients}</TableCell>
            </TableRow>
          ))}
        </TableBody>
      </Table>
    </>
  )
}

const StatisticsTableComponent = ({
  data,
  headers,
  title,
  experimentId,
  folderId,
}: {
  experimentId: string
  data: Array<{ key: string; value: string }>
  headers: string[]
  title: string
  folderId: string
}): React.JSX.Element => {
  const { t } = useTranslation()
  const classes = useStyles()
  return (
    <>
      <TableTitle title={title} />
      <Table
        key={`table-${experimentId}-${folderId}`}
        className={classes.tableroot}
      >
        <TableHead>
          {headers.map((header) => (
            <TableCell key={header} className={classes.tableHeader}>
              {header}
            </TableCell>
          ))}
        </TableHead>
        <TableBody>
          {data.map((row, index) => (
            <TableRow key={index}>
              <TableCell>{t(`trajectories.statistics.${row.key}`)}</TableCell>
              <TableCell>{row.value}</TableCell>
            </TableRow>
          ))}
        </TableBody>
      </Table>
    </>
  )
}

const getStatisticsTable = (
  trajectory: Trajectory,
  selectedState: string,
  metadata: TrajectoriesMetadata,
): Array<{
  key: string
  value: string
}> => {
  const numSequences = metadata?.num_sequences ?? 0
  const numAbsoring = metadata?.num_absorbing_events ?? 0
  const absorbedPatients = trajectory.statistics
    .slice(-numAbsoring)
    .reduce(
      (prev, cur) =>
        prev
        + ((cur.statistics.find((stat) => stat.name === 'last_state')
          ?.value as number) ?? 0),
      0,
    )
  const nonAbsorbedPatients = numSequences - absorbedPatients
  const transformValue = (
    name: string,
    value: number[] | number,
    count: number | undefined,
    numState: number,
  ): string => {
    if (
      count !== undefined
      && ['icu', 'exitus', 'casa', 'derivacio'].includes(name)
    ) {
      return `${value as number} (${round(((value as number) / count) * 100)}%)`
    }
    if (name === 'sex' && (value as number[]).length === 2) {
      const val = value as number[]
      const sum = val[0] + val[1]
      return `${round((val[0] / sum) * 100)}% ${round((val[1] / sum) * 100)}%`
    }
    if (name === 'mean_age') {
      return String(round(value as number, 2))
    }
    if (name === 'count' && numSequences > 0) {
      return String(
        `${value as number} (${round(
          ((value as number) / numSequences) * 100,
          2,
        )}%)`,
      )
    }
    if (name === 'last_state' && nonAbsorbedPatients > 0) {
      if (numState >= trajectory.statistics.length - numAbsoring) {
        return '-'
      } else {
        return `${value as number} (${round(
          ((value as number) / nonAbsorbedPatients) * 100,
          2,
        )}%)`
      }
    }
    return String(value)
  }
  const stateStatistics = trajectory.statistics.find(
    (state) => state.id === selectedState,
  )?.statistics
  if (stateStatistics === undefined) {
    return []
  }
  const count = stateStatistics.find((st) => st.name === 'count')?.value as
    | number
    | undefined
  const numState = parseInt(selectedState)
  const stats = stateStatistics.map((row) => ({
    key: row.name,
    value: transformValue(row.name, row.value, count, numState),
  }))
  const meanStay = {
    key: 'mean_stay',
    value:
      trajectory.R_hat.length <= numState
        ? '-'
        : `${round(trajectory.R_hat[numState], 2)} d`,
  }
  const startingPatients = {
    key: 'starting_patients',
    value:
      trajectory.Pi_hat.length <= numState
        ? '-'
        : `${trajectory.Pi_patients[numState]} (${round(
            trajectory.Pi_hat[numState] * 100,
            2,
          )}%)`,
  }
  return [
    ...stats.slice(0, 1),
    ...stats.slice(-1),
    meanStay,
    startingPatients,
    ...stats.slice(1, -1),
  ]
}

interface TrajectoriesMetadata {
  num_sequences?: number
  num_absorbing_events?: number
}

interface TrajectoriesTableProps {
  trajectory: Trajectory
  dictionary: Dictionary
  selectedState?: string
  options: TrajectoryOptions
  metadata: TrajectoriesMetadata
  experimentId: string
  folderId: string
}

const TrajectoriesTable = ({
  trajectory,
  dictionary,
  selectedState,
  options,
  metadata,
  experimentId,
  folderId,
}: TrajectoriesTableProps): React.JSX.Element => {
  const { t } = useTranslation()
  const { onDataChange } = useContext(ExportDataContext)

  const statsTable
    = selectedState !== undefined
      ? getStatisticsTable(trajectory, selectedState, metadata)
      : undefined
  const incomingTable
    = selectedState !== undefined
      ? getIncomingTable(trajectory, selectedState, options).map((row) => ({
          label: `des de l'estat ${row.state}`,
          probabilityText: `${Math.round(row.prob * 100)}%`,
          nPatients: row.nPatients,
        }))
      : undefined

  const outgoingTable
    = selectedState !== undefined
      ? getOutgoingTable(trajectory, selectedState, options).map((row) => ({
          label: `a l'estat ${row.state}`,
          probabilityText: `${Math.round(row.prob * 100)}%`,
          nPatients: row.nPatients,
        }))
      : undefined

  const stateTable
    = selectedState !== undefined
      ? getStateTable(trajectory, selectedState, dictionary, options).map(
          (row) => ({
            label: row.condition,
            probabilityText: row.probText,
            nPatients: row.nPatients,
          }),
        )
      : undefined

  useEffect(() => {
    if (
      statsTable !== undefined
      && incomingTable !== undefined
      && outgoingTable !== undefined
      && stateTable !== undefined
    ) {
      const exportData = {
        state: stateTable.map((row) => [
          row.label,
          row.probabilityText,
          row.nPatients,
        ]),
        statistics: statsTable.map((row) => [
          t(`trajectories.statistics.${row.key}`),
          row.value,
        ]),
        incoming: incomingTable.map((row) => [
          row.label,
          row.probabilityText,
          row.nPatients,
        ]),
        outgoing: outgoingTable.map((row) => [
          row.label,
          row.probabilityText,
          row.nPatients,
        ]),
      }
      onDataChange(exportData)
    }
  }, [trajectory, selectedState])

  if (
    selectedState === undefined
    || statsTable === undefined
    || incomingTable === undefined
    || outgoingTable === undefined
    || stateTable === undefined
  ) {
    onDataChange(null)
    return <div>{t('trajectories.noStateSelected')}</div>
  }

  return (
    <Grid container spacing={2}>
      <Grid item xs={12} lg={6} xl={3}>
        <StatisticsTableComponent
          experimentId={experimentId}
          folderId={folderId}
          title={t('trajectories.tableHeaders.statistics')}
          headers={['statistic', 'value'].map((v) =>
            t(`trajectories.tableHeaders.${v}`),
          )}
          data={statsTable}
        />
      </Grid>
      <Grid item xs={12} lg={6} xl={3}>
        <TableComponent
          title={t('trajectories.tableHeaders.incomingNodes')}
          headers={['transition', 'probability', 'nPatients'].map((v) =>
            t(`trajectories.tableHeaders.${v}`),
          )}
          data={incomingTable}
        />
      </Grid>
      <Grid item xs={12} lg={6} xl={3}>
        <TableComponent
          title={t('trajectories.tableHeaders.outgoingNodes')}
          headers={['transition', 'probability', 'nPatients'].map((v) =>
            t(`trajectories.tableHeaders.${v}`),
          )}
          data={outgoingTable}
        />
      </Grid>
      <Grid item xs={12} lg={6} xl={3}>
        <TableComponent
          title={t('trajectories.tableHeaders.state')}
          headers={['condition', 'probability', 'nPatients'].map((v) =>
            t(`trajectories.tableHeaders.${v}`),
          )}
          data={stateTable}
        />
      </Grid>
    </Grid>
  )
}

export default TrajectoriesTable
