import { Props } from './WaterfallChart.types';
import React, { useRef } from 'react';
import { useEffect } from 'react';
import * as d3 from 'd3';
import { useState } from 'react';
import './WaterfallChart.css';

export default function WaterfallChart({ chartDetails }: Props): JSX.Element {
    const data = chartDetails.data;
    const parentRef = useRef<HTMLDivElement>(null);
    const [chartInit, setChartInit] = useState(false);
    const [width, setWidth] = useState(0);
    const height = 700;
    const xScale = d3.scaleLinear();
    const yScale = d3.scaleBand();
    const marginLeft = 70;
    const margin = 10;

    useEffect(() => {
        if (parentRef.current) {
            setWidth(parentRef?.current?.offsetWidth - marginLeft);
        }

        if (parentRef.current && !chartInit) {
            const svg = d3
                .select(parentRef.current)
                .append('svg')
                .attr('width', parentRef?.current?.offsetWidth)
                .attr('height', height + margin);

            const chart = svg
                .append('g')
                .attr('class', 'chart')
                .attr('width', parentRef?.current?.offsetWidth - marginLeft)
                .attr('height', height)
                .attr('transform', `translate(${marginLeft}, ${margin})`);

            chart
                .append('g')
                .attr('class', 'x-axis')
                .attr('transform', `translate(0, ${height - margin * 2})`);

            chart.append('g').attr('class', 'y-axis').attr('transform', `translate(0, ${-20})`);

            setChartInit(true);
        }
    }, [chartInit, width, height]);

    useEffect(() => {
        const resize = () => {
            setWidth(parentRef?.current?.offsetWidth || 0 - marginLeft);
            const svg = d3
                .select('svg')
                .attr('width', parentRef?.current?.offsetWidth || 0)
                .attr('height', height + margin);

            svg.select('.chart')
                .attr('width', parentRef?.current?.offsetWidth || 0 - marginLeft)
                .attr('height', height)
                .attr('transform', `translate(${marginLeft}, ${margin})`);
        };
        window.addEventListener('resize', resize);

        if (chartInit) {
            let total = 0;
            let sum = 0;
            const adjustedData: any[] = data.reduce((result: any, element) => {
                const item = { ...element, start: 0, end: 0 };
                const old = total;
                total += element.value;
                sum += element.value;

                item.start = old;
                item.end = total;
                item.label = element.label;

                result = [...result, { ...item }];
                return result;
            }, []);

            adjustedData.push({
                key: 'TOTAL',
                start: 0,
                end: sum,
                label: total.toFixed(2),
                value: total,
            });

            const chart = d3.select('.chart');

            yScale.range([0, height]).domain(adjustedData.map((d: any) => d.key));

            // Calculate minimum and maximum coordiates of the x scale
            const findDomainRange = (data: any) => {
                const minMaxRangeObject = data.reduce(
                    (acc: any, ele: any) => {
                        // Ignore object with total key for calculations
                        if (ele.key === 'TOTAL') return acc;

                        // Calculate running sum and find the corresponding min/max
                        acc.sum = acc.sum + (ele.end - ele.start);
                        acc.min = Math.min(acc.sum, acc.min);
                        acc.max = Math.max(acc.sum, acc.max);
                        return acc;
                    },
                    { min: 0, max: 0, sum: 0 },
                ); // Min max is set to 0 initially cause the y axis reference line is plotted at x:0

                delete minMaxRangeObject.sum;
                return minMaxRangeObject;
            };

            const domainRange = findDomainRange(adjustedData);
            xScale.range([0, width - 100]).domain([domainRange.min, domainRange.max]);
            const xAxis: any = d3.axisBottom(xScale);
            const yAxis: any = d3
                .axisLeft(yScale)
                .tickFormat(('' as unknown) as (domainValue: string, index: number) => string);

            const xGroup = chart.select('.x-axis');
            const yGroup = chart.select('.y-axis');

            xGroup.call(xAxis);
            yGroup.call(yAxis);

            yGroup.selectAll('.tick').remove();

            yGroup.select('path').attr('transform', `translate(${xScale(0)}, 0)`);

            const cell: any = chart.selectAll('.cell').data(adjustedData);

            cell.exit().remove();

            const cellEnter = cell
                .enter()
                .append('g')
                .attr('class', (d: any) => `cell ${d.key}`);

            cellEnter
                .append('foreignObject')
                .attr('class', 'bar')
                .attr('width', 0)
                .attr('height', 20)
                .attr('y', 0)
                .attr('x', 0);

            cellEnter
                .append('line')
                .attr('class', 'connector')
                .style('stroke', '#32a852')
                .style('stroke-dasharray', '5,3')
                .attr('y1', 30)
                .attr('y2', 0);

            cellEnter
                .append('foreignObject')
                .attr('class', 'label')
                .attr('width', 0)
                .attr('height', 0)
                .attr('y', 0)
                .attr('x', 0);

            cellEnter
                .append('foreignObject')
                .attr('class', 'yLabel')
                .attr('width', 0)
                .attr('height', yScale.bandwidth() - 10)
                .attr('y', 0)
                .attr('x', 0);

            const cellUpdate = cellEnter.merge(cell);

            cellUpdate
                .select('.bar')
                .attr('transform', (d: any) => `translate(0, ${yScale(d.key)})`)
                .transition()
                .duration(500)
                .attr('width', (d: any) => Math.abs(xScale(d.start) - xScale(d.end)))
                .attr('x', (d: any) => xScale(Math.min(d.start, d.end)));

            cellUpdate
                .select('.bar')
                .html((d: any) => `<div class="rect ${d.value >= 0 ? 'positive' : 'negative'}"></div>`);

            cellUpdate
                .select('.label')
                .attr('transform', (d: any) => `translate(0, ${yScale(d.key)})`)
                .attr('height', 20)
                .attr('width', (d: any) => Math.abs(xScale(d.start) - xScale(d.end)))
                .attr('y', 0)
                .transition()
                .delay(300)
                .duration(500)
                .attr('x', (d: any) => (d.value >= 0 ? xScale(d.end) + 5 : xScale(d.start) + 5));

            cellUpdate.select('.label').html((d: any) => {
                const modSum = adjustedData.reduce(
                    (acc: number, ele: any) => acc + (ele.key === 'TOTAL' ? 0 : Math.abs(ele.value)),
                    0,
                );
                const percent = ((Math.abs(d.value) / modSum) * 100).toFixed(3);
                return d.value
                    ? d.key !== 'TOTAL'
                        ? `<div class="label">${d.label} <span class="sublabel">(${percent || 0}%)</span></div>`
                        : `<div class="label">${d.label}</div>`
                    : '';
            });

            cellUpdate
                .select('.yLabel')
                .attr('transform', (d: any) => `translate(0, ${yScale(d.key)})`)
                .attr('height', 20)
                .attr('x', -105)
                .attr('y', 0);

            cellUpdate
                .select('.yLabel')
                .html((d: any, _i: any, a: any) => (a.length > 1 ? `<div class="yLabel">${d.key}</div>` : ''));

            cellUpdate
                .select('.connector')
                .attr('transform', (d: any) => `translate(0, ${(yScale(d.key) || 0) + 10})`)
                .attr('y1', 0)
                .attr('y2', 0)
                .attr('x1', (d: any) => xScale(d.end))
                .attr('x2', (d: any) => xScale(d.end))
                .style('stroke', (d: any) => (d.key === 'TOTAL' ? 'transparent' : '#BDBDBD'))
                .attr('y1', (d: any) => 0)
                .attr('y2', yScale.bandwidth());

            return () => window.removeEventListener('resize', resize);
        }
    }, [chartInit, yScale, xScale, data, width, height]);
    return <div ref={parentRef} className="waterfall"></div>;
}
