import * as React from "react"
import ReactDOM from "react-dom"
import { useEvent } from "react-use"
import {
  Flow,
  type MonetaryValue,
  type MoneyFlow,
  type Period,
} from "@digits-graphql/frontend/graphql-bearer"
import { hasFirstElement } from "@digits-shared/helpers/arrayHelper"
import moneyFlowHelper from "@digits-shared/helpers/moneyFlowHelper"
import numberHelper from "@digits-shared/helpers/numberHelper"
import objectHelper from "@digits-shared/helpers/objectHelper"
import useConstant from "@digits-shared/hooks/useConstant"
import { useModalRoot } from "@digits-shared/hooks/useModalRoot"
import { useIsPrintTheme } from "@digits-shared/themes"
import colors from "@digits-shared/themes/colors"
import { useThemedConstant } from "@digits-shared/themes/themedFunctions"
import fonts from "@digits-shared/themes/typography"
import { animated, type PickAnimated, type SpringValues, useSpring } from "@react-spring/web"
import { localPoint } from "@visx/event"
import { Group } from "@visx/group"
import { ParentSize } from "@visx/responsive"
import { scaleBand, scaleLinear } from "@visx/scale"
import { Bar, BarStack, Line, LinePath } from "@visx/shape"
import { Tooltip } from "@visx/tooltip"
import { extent, max, min } from "d3-array"
import { type ScaleBand, type ScaleLinear } from "d3-scale"
import styled, { css } from "styled-components"
import { v4 as generateUUID } from "uuid"
import {
  HorizontalGradient,
  VerticalGradient,
} from "src/frontend/components/OS/Shared/Charts/Coloring"
import { GlowFilter } from "src/frontend/components/OS/Shared/Charts/GlowFilter"
import {
  SharedBarChartStyles,
  SharedLineChartStyles,
  type TimeseriesBarChartStyle,
  type TimeseriesLineChartStyle,
} from "src/frontend/components/OS/Shared/Charts/styles"
import {
  ChartContainer,
  SVGContainer,
} from "src/frontend/components/Shared/Layout/Components/Charts/shared"
import {
  ChartGrid,
  ChartXAxis,
  ChartYAxis,
  DOUBLE_Y_AXIS_WIDTH,
  X_AXIS_HEIGHT,
  Y_AXIS_WIDTH,
} from "src/frontend/components/Shared/Layout/Components/Charts/TimeseriesChartAxis"
import {
  type Timeseries,
  type TimeseriesValue,
  type TimeseriesValues,
} from "src/frontend/components/Shared/Layout/Components/Charts/toTimeseries"
import customKeyframes from "src/shared/config/customKeyframes"
import zIndexes from "src/shared/config/zIndexes"

const ANIMATION_INITIAL_DELAY = 120
const ANIMATION_LINE_DELAY = 500

/*
 STYLES
*/

// <ElementType> is a workaround for TS compiler error triggered by react-springs v9 & styled-components
// https://github.com/pmndrs/react-spring/issues/1515
// By setting these generic react element type we lose the Bar's property inference.
// Once issue is fixed or there is a better work around we should remove the generic type.
const AnimatedBar = animated<React.ElementType>(Bar)

const StyledTooltip = styled(Tooltip)<{ $hoverRight?: boolean; $barWidth: number }>`
  position: fixed;
  z-index: ${zIndexes.modalOverlay};
  padding: 12px 16px;
  background-color: ${colors.translucentWhite80};
  border-radius: 8px;
  backdrop-filter: blur(8px);
  box-shadow: 0 7px 21px 0 ${colors.translucentBlack10};
  pointer-events: none;
  transform: translateX(
    ${({ $hoverRight, $barWidth }) =>
      `calc(${$hoverRight ? 0 : -100}% ${$hoverRight ? "+" : "-"} ${$barWidth / 1.8}px)`}
  );
`

const TooltipSeparator = styled.div`
  height: 1px;
  height: 100%;
  border-top: 1px solid ${colors.secondary10};
`

const TooltipRowContainer = styled.div`
  display: grid;
  color: ${colors.secondary};
  padding: 0 4px;
  white-space: nowrap;
  justify-items: flex-end;
  &:not(:first-child) {
    margin-top: 4px;
  }
  span {
    &:first-child {
      font-size: 12px;
      font-weight: ${fonts.weight.heavy};
    }
    &:last-child {
      font-size: 10px;
    }
  }
`

/*
  INTERFACES
*/

interface ChartProps {
  breakdownTimeseries: Timeseries[]
  totalTimeseries: Timeseries
  onClick?: (value: TimeseriesValue, index: number) => void
  onMouseOver?: (value: TimeseriesValue, index: number) => void
  onMouseOut?: (value?: TimeseriesValue) => void
  className?: string
  hideGrid?: boolean
  hideAxis?: boolean
  noTooltip?: boolean
  width: number
  height: number
  barChartStyle?: TimeseriesBarChartStyle
  lineChartStyle?: TimeseriesLineChartStyle
  skipAnimations?: boolean
}

interface SharedProps {
  chartId: string
  xScale: ScaleBand<number>
  xTicks: number[]
  yScale: ScaleLinear<number, number>
  yTicks: number[]
  timeseries: TimeseriesValues
  width: number
  height: number
  parentWidth: number
  barChartStyle?: TimeseriesBarChartStyle
  lineChartStyle?: TimeseriesLineChartStyle
  skipAnimations?: boolean
}

type BarProps = Pick<SharedProps, "chartId" | "xScale" | "yScale"> & {
  value: TimeseriesValue
  selectedValue: HoverValue | undefined
  index: number
  spring: SpringValues<PickAnimated<{ scale: number }>>
  onClick?: (value: TimeseriesValue, index: number) => void
}

interface LineProps {
  uuid: string
  timeseries: TimeseriesValues
  xScale: ScaleBand<number>
  yScale: ScaleLinear<number, number>
  strokeWidth: number
  skipAnimations?: boolean
}

type TooltipProps = Pick<SharedProps, "xScale" | "height" | "lineChartStyle"> & {
  value: HoverValue | undefined
  yAxisWidth: number
}

type FlowTimeseriesKey = "inflows" | "outflows"

interface FlowTimeseriesValue {
  period: Period
  inflows: FlowValues
  outflows: FlowValues
}

interface FlowValues {
  values: FlowValue[]
  total: TimeseriesValue // timestamp is redundant but makes it easier to reuse yValue
}

interface FlowValue {
  name: string
  moneyFlow: MoneyFlow
}

interface HoverValue {
  position: { top: number; left: number }
  totalName: string
  totalValue: TimeseriesValue
  flowValue: FlowTimeseriesValue
}

/*
  COMPONENTS
*/

const xValue = ({ period: { startedAt } }: TimeseriesValue | FlowTimeseriesValue) => startedAt

const yValue = ({
  moneyFlow: {
    value: { amount, currencyMultiplier },
    isNormal,
  },
}: TimeseriesValue) => (Math.abs(amount) * (isNormal ? 1 : -1)) / currencyMultiplier

export const ParentSizedTimeseriesBreakdownChart: React.FC<
  Omit<ChartProps, "width" | "height">
> = ({
  breakdownTimeseries: flowTimeseries,
  totalTimeseries,
  className,
  onClick,
  onMouseOver,
  onMouseOut,
  hideGrid,
  hideAxis,
  noTooltip,
  barChartStyle,
  lineChartStyle,
  skipAnimations,
}) => (
  <ParentSize>
    {(parent) => {
      const { width, height } = parent

      return (
        <TimeseriesBreakdownChart
          breakdownTimeseries={flowTimeseries}
          totalTimeseries={totalTimeseries}
          className={className}
          onClick={onClick}
          onMouseOver={onMouseOver}
          onMouseOut={onMouseOut}
          hideGrid={hideGrid}
          hideAxis={hideAxis}
          noTooltip={noTooltip}
          barChartStyle={barChartStyle || SharedBarChartStyles}
          lineChartStyle={lineChartStyle || SharedLineChartStyles}
          width={width}
          height={height}
          skipAnimations={skipAnimations}
        />
      )
    }}
  </ParentSize>
)

// Timeseries chart that shows a line for the total value (e.g. Total Cash) and
// stacked bars for the breakdowns of that value (e.g. inflows and outflows that
// sum to the delta in total cash).
export const TimeseriesBreakdownChart: React.FC<ChartProps> = ({
  breakdownTimeseries: propBreakdownTimeseries,
  totalTimeseries: { label: totalName, values: propTotalTimeseries },
  className,
  onClick,
  onMouseOver,
  onMouseOut,
  hideGrid,
  hideAxis,
  noTooltip,
  width,
  height,
  barChartStyle: propChartStyle,
  lineChartStyle: propLineChartStyle,
  skipAnimations,
}) => {
  const chartStyle = propChartStyle || SharedBarChartStyles
  const lineChartStyle = propLineChartStyle || SharedLineChartStyles

  const chartId = useConstant<string>(generateUUID)
  const [selectedValue, setSelectedValue] = React.useState<HoverValue | undefined>()

  const noopColor = React.useCallback(() => "", [])

  // Transform into a shape that is easier to use with stacked bar charts.
  const breakdownTimeseries = React.useMemo(() => {
    const m = new Map<number, FlowTimeseriesValue>()
    propBreakdownTimeseries.forEach(({ label: name, values: timeseriesList }) => {
      timeseriesList.forEach(({ period, moneyFlow }) => {
        const key = period.startedAt
        const flowValue = m.get(key) || {
          period,
          inflows: {
            values: [] as FlowValue[],
            total: {
              period,
              moneyFlow: {
                ...moneyFlowHelper.buildZeroMoneyFlow(),
                businessFlow: Flow.Inbound,
              },
              deltaPrevious: undefined,
              deltaYearAgo: undefined,
            },
          },
          outflows: {
            values: [] as FlowValue[],
            total: {
              period,
              moneyFlow: {
                ...moneyFlowHelper.buildZeroMoneyFlow(),
                businessFlow: Flow.Outbound,
              },
              deltaPrevious: undefined,
              deltaYearAgo: undefined,
            },
          },
        }

        if (moneyFlow.businessFlow === Flow.Inbound && moneyFlow.isNormal) {
          flowValue.inflows.values.push({ name, moneyFlow })
          flowValue.inflows.total.moneyFlow = moneyFlowHelper.add(
            flowValue.inflows.total.moneyFlow,
            moneyFlow
          )
        } else if (
          (moneyFlow.businessFlow === Flow.Outbound && moneyFlow.isNormal) ||
          moneyFlow.businessFlow === Flow.Inbound
        ) {
          flowValue.outflows.values.push({ name, moneyFlow })
          flowValue.outflows.total.moneyFlow = moneyFlowHelper.subtract(
            flowValue.outflows.total.moneyFlow,
            moneyFlow
          )
        }
        m.set(key, flowValue)
      })
    })
    return Array.from(m.values()).toSorted((a, b) => a.period.startedAt - b.period.startedAt)
  }, [propBreakdownTimeseries])

  // Flattend for axes.
  const flattenedBreakdownTimeseries: TimeseriesValues = React.useMemo(
    () =>
      breakdownTimeseries.flatMap(({ period, inflows, outflows }) => [
        {
          period,
          moneyFlow: inflows.total.moneyFlow,
          deltaPrevious: undefined,
          deltaYearAgo: undefined,
        },
        {
          period,
          moneyFlow: outflows.total.moneyFlow,
          deltaPrevious: undefined,
          deltaYearAgo: undefined,
        },
      ]),
    [breakdownTimeseries]
  )

  const totalTimeseries = React.useMemo(
    () => propTotalTimeseries.toSorted((a, b) => a.period.startedAt - b.period.startedAt),
    [propTotalTimeseries]
  )

  const xValues = breakdownTimeseries.map(xValue)

  const breakdownBarKeys = React.useMemo(
    () =>
      hasFirstElement(breakdownTimeseries)
        ? objectHelper.keysOf(breakdownTimeseries[0]).filter((k) => k !== "period")
        : [],
    [breakdownTimeseries]
  )

  const breakdownBarValue = React.useCallback(
    (value: FlowTimeseriesValue, key: FlowTimeseriesKey) => yValue(value[key].total),
    []
  )

  const xMin = hideAxis ? 0 : DOUBLE_Y_AXIS_WIDTH / 3
  const xMaxPadding = hideAxis ? Y_AXIS_WIDTH : Y_AXIS_WIDTH * 1.3 * 2
  const xMax = hideAxis ? width : width - xMaxPadding

  const xScale = scaleBand({
    range: [xMin, xMax],
    round: true,
    domain: xValues,
    paddingInner: 0.5,
  })

  const onMouseMove = React.useCallback(
    (event: React.MouseEvent) => {
      const point = localPoint(event)
      const x = point ? point.x : 0
      const yAxisOffset = hideAxis ? 0 : DOUBLE_Y_AXIS_WIDTH
      const index = Math.floor((x - yAxisOffset + xScale.bandwidth() / 2) / xScale.step())

      if (index < 0 || index >= xValues.length) {
        return
      }

      const totalValue = totalTimeseries[index] as TimeseriesValue
      const position =
        selectedValue && selectedValue.totalValue === totalValue
          ? selectedValue.position
          : {
              top: event.pageY,
              left: event.pageX,
            }

      const hoverValue = {
        position,
        totalName,
        totalValue,
        flowValue: breakdownTimeseries[index] as FlowTimeseriesValue,
      }
      setSelectedValue(hoverValue)
      onMouseOver?.(totalValue, index)
    },
    [
      breakdownTimeseries,
      hideAxis,
      onMouseOver,
      selectedValue,
      totalName,
      totalTimeseries,
      xScale,
      xValues.length,
    ]
  )

  const onMouseLeave = React.useCallback(() => {
    setSelectedValue(undefined)
    onMouseOut?.(undefined)
  }, [onMouseOut])
  useEvent("mousewheel", () => setSelectedValue(undefined), document.body)

  const breakdownMaxYValue = React.useMemo(
    () =>
      Math.max(
        max(breakdownTimeseries, (v) =>
          Math.max(yValue(v.inflows.total), yValue(v.outflows.total))
        ) || 0,
        0
      ),
    [breakdownTimeseries]
  )

  const breakdownMinYValue = React.useMemo(
    () =>
      Math.min(
        min(breakdownTimeseries, (v) =>
          Math.min(yValue(v.inflows.total), yValue(v.outflows.total))
        ) || 0,
        0
      ),
    [breakdownTimeseries]
  )

  const [totalMinYValue, totalMaxYValue] = React.useMemo(() => {
    const [minVal, maxVal] = extent(totalTimeseries, yValue)
    return [Math.min(minVal ?? 0, 0), Math.max(maxVal ?? 0, 0)]
  }, [totalTimeseries])

  const immediate = skipAnimations
  const delay = immediate ? 0 : ANIMATION_INITIAL_DELAY
  const spring = useSpring({
    config: {
      tension: 250,
      friction: 70,
      mass: 3,
    },
    from: { scale: 0.1 },
    to: { scale: 1 },
    delay,
    immediate,
  })

  if (!propBreakdownTimeseries?.length || !breakdownTimeseries.length || !totalTimeseries.length) {
    return null
  }

  if (width === 0 || height === 0) return null

  const svgHeight = hideAxis ? height : height - X_AXIS_HEIGHT

  const yTotalScale: ScaleLinear<number, number> = scaleLinear({
    range: [svgHeight, 0],
    round: true,
    domain: [totalMinYValue, totalMaxYValue],
  })

  const yBreakdownScale: ScaleLinear<number, number> = scaleLinear({
    range: [svgHeight, 0],
    round: true,
    domain: [breakdownMinYValue, breakdownMaxYValue],
  })

  const barsLeft = hideAxis ? 0 : Y_AXIS_WIDTH

  return (
    <ChartContainer className={className} width={width} height={height} onMouseLeave={onMouseLeave}>
      <SVGContainer width={width} height={svgHeight} onMouseMove={onMouseMove}>
        <GlowFilter id={`${chartId}-glow`} standardDeviation={hideAxis ? 1 : 2.5} />
        <VerticalGradient
          id={`${chartId}-bar-fill`}
          color={chartStyle.barFillColor}
          bottomColor={chartStyle.barFillBottomColor}
        />
        <VerticalGradient
          id={`${chartId}-bar-negative-fill`}
          color={chartStyle.barFillNegativeColor}
          bottomColor={chartStyle.barFillNegativeBottomColor}
        />
        <VerticalGradient
          id={`${chartId}-bar-fill-inactive`}
          color={chartStyle.barFillInactiveColor}
          bottomColor={chartStyle.barFillInactiveBottomColor}
        />
        <VerticalGradient
          id={`${chartId}-bar-fill-selected`}
          color={chartStyle.barFillSelectedColor}
          bottomColor={chartStyle.barFillSelectedBottomColor}
        />
        <HorizontalGradient
          id={`${chartId}-line-stroke-gradient`}
          color="#6B3CF2"
          bottomColor="#C533EA"
        />

        {!hideGrid && (
          <ChartGrid
            yScale={yBreakdownScale}
            width={width - Y_AXIS_WIDTH}
            height={svgHeight}
            axisStyle={chartStyle}
            yAxisWidth={Y_AXIS_WIDTH}
          />
        )}
        <Group top={0} left={barsLeft}>
          <BarStack<FlowTimeseriesValue, string>
            data={breakdownTimeseries}
            keys={breakdownBarKeys}
            xScale={xScale}
            yScale={yBreakdownScale}
            color={noopColor}
            x={xValue}
            value={breakdownBarValue}
            offset="diverging"
          >
            {(barStacks) =>
              barStacks.map((barStack, stackIndex) =>
                barStack.bars.map((value, barIndex) => (
                  <ValueBar
                    chartId={chartId}
                    key={`bar-${stackIndex}-${barIndex}`}
                    xScale={xScale}
                    yScale={yBreakdownScale}
                    value={value.bar.data[value.key as FlowTimeseriesKey].total}
                    selectedValue={selectedValue}
                    index={barIndex}
                    onClick={onClick}
                    spring={spring}
                  />
                ))
              )
            }
          </BarStack>
          <SummaryLine
            uuid={chartId}
            timeseries={totalTimeseries}
            xScale={xScale}
            yScale={yTotalScale}
            strokeWidth={2}
            skipAnimations={skipAnimations}
          />
        </Group>
        {!hideAxis && (
          <>
            <ChartXAxis
              timeseries={flattenedBreakdownTimeseries}
              xScale={xScale}
              height={svgHeight}
              width={width}
              axisStyle={chartStyle}
              yAxisWidth={Y_AXIS_WIDTH}
            />
            <DashedLine
              value={selectedValue}
              xScale={xScale}
              height={svgHeight}
              lineChartStyle={lineChartStyle}
              yAxisWidth={Y_AXIS_WIDTH}
            />
            <ChartYAxis
              timeseries={flattenedBreakdownTimeseries}
              yScale={yBreakdownScale}
              height={svgHeight}
              width={width}
              axisStyle={chartStyle}
              yAxisWidth={Y_AXIS_WIDTH}
            />
            <ChartYAxis
              timeseries={totalTimeseries}
              yScale={yTotalScale}
              height={svgHeight}
              width={width}
              axisStyle={chartStyle}
              yAxisWidth={Y_AXIS_WIDTH}
              axisRight
            />
          </>
        )}
      </SVGContainer>
      {!noTooltip && (
        <HoverTooltip
          value={selectedValue}
          xScale={xScale}
          height={svgHeight}
          lineChartStyle={lineChartStyle}
          yAxisWidth={Y_AXIS_WIDTH}
        />
      )}
    </ChartContainer>
  )
}

const ValueBar: React.FC<BarProps> = ({
  chartId,
  xScale,
  yScale,
  value,
  index,
  onClick,
  spring,
}) => {
  const yVal = yValue(value)
  const zeroY = yScale(0) ?? 0
  const negative = yVal < 0

  const barWidth = xScale.bandwidth()
  const barHeight = Math.max(zeroY - (yScale(Math.abs(yVal)) ?? 0), 0.5)
  const barX = xScale(xValue(value))
  const barY = negative ? zeroY : spring.scale.to((s: number) => zeroY - s * barHeight)
  const click = React.useCallback(() => onClick?.(value, index), [onClick, value, index])

  const cornerRadius = barWidth > 4 ? 2 : 1

  const fill =
    negative || value.moneyFlow.businessFlow === Flow.Outbound
      ? `url(#${chartId}-bar-negative-fill)`
      : `url(#${chartId}-bar-fill)`

  return (
    <AnimatedBar
      width={barWidth}
      height={spring.scale.to((s: number) => s * barHeight)}
      x={barX}
      y={barY}
      rx={cornerRadius}
      ry={cornerRadius}
      fill={fill}
      strokeWidth={0}
      onClick={click}
      css={onClick ? { cursor: "pointer" } : undefined}
    />
  )
}

const SummaryLine: React.FC<LineProps> = ({
  uuid,
  timeseries,
  xScale,
  yScale,
  strokeWidth,
  skipAnimations,
}) => {
  const isPrintMode = useIsPrintTheme()
  const stroke = useThemedConstant({
    print: colors.translucentBlack50,
    light: `url(#${uuid}-line-stroke-gradient)`,
    dark: `url(#${uuid}-line-stroke-gradient)`,
  })
  const x = React.useCallback(
    (point: TimeseriesValue) => xScale.bandwidth() / 2 + (xScale(xValue(point)) ?? 0),
    [xScale]
  )
  const y = React.useCallback(
    (point: TimeseriesValue, index: number) => {
      const scaledPoint = yScale(yValue(point)) ?? 0
      return index ? scaledPoint : scaledPoint - 0.001
    },
    [yScale]
  )

  const immediate = skipAnimations || isPrintMode

  const transition =
    !immediate &&
    css`
      opacity: 0;
      animation: ${customKeyframes.fadeIn} 1500ms ${ANIMATION_LINE_DELAY}ms forwards;
    `

  return (
    <LinePath
      css={transition}
      data={timeseries}
      x={x}
      y={y}
      stroke={stroke}
      strokeWidth={strokeWidth}
      pointerEvents="none"
    />
  )
}

const HoverTooltip: React.FC<TooltipProps> = ({ value, xScale }) => {
  const modalRoot = useModalRoot()
  const rows = React.useMemo(
    () =>
      (value?.flowValue.inflows.values ?? [])
        .concat(value?.flowValue.outflows.values ?? [])
        .filter((v) => v.moneyFlow.value.amount !== 0)
        .toSorted((a, b) => b.moneyFlow.value.amount - a.moneyFlow.value.amount),
    [value?.flowValue]
  )

  const deltaValue = React.useMemo(
    () =>
      !!rows.length &&
      rows
        .map((r) => r.moneyFlow.value)
        .reduce((acc, v) => ({ ...acc, amount: acc.amount + v.amount })),
    [rows]
  )

  if (!value) return null
  if (rows.length === 0) return null

  const { position, totalValue } = value
  const isLeftOfCenter = xScale.domain().indexOf(xValue(totalValue)) < xScale.domain().length / 2
  const barWidth = xScale.bandwidth()

  return (
    <>
      {ReactDOM.createPortal(
        <StyledTooltip
          left={position.left - 20}
          top={position.top}
          $hoverRight={isLeftOfCenter}
          $barWidth={barWidth}
          unstyled
        >
          {rows.map(({ name, moneyFlow }, i, all) => (
            <TooltipRow
              key={name}
              label={name}
              value={moneyFlow.value}
              separator={deltaValue && i == all.length - 1}
            />
          ))}
          {deltaValue && <TooltipRow label={value.totalName} value={deltaValue} />}
        </StyledTooltip>,
        modalRoot
      )}
    </>
  )
}

const TooltipRow: React.FC<{ label: string; value: MonetaryValue; separator?: boolean }> = ({
  label,
  value,
  separator,
}) => (
  <>
    <TooltipRowContainer>
      <span>{numberHelper.currency(value)}</span>
      <span>{label}</span>
    </TooltipRowContainer>
    {separator && <TooltipSeparator css="margin-top: 4px;" />}
  </>
)

const DashedLine: React.FC<TooltipProps> = ({
  value,
  xScale,
  height,
  lineChartStyle,
  yAxisWidth,
}) => {
  if (!value) return null

  const left = yAxisWidth + (xScale(xValue(value.totalValue)) ?? 0) + xScale.bandwidth() / 2
  return (
    <Line
      from={{ x: left, y: 0 }}
      to={{ x: left, y: height }}
      stroke={lineChartStyle?.lineStroke}
      strokeWidth={1}
      strokeDasharray="2,2"
      style={{ pointerEvents: "none" }}
    />
  )
}
