import { Rect, Vec2 } from "@/packages/util/geometry"
import { Ref, forwardRef, useImperativeHandle, useRef } from "react"

interface ScrollViewBackgroundProps {
  className?: string
}

export interface ScrollViewBackgroundHandle {
  setTransform(transform: { x: number; y: number; scale: number }): void
  setBounds(inset: Vec2, gridBounds: Rect, transform: { x: number; y: number; scale: number }): void
}

export default forwardRef(function ScrollViewBackground(
  props: ScrollViewBackgroundProps,
  ref: Ref<ScrollViewBackgroundHandle>,
) {
  const canvasRef = useRef<HTMLCanvasElement>(null)
  const requestAnimationFrameRef = useRef<number>(0)

  const transformRef = useRef<{ x: number; y: number; scale: number }>({ x: 0, y: 0, scale: 1 })
  const insetRef = useRef<Vec2>([0, 0])
  const gridBoundsRef = useRef<Rect>({ left: 0, top: 0, right: 0, bottom: 0 })

  useImperativeHandle(ref, () => ({
    setTransform(transform: { x: number; y: number; scale: number }) {
      transformRef.current = transform
      invalidate(true)
    },
    setBounds(inset: Vec2, gridBounds: Rect, transform: { x: number; y: number; scale: number }) {
      insetRef.current = inset
      transformRef.current = transform
      gridBoundsRef.current = gridBounds

      cancelAnimationFrame(requestAnimationFrameRef.current)
      draw()
    },
  }))

  function invalidate(viewportOnly: boolean) {
    requestAnimationFrameRef.current = requestAnimationFrame(() => {
      if (!canvasRef.current) return

      if (viewportOnly) {
        const canvasBounds = canvasRef.current.getBoundingClientRect()
        const visibleLeft = -Math.min(0, canvasBounds.left)
        const visibleTop = -Math.min(0, canvasBounds.top)
        const visibleRight = window.innerWidth + visibleLeft
        const visibleBottom = window.innerHeight + visibleTop

        draw({
          left: visibleLeft,
          top: visibleTop,
          right: visibleRight,
          bottom: visibleBottom,
        })
      } else {
        draw()
      }
    })
  }

  // TODO: chrome really struggles when the canvas is too big, should probably switch to a static image in that case
  function draw(drawingBounds: Rect | null = null) {
    const canvas = canvasRef.current
    if (!canvas?.parentElement) return

    canvas.width = canvas.parentElement.offsetWidth
    canvas.height = canvas.parentElement.offsetHeight

    const ctx = canvas.getContext("2d", {
      alpha: false,
    })
    if (!ctx) return

    const scale = transformRef.current.scale
    const gridStartLeft =
      insetRef.current[0] + gridBoundsRef.current.left * scale - transformRef.current.x
    const gridStartTop =
      insetRef.current[1] + gridBoundsRef.current.top * scale - transformRef.current.y

    const minTop = drawingBounds?.top ?? 0
    const maxBottom = drawingBounds?.bottom ?? canvas.height
    const minLeft = drawingBounds?.left ?? 0
    const maxRight = drawingBounds?.right ?? canvas.width

    ctx.fillStyle = "#25252a"
    ctx.fillRect(0, 0, canvas.width, canvas.height)

    const drawWidth = maxRight - minLeft
    const drawHeight = maxBottom - minTop

    const drawGridLines = (step: number, color: string, lineWidth: number) => {
      const stepSize = step * scale
      if (stepSize < 3) return

      const xStartOffset = -mod(gridStartLeft, stepSize)
      const yStartOffset = -mod(gridStartTop, stepSize)
      const stepsX = Math.ceil(drawWidth / stepSize)
      const stepsY = Math.ceil(drawHeight / stepSize)
      const startStepX = Math.ceil(minLeft / stepSize)
      const startStepY = Math.ceil(minTop / stepSize)

      ctx.beginPath()
      ctx.strokeStyle = color
      ctx.lineWidth = lineWidth

      for (let i = 0; i < stepsX; i++) {
        const [x, y1, y2] = [xStartOffset + (i + startStepX) * stepSize, minTop, maxBottom]
        ctx.moveTo(x, y1)
        ctx.lineTo(x, y2)
      }

      for (let i = 0; i < stepsY; i++) {
        const [x1, x2, y] = [minLeft, maxRight, yStartOffset + (i + startStepY) * stepSize]
        ctx.moveTo(x1, y)
        ctx.lineTo(x2, y)
      }

      ctx.stroke()
      ctx.closePath()
    }

    drawGridLines(20, "#28282d", 1)
    drawGridLines(100, "#2d2d32", 1)
  }

  if (canvasRef.current) {
    cancelAnimationFrame(requestAnimationFrameRef.current)
    draw()
  }

  return <canvas ref={canvasRef} {...props} />
})

function mod(n: number, m: number) {
  return ((n % m) + m) % m
}
