import { Ref, forwardRef, useEffect, useImperativeHandle, useRef } from "react"
import {
  CellsStoreState,
  selectCellConnections,
  selectCellLayouts,
  selectGridBounds,
} from "@/app/store/cellsSlice"
import { useAppSelector } from "../../../store"
import { Box, Rect, Vec2 } from "@/packages/util/geometry"
import { ArrowDescriptor, getBoxToBoxArrow } from "@/packages/util/arrows"
import { CellId } from "@/engine/state/types"

export interface BoardConnectionsHandle {
  // add a temporary arrow from a cell to a point while creating a new connection
  addArrowToPoint(cellId: CellId, to: Vec2): void
  // add a temporary arrow from a cell to another cell while creating a new connection
  addArrowToCell(cellId: CellId, to: CellId): void
  // offsets cells by a given amount while moving the cell
  setCelLayoutOverride(override: Record<CellId, Box>): void
  // clear all temporary changes
  reset(): void
}

type Props = unknown

export default forwardRef(function BoardConnections(_: Props, ref: Ref<BoardConnectionsHandle>) {
  const cellConnections = useAppSelector(selectCellConnections)
  const cellLayouts = useAppSelector(selectCellLayouts)
  const gridBounds = useAppSelector(selectGridBounds)

  const canvasRef = useRef<HTMLCanvasElement>(null)
  const canvasPadding = 50

  function resizeCanvasToInclude(canvas: HTMLCanvasElement, gridBounds: Rect, points: Vec2[]) {
    let bounds: Rect = gridBounds

    for (const [x, y] of points) {
      if (x < bounds.left || x > bounds.right || y < bounds.top || y > bounds.bottom) {
        bounds = {
          left: Math.min(bounds.left, x),
          top: Math.min(bounds.top, y),
          right: Math.max(bounds.right, x),
          bottom: Math.max(bounds.bottom, y),
        }
      }
    }

    canvas.style.marginLeft = `${bounds.left - gridBounds.left - canvasPadding}px`
    canvas.style.marginTop = `${bounds.top - gridBounds.top - canvasPadding}px`
    canvas.width = bounds.right - bounds.left + canvasPadding * 2
    canvas.height = bounds.bottom - bounds.top + canvasPadding * 2

    return bounds
  }

  useImperativeHandle(ref, () => ({
    addArrowToPoint(cellId: CellId, to: Vec2) {
      const canvas = canvasRef.current
      if (!canvas) return

      const bounds = resizeCanvasToInclude(canvas, gridBounds, [to])
      draw(canvas, cellLayouts, cellConnections, bounds, canvasPadding, {}, { [cellId]: to })
    },
    addArrowToCell(cellId: CellId, to: CellId) {
      const canvas = canvasRef.current
      if (!canvas) return

      draw(
        canvas,
        cellLayouts,
        cellConnections,
        gridBounds,
        canvasPadding,
        {},
        {},
        { [cellId]: to },
      )
    },
    setCelLayoutOverride(override) {
      const canvas = canvasRef.current
      if (!canvas) return

      const points: Vec2[] = Object.values(override).flatMap((box) => {
        return [
          [box.x, box.y],
          [box.x + box.width, box.y + box.height],
        ]
      })

      const bounds = resizeCanvasToInclude(canvas, gridBounds, points)

      draw(canvas, cellLayouts, cellConnections, bounds, canvasPadding, override)
    },
    reset() {
      const canvas = canvasRef.current
      if (!canvas) return

      const bounds = resizeCanvasToInclude(canvas, gridBounds, [])
      draw(canvas, cellLayouts, cellConnections, bounds, canvasPadding)
    },
  }))

  useEffect(() => {
    const canvas = canvasRef.current
    if (!canvas) return

    const [width, height] = [
      gridBounds.right - gridBounds.left + canvasPadding * 2,
      gridBounds.bottom - gridBounds.top + canvasPadding * 2,
    ]
    if (canvas.width !== width || canvas.height !== height) {
      canvas.style.marginLeft = `${-canvasPadding}px`
      canvas.style.marginTop = `${-canvasPadding}px`
      canvas.width = width
      canvas.height = height
    }

    draw(canvas, cellLayouts, cellConnections, gridBounds, canvasPadding)
  }, [cellConnections, cellLayouts, gridBounds, canvasRef])

  return <canvas ref={canvasRef} className="!pointer-events-none absolute inset-0 touch-none" />
})

const draw = (
  canvas: HTMLCanvasElement,
  cellLayouts: CellsStoreState["layouts"],
  cellInputs: CellsStoreState["inputs"],
  bounds: Rect,
  canvasPadding: number,
  overrideLayouts: Record<CellId, Box> = {},
  freeformArrow: Record<CellId, Vec2> = {},
  pendingConnection: Record<CellId, CellId> = {},
) => {
  const linePadding = 5
  const arrowHeadSize = 5
  const circleRadius = 5

  const translate = ([x, y]: Vec2) => {
    return [x - bounds.left + canvasPadding, y - bounds.top + canvasPadding]
  }

  const drawArrow = (
    ctx: CanvasRenderingContext2D,
    {
      start: [sx, sy],
      c1: [c1x, c1y],
      c2: [c2x, c2y],
      end: [ex, ey],
      angleStart,
      angleEnd,
    }: ArrowDescriptor,
    arrowHeadSize: number,
  ) => {
    // draw line
    ctx.beginPath()
    ctx.moveTo(sx, sy)
    ctx.bezierCurveTo(c1x, c1y, c2x, c2y, ex, ey)
    ctx.stroke()
    ctx.closePath()

    // draw start
    ctx.save()
    ctx.translate(sx, sy)
    ctx.rotate((angleStart * Math.PI) / 180)
    ctx.translate(8, 0)
    ctx.beginPath()
    ctx.ellipse(0, 0, circleRadius, circleRadius, 0, 0, 2 * Math.PI)
    ctx.fill()
    ctx.closePath()
    ctx.restore()

    // draw arrow head
    ctx.save()
    ctx.beginPath()
    ctx.translate(ex, ey)
    ctx.rotate((angleEnd * Math.PI) / 180)
    ctx.translate(-2, 0)
    ctx.moveTo(0, 0)
    ctx.lineTo(-arrowHeadSize / 2, -arrowHeadSize)
    ctx.lineTo(arrowHeadSize * 2, 0)
    ctx.lineTo(-arrowHeadSize / 2, arrowHeadSize)
    ctx.lineTo(0, 0)
    ctx.fill()
    ctx.closePath()
    ctx.restore()
  }

  const drawConnection = (ctx: CanvasRenderingContext2D, from: Box, to: Box) => {
    const [fromX, fromY] = translate([from.x, from.y])
    const [toX, toY] = translate([to.x, to.y])

    const arrow = getBoxToBoxArrow(
      { x: fromX, y: fromY, width: from.width, height: from.height },
      { x: toX, y: toY, width: to.width, height: to.height },
      { padding: linePadding },
    )
    drawArrow(ctx, arrow, arrowHeadSize)
  }

  const drawFreeformArrow = (ctx: CanvasRenderingContext2D, from: Box, to: Vec2) => {
    const [fromX, fromY] = translate([from.x, from.y])
    const [toX, toY] = translate(to)

    const arrow = getBoxToBoxArrow(
      { x: fromX, y: fromY, width: from.width, height: from.height },
      { x: toX, y: toY, width: 0, height: 0 },
      { padding: linePadding, endSides: ["top", "right", "bottom", "left"] },
    )

    drawArrow(ctx, arrow, arrowHeadSize)

    ctx.restore()
  }

  const ctx = canvas.getContext("2d")
  if (!ctx) return

  ctx.clearRect(0, 0, canvas.width, canvas.height)

  ctx.lineWidth = 2
  ctx.strokeStyle = "#075985" // sky-800
  ctx.fillStyle = "#075985"
  ctx.setLineDash([5, 5])

  const targetIds = Object.keys(cellInputs)
  for (const targetId of targetIds) {
    const target = overrideLayouts[targetId] ?? cellLayouts[targetId]

    const sourceIds = cellInputs[targetId]
    for (const sourceId of sourceIds) {
      const source = overrideLayouts[sourceId] ?? cellLayouts[sourceId]
      drawConnection(ctx, source, target)
    }
  }

  ctx.strokeStyle = "#38bdf8" // sky-400
  ctx.fillStyle = "#38bdf8"
  ctx.setLineDash([])

  const freeformSourceIds = Object.keys(freeformArrow)
  for (const sourceId of freeformSourceIds) {
    const source = cellLayouts[sourceId]
    drawFreeformArrow(ctx, source, freeformArrow[sourceId])
  }

  const pendingConnectionSourceIds = Object.keys(pendingConnection)
  for (const sourceId of pendingConnectionSourceIds) {
    drawConnection(ctx, cellLayouts[sourceId], cellLayouts[pendingConnection[sourceId]])
  }
}
