import { z } from 'zod'

export const baseStatisticsValueSchema = z.object({
  operation_name: z.string(),
  fn: z.string(),
})
const partition = z.optional(
  z.array(z.union([z.string(), z.null(), z.number()])),
)

const statisticsCountSchema = baseStatisticsValueSchema.extend({
  fn: z.literal('count'),
  result: z.array(
    z.object({
      value: z.union([z.string(), z.number()]),
      percent: z.optional(z.number()),
      partition,
    }),
  ),
  args: z.object({
    by: z.string(),
    partition_variables: z.optional(z.array(z.string())),
  }),
})

const statisticsCountFilterSchema = baseStatisticsValueSchema.extend({
  fn: z.literal('count_filter'),
  result: z.array(
    z.object({
      value: z.union([z.string(), z.number()]),
      percent: z.number(),
    }),
  ),
  args: z.object({
    by: z.string(),
    filter: z.any(),
    filter_column: z.optional(z.string()),
  }),
})

const statisticsMinMaxValueSchema = baseStatisticsValueSchema.extend({
  fn: z.literal('min_max_value'),
  result: z.array(
    z.object({
      min: z.union([z.string(), z.number(), z.null()]),
      max: z.union([z.string(), z.number(), z.null()]),
    }),
  ),
  args: z.object({
    by: z.object({ min: z.string(), max: z.string() }),
  }),
})

const statisticsAvgTimeSchema = baseStatisticsValueSchema.extend({
  fn: z.literal('avg_time'),
  result: z.array(
    z.object({
      value: z.number(),
    }),
  ),
  args: z.object({
    by: z.object({ start: z.string(), end: z.string() }),
    time_unit: z.string(),
  }),
})

const statisticsReadmissionsSchema = baseStatisticsValueSchema.extend({
  fn: z.literal('readmissions'),
  result: z.array(
    z.object({
      value: z.number(),
      percent: z.number(),
    }),
  ),
  args: z.object({
    filter: z.any(),
    filter_column: z.string(),
    time_unit: z.string(),
    by: z.object({ start: z.string(), end: z.string() }),
    window: z.number(),
  }),
})

const statisticsMeanSchema = baseStatisticsValueSchema.extend({
  fn: z.literal('mean'),
  result: z.array(
    z.object({
      value: z.number(),
      partition,
    }),
  ),
  args: z.object({
    by: z.string(),
    partition_variables: z.optional(z.array(z.string())),
  }),
})

const statisticsTopCountsSchema = baseStatisticsValueSchema.extend({
  fn: z.literal('top_counts'),
  result: z.array(
    z.object({
      key: z.union([z.null(), z.string()]),
      value: z.number(),
      percent: z.optional(z.number()),
    }),
  ),
  args: z.object({
    by: z.string(),
    max_counts: z.number(),
    percent_denominator: z.optional(z.string()),
    hide_key: z.optional(z.boolean()),
  }),
})

const statisticsValueSchema = z.discriminatedUnion('fn', [
  statisticsCountSchema,
  statisticsMeanSchema,
  statisticsTopCountsSchema,
  statisticsCountFilterSchema,
  statisticsMinMaxValueSchema,
  statisticsAvgTimeSchema,
  statisticsReadmissionsSchema,
])

export type StatisticsMean = z.infer<typeof statisticsMeanSchema>
export type StatisticsCount = z.infer<typeof statisticsCountSchema>
export type StatisticsCountFilter = z.infer<typeof statisticsCountFilterSchema>
export type StatisticsMinMaxValue = z.infer<typeof statisticsMinMaxValueSchema>
export type StatisticsTopCounts = z.infer<typeof statisticsTopCountsSchema>
export type StatisticsAvgTime = z.infer<typeof statisticsAvgTimeSchema>
export type StatisticsReadmissions = z.infer<
  typeof statisticsReadmissionsSchema
>
export type StatisticsValue = z.infer<typeof statisticsValueSchema>
export const statisticsDataSchema = z.object({
  result: z.array(statisticsValueSchema),
})
export type StatisticsData = z.infer<typeof statisticsDataSchema>
export const statisticsClusteringSchema = z.array(statisticsDataSchema)
export type StatisticsClusteringData = z.infer<
  typeof statisticsClusteringSchema
>

export interface EncodingDict {
  [dictName: string]: string[]
}
