import React, { useRef, useEffect } from 'react';
import { useSelector } from 'react-redux';
import { selectors } from 'state';
import { blue, yellow } from '@ant-design/colors';
import { Button } from 'antd';

type Props = {
  correlationFields: { row?: number; col?: number }[];
};

const MatrixCanvas: React.FC<Props> = ({ correlationFields }) => {
  const equities = useSelector(selectors.equities);
  const canvasRef = useRef<HTMLCanvasElement>(null);

  const rows = correlationFields.map(c => c.row);
  const columns = correlationFields.map(c => c.col);

  const correlationValues = useSelector(selectors.correlationValues);
  const correlationList = Object.values(correlationValues);

  const getColor = (value: number) => {
    const colorRange = value >= 0 ? yellow : blue;
    const absValue = Math.abs(value);
    // Map [0, 1] to {0,1,2,3,4,5}
    if (absValue > 1) {
      return colorRange[4];
    }
    const index = Math.floor(absValue / (1 / 5));
    return colorRange[index];
  };

  const OFFSET = 100;

  useEffect(() => {
    const canvas = canvasRef.current;
    if (canvas === null) {
      return;
    }
    const ctx = canvas.getContext('2d');
    if (ctx === null) {
      return;
    }

    ctx.fillStyle = '#fff';
    ctx.fillRect(0, 0, canvas.width, canvas.height);
    ctx.fillStyle = '#000';

    if (!columns.length || !rows.length) {
      return;
    }

    // Calculate the width and height of each cell
    const cellWidth = (canvas.width - OFFSET) / columns.length;
    const cellHeight = (canvas.height - OFFSET) / rows.length;

    // Draw the legends
    ctx.font = '11px Arial';
    ctx.textAlign = 'center';
    ctx.textBaseline = 'middle';

    // Draw the row legends on the left side
    for (let i = 0; i < rows.length; i++) {
      const y = OFFSET + i * cellHeight + cellHeight / 2;
      const equityId = rows[i];
      if (equityId !== undefined) {
        ctx.save();
        ctx.translate(OFFSET / 2, y);
        ctx.rotate(-Math.PI / 6);
        ctx.fillText(equities[equityId]?.ticker, 0, 0);
        ctx.restore();
      }
    }

    // Draw the column legends on the top
    for (let i = 0; i < columns.length; i++) {
      const x = OFFSET + i * cellWidth + cellWidth / 2;
      const equityId = columns[i];
      if (equityId) {
        ctx.save();
        ctx.translate(x, OFFSET / 2);
        ctx.rotate(-Math.PI / 6);
        ctx.fillText(equities[equityId]?.ticker, 0, 0);
        ctx.restore();
      }
    }

    for (let i = 0; i < rows.length; i++) {
      for (let j = 0; j < columns.length; j++) {
        const row = rows[i];
        const col = columns[j];
        if (row === undefined || col === undefined) {
          continue;
        }

        const value =
          row === col
            ? 1
            : correlationList.find(
                c => c.equities.includes(row) && c.equities.includes(col)
              )?.value;
        if (value === undefined) {
          continue;
        }

        const cellX = OFFSET + i * cellWidth;
        const cellY = OFFSET + j * cellHeight;

        ctx.fillStyle = getColor(value);
        ctx.fillRect(cellX, cellY, cellWidth, cellHeight);

        const x = cellX + cellWidth / 2;
        const y = cellY + cellHeight / 2;
        ctx.fillStyle = '#000';
        ctx.fillText(value.toFixed(2), x, y);
      }
    }

    // Draw borders
    ctx.beginPath();
    ctx.moveTo(OFFSET, OFFSET);
    ctx.lineTo(canvas.width, OFFSET);
    ctx.moveTo(OFFSET, canvas.height);
    ctx.lineTo(canvas.width, canvas.height);
    ctx.moveTo(OFFSET, OFFSET);
    ctx.lineTo(OFFSET, canvas.height);
    ctx.moveTo(canvas.width, OFFSET);
    ctx.lineTo(canvas.width, canvas.height);
    ctx.stroke();

    // Draw the horizontal lines
    ctx.beginPath();
    for (let i = 0; i < rows.length; i++) {
      const y = OFFSET + i * cellHeight;
      ctx.moveTo(OFFSET, y);
      ctx.lineTo(canvas.width, y);
    }
    ctx.stroke();

    // Draw the vertical lines
    ctx.beginPath();
    for (let i = 0; i < columns.length; i++) {
      const x = OFFSET + i * cellWidth;
      ctx.moveTo(x, OFFSET);
      ctx.lineTo(x, canvas.height);
    }
    ctx.stroke();
  }, [rows, columns, correlationList]);

  return (
    <div style={{ display: 'flex', flexDirection: 'column', gap: 16 }}>
      <canvas
        ref={canvasRef}
        style={{ display: 'block' }}
        width={600}
        height={600}
      />
      <Button
        type="primary"
        style={{ alignSelf: 'end' }}
        disabled={!canvasRef.current}
        onClick={() => {
          if (!canvasRef.current) {
            return;
          }
          const link = document.createElement('a');
          link.download = 'correlation_matrix.png';
          link.href = canvasRef.current.toDataURL();
          link.click();
        }}
      >
        Download
      </Button>
    </div>
  );
};

export default MatrixCanvas;
