/* eslint-disable react-hooks/exhaustive-deps */
import { useEffect, useRef } from "react";
import * as d3 from "d3";
import { Box, Card, CircularProgress } from "@mui/material";
import { RootState } from "@/redux/store";
import { useSelector } from "react-redux";
import { isEmpty } from "lodash";

const ParallelCoordinates = ({ data, dimensions, dimensionsTypes }) => {
  const svgRef = useRef(null);
  const svgWidth = 1200;
  const svgHeight = 400;
  const margin = { top: 50, right: 0, bottom: 40, left: 0 };

  useEffect(() => {
    if (data.length > 0) {
      const svg = d3.select(svgRef.current);
      svg.selectAll("*").remove();

      const x = d3
        .scalePoint()
        .range([margin.left, svgWidth - margin.right])
        .padding(1)
        .domain(dimensions);

      const y = {};
      for (const dimension of dimensions) {
        if (dimensionsTypes[dimension] === "continuous") {
          y[dimension] = d3
            .scaleLinear()
            .domain(d3.extent(data, (d) => d[dimension]))
            .range([svgHeight - margin.bottom, margin.top]);
        } else {
          y[dimension] = d3
            .scalePoint()
            .domain(
              data
                .map((d) => d[dimension])
                .reduce((acc, value) => {
                  if (acc.indexOf(value) === -1) acc.push(value);
                  return acc;
                }, [])
            )
            .range([svgHeight - margin.bottom, margin.top])
            .padding(0.5);
        }
      }

      svg
        .selectAll("myAxis")
        .data(dimensions)
        .enter()
        .append("g")
        .attr("transform", (d) => `translate(${x(d)})`)
        .each(function (this: SVGPathElement, d) {
          d3.select(this).call(d3.axisLeft().scale(y[d]));
        })
        .append("text")
        .style("text-anchor", "middle")
        .attr("y", margin.top - 30)
        .text((d) => d)
        .style("fill", "black");

      const color = d3
        .scaleOrdinal()
        .domain(data.map((_, i) => i))
        .range(data.map(() => d3.interpolateCool(Math.random())));

      const tooltip = d3
        .select("body")
        .append("div")
        .attr("class", "tooltip")
        .style("opacity", 0)
        .style("position", "absolute")
        .style("padding", "10px")
        .style("background", "white")
        .style("border", "1px solid black")
        .style("border-radius", "5px")
        .style("pointer-events", "none");

      svg
        .selectAll("myPath")
        .data(data)
        .enter()
        .append("path")
        .attr("d", (d) => d3.line()(dimensions.map((p) => [x(p), y[p](d[p])])))
        .style("fill", "none")
        .style("stroke", (_, i) => color(i))
        .style("opacity", 0.5)
        .on("mouseenter", function (this: SVGPathElement, event, d) {
          d3.select(this).style("stroke-width", "4px");
          tooltip.transition().duration(200).style("opacity", 0.9);
          let htmlContent = dimensions
            .map((p) => `<strong>${p}:</strong> ${d[p]}`)
            .join("<br/>");
          tooltip
            .html(htmlContent)
            .style("left", event.pageX + "px")
            .style("top", event.pageY - 28 + "px");
        })
        .on("mouseleave", function (this: SVGPathElement) {
          d3.select(this).style("stroke-width", "1px");
          tooltip.transition().duration(500).style("opacity", 0);
        });
    }
  }, [data, dimensions, dimensionsTypes]);

  useEffect(() => {
    return () => {
      d3.select(".tooltip").remove(); // Cleanup the tooltip when component unmounts
    };
  }, []);

  return (
    <svg
      ref={svgRef}
      width="100%"
      height="100%"
      viewBox={`0 0 ${svgWidth} ${svgHeight}`}
    ></svg>
  );
};

const TrailChart = () => {
  const {
    hptunerResult: { data, loading },
  } = useSelector((state: RootState) => state.experiment);
  const { trials = [] } = data || {};
  const suceededTrials = trials.filter((item) => item.status === "Succeeded");

  const getData = () => {
    if (!isEmpty(suceededTrials)) {
      return suceededTrials.map((item) => {
        const { objective = [], parameters = [] } = item;

        const tempTrial: any = {};
        if (!isEmpty(objective)) {
          objective.forEach((item) => {
            tempTrial[item.name] = item?.latest || "-";
          });
        }
        if (!isEmpty(parameters)) {
          parameters.forEach((item) => {
            tempTrial[item.name] = item?.value || "-";
          });
        }

        return tempTrial;
      });
    }
    return [];
  };

  /*   const dummyData = Array.from({ length: 50 }, () => ({
      'validation accuracy': Math.random(),
      'train accuracy': Math.random(),
      'lr': Math.random() * 0.01,
      'num layers': Math.floor(Math.random() * 10),
      'optimizer': ['adam', 'sgd', 'rmsprop'][Math.floor(Math.random() * 3)]
    })); */

  const getTrialDimensions = () => {
    if (!isEmpty(suceededTrials)) {
      const sample = suceededTrials[0] || {};
      const tempData: any = [];
      tempData.push(...(sample?.objective?.map((item) => item.name) || []));
      tempData.push(...(sample?.parameters?.map((item) => item.name) || []));

      return tempData;
    }

    return [];
  };

  // const dimensions = ['validation accuracy', 'train accuracy', 'lr', 'num layers', 'optimizer'];

  const getDimensionTypes = () => {
    if (!isEmpty(suceededTrials)) {
      const sample = suceededTrials[0] || {};
      const { objective, parameters } = sample;

      const tempTrial: any = {};
      if (!isEmpty(objective)) {
        objective.forEach((item) => {
          tempTrial[item.name] = "continuous";
        });
      }
      if (!isEmpty(parameters)) {
        parameters.forEach((item) => {
          tempTrial[item.name] = "continuous";
        });
      }

      return tempTrial;
    }

    return {};
  };

  /*   const dimensionsTypes = {
      'validation accuracy': 'continuous',
      'train accuracy': 'continuous',
      'lr': 'continuous',
      'num layers': 'continuous',
      'optimizer': 'ordinal' 
    }; */

  return (
    <Card sx={{ mb: 2, py: 2 }}>
      {loading ? (
        <Box
          component="div"
          sx={{
            display: "flex",
            justifyContent: "center",
            alignItems: "center",
            width: "100%",
            height: "400px",
          }}
        >
          <CircularProgress color="primary" size={50} />
        </Box>
      ) : (
        <ParallelCoordinates
          data={getData()}
          dimensions={getTrialDimensions()}
          dimensionsTypes={getDimensionTypes()}
        />
      )}
    </Card>
  );
};

export default TrailChart;
