import { Data } from "@elara/select"
import { useVirtualizer, VirtualItem } from "@tanstack/react-virtual"
import React, { ReactElement, useContext } from "react"

import { DataViewChunkedFetcherContext } from "./data-view-chunked-fetcher"
import { GroupBySummaryItem, RowItem } from "./data-view-types"

function VirtualizedRow<D extends Data>(props: {
  virtualItem: VirtualItem
  measureElement: (el: HTMLDivElement | null) => void
  useGetItem: (vItem: VirtualItem) => RowItem<D> | null

  render: (row: RowItem<D>) => JSX.Element
}) {
  const { virtualItem } = props

  const item = props.useGetItem(virtualItem)

  if (!item) return null

  return (
    <div key={virtualItem.key} data-index={virtualItem.index} ref={props.measureElement}>
      {props.render(item)}
    </div>
  )
}

export const DataViewChunkedVirtualizedRowRenderer = <D extends Data>(props: {
  render: (row: RowItem<D>) => ReactElement
  rowHeight: number
  style?: React.CSSProperties
  className?: string
  slot?: { header?: ReactElement; footer?: ReactElement }
  overscan?: number
}) => {
  // The scrollable element for your list
  const parentRef = React.useRef<HTMLDivElement>(null)

  const ctx = useContext(DataViewChunkedFetcherContext)
  // The virtualizer
  const count = ctx.size
  const rowVirtualizer = useVirtualizer({
    count,
    getScrollElement: () => parentRef.current,
    estimateSize: () => props.rowHeight,
    overscan: props.overscan ?? 10,
  })

  const getItem = (index: number): RowItem<D> | null => {
    if (index < 0) return null
    if (index >= ctx.size) return null

    const item = ctx.getItem(index)
    if (item.type === "group") {
      return { type: "group", group: item.group as GroupBySummaryItem<D>, rowIndex: index }
    }

    if (item.type === "row") {
      const row = ctx.getRow(item)
      return {
        type: "row" as const,
        row: row as D | null,
        group: item.group as GroupBySummaryItem<D>,
        rowIndex: index,
        level: 0,
      }
    } else if (item.type === "child") {
      const row = ctx.getRow(item)
      return {
        type: "row" as const,
        row: row as D | null,
        rowIndex: index,
        group: null,
        level: item.level,
      }
    }
    return null
  }

  const useGetItem = (virtualItem: VirtualItem) => {
    const item = getItem(virtualItem.index)

    ctx.useRegisterRow(virtualItem.index)
    return item
  }

  const items = rowVirtualizer.getVirtualItems()
  return (
    <div
      style={{ height: 200, overflow: "auto", ...props.style }}
      className={props.className}
      ref={parentRef}>
      {props.slot?.header}
      <div
        style={{
          height: rowVirtualizer.getTotalSize(),
          width: "100%",
          position: "relative",
        }}>
        <div
          style={{
            position: "absolute",
            top: 0,
            left: 0,
            width: "100%",
            transform: `translateY(${items[0]?.start ?? 0}px)`,
          }}>
          {/* Only the visible items in the virtualizer, manually positioned to be in view */}
          {items.map((virtualItem) => {
            return (
              <VirtualizedRow<D>
                key={virtualItem.key}
                virtualItem={virtualItem}
                useGetItem={useGetItem}
                render={props.render}
                measureElement={rowVirtualizer.measureElement}
              />
            )
          })}
        </div>
      </div>
      {props.slot?.footer}
    </div>
  )
}
