import { Selection, EnterElement } from "d3"
import { D3VisualizationRenderer } from "../../D3VisualizationRenderer";
import { D3ScatterPlot } from "./D3ScatterPlot";
import { D3ScatterPlotConfigurationBuilder } from "./D3ScatterPlotConfigurationBuilder";
import { D3ScatterPlotXAxis } from "./D3ScatterPlotXAxis";
import { D3ScatterPlotYAxis } from "./D3ScatterPlotYAxis";
import { D3ScatterPlotPointConfig } from "../../../../Types/ScatterPlot";
import { D3Timeline } from "../../../D3/Timeline/D3Timeline";
import { D3ScatterPlotCanvas } from "./D3ScatterPlotCanvas";
import { requestAutoScale } from "../../autoScale";
import { TimeSeriesData } from "../../../../Data/TimeSeriesData";
import { DataSource } from "../../../../Types/DataSource";

export class D3ScatterPlotRenderer extends D3VisualizationRenderer<D3ScatterPlot, D3ScatterPlotConfigurationBuilder> {
	public className: string = "d3-scatter-plot"

	public xAxis?: D3ScatterPlotXAxis
	public yAxis?: D3ScatterPlotYAxis
	private timelineSpacing = 45
	public scatterPlotCanvas?: D3ScatterPlotCanvas
	private timelineContainerClassName: string = "d3-scatter-plot-timeline"

	public viewTimesChanged(): void {
		this.timeline?.viewTimesChanged()
		this.visualization.data = this.getConfigs()

		if (this.visualization.config.autoScale && !this.timeline?.isPlaying()) {
			requestAnimationFrame(() => {
				requestAutoScale(this.visualization.timeSeriesPageManager, this.autoScale)
			})
		}

		if (this.scatterPlotCanvas) {
			this.scatterPlotCanvas.updateData(this.visualization.data)
			this.scatterPlotCanvas.render()
		}
	}

	public onTimelineSliderDrag(): void {
		this.timeline?.viewTimesChanged()
		this.scatterPlotCanvas?.updateData(this.getConfigs())
		this.scatterPlotCanvas?.render()
	}

	protected enter(enterElements: Selection<EnterElement, any, any, any>): Selection<any, any, any, any> {
		const svg = enterElements.append("svg").attr("class", this.className).attr("width", "100%").attr("height", "100%")

		// Add legend
		const legendGroup = svg.append("g")
			.attr("class", "d3-scatter-plot-legend")
			.style('font-family', '"source-sans-pro", sans-serif')
			.style("font-weight", "bold")
			.style("font-size", "14px")

		// Y-Axis Legend
		legendGroup.append("text")
			.attr("class", "d3-scatter-plot-legend-yAxis")
			.text(`[${this.visualization.config.yAxisConfig.modality}]`)
			.attr("transform", `rotate(-90) translate(${-this.visualization.boundingBox.y}, ${this.visualization.boundingBox.x - 40})`)
			.style("text-anchor", "end")

		// X-Axis Legend
		legendGroup.append("text")
			.attr("class", "d3-scatter-plot-legend-xAxis")
			.text(`[${this.visualization.config.xAxisConfig.modality}]`)
			.style("text-anchor", "end")
			.attr("x", this.visualization.boundingBox.width + this.visualization.boundingBox.x)
			.attr("y", this.visualization.boundingBox.height + this.visualization.boundingBox.y + 35);

		const boundingBox = svg
			.append("g")
			.attr("class", "d3-scatter-plot-bounding-box")
			.attr("transform", `translate(${this.visualization.boundingBox.x}, ${this.visualization.boundingBox.y})`)

		boundingBox.each((config, index, nodes) => this.createBoundingBoxChildren(config, index, nodes))


		const timelineContainer = svg
			.append('g')
			.attr("class", this.timelineContainerClassName)
			.attr("transform", `translate(${this.visualization.boundingBox.x}, ${this.visualization.boundingBox.y + this.visualization.boundingBox.height + this.timelineSpacing})`)

		timelineContainer.each((config, index, nodes) => {
			this.timeline = new D3Timeline(nodes[index], this.configBuilder.getTimelineConfig(), this.visualization.timeSeriesPageManager, this.visualization.reactCallbacks)
		})

		return svg
	}

	protected update = (updatedSVG: Selection<any, any, any, any>): Selection<any, any, any, any> => {
		const svg = updatedSVG

		// Update the positions and text of the legends
		const legendGroup = svg.select(".d3-scatter-plot-legend")

		// Update Y-Axis Legend
		legendGroup.select(".d3-scatter-plot-legend-yAxis")
			.text(`[${this.visualization.config.yAxisConfig.modality}]`)
			.attr("transform", `rotate(-90) translate(${-this.visualization.boundingBox.y}, ${this.visualization.boundingBox.x - 40})`)

		// Update X-Axis Legend
		legendGroup.select(".d3-scatter-plot-legend-xAxis")
			.text(`[${this.visualization.config.xAxisConfig.modality}]`)
			.attr("x", this.visualization.boundingBox.width + this.visualization.boundingBox.x)
			.attr("y", this.visualization.boundingBox.height + this.visualization.boundingBox.y + 35)

		const timelineContainer = svg.select("." + this.timelineContainerClassName)
		timelineContainer.attr("transform", `translate(${this.visualization.boundingBox.x}, ${this.visualization.boundingBox.y + this.visualization.boundingBox.height + this.timelineSpacing})`)

		this.timeline?.render()

		return svg
	}

	protected canRender(): boolean {
		return this.visualization.boundingBox.height > 0 && this.visualization.boundingBox.width > 0
	}

	private createBoundingBoxChildren = (config: D3ScatterPlot, index: number, nodes: ArrayLike<SVGGElement>) => {
		const root = nodes[index]

		this.xAxis = new D3ScatterPlotXAxis(root, this.configBuilder.getXAxisConfig(), this.visualization.reactCallbacks)
		this.yAxis = new D3ScatterPlotYAxis(root, this.configBuilder.getYAxisConfig(), this.visualization.reactCallbacks)

		this.visualization.data = this.getConfigs()

		this.scatterPlotCanvas = new D3ScatterPlotCanvas(root, this.configBuilder.getScatterPlotCanvasConfig(), this.visualization.reactCallbacks)

	}

	private findLowerIndex = (data: TimeSeriesData, timestamp: number) => {
		let low = 0
		let high = data.data.length - 1
		while (low <= high) {
			const mid = Math.floor((low + high) / 2)
			const midValue = data.times[mid]
			if (midValue && midValue < timestamp) {
				low = mid + 1
			} else {
				high = mid - 1
			}
		}
		return low
	}

	private findUpperIndex = (data: TimeSeriesData, timestamp: number) => {
		let low = 0
		let high = data.data.length
		while (low < high) {
			const mid = Math.floor((low + high) / 2)
			const midValue = data.times[mid]
			if (midValue && timestamp < midValue) {
				high = mid
			} else {
				low = mid + 1
			}
		}
		return low
	}

	public getConfigs(): D3ScatterPlotPointConfig[] {
		const pages = this.visualization.timeSeriesPageManager.getPagesInView()
		let data: D3ScatterPlotPointConfig[] = []

		const [startDate, endDate] = this.visualization.config.viewScale.domain()
		const startTime = startDate.getTime()
		const endTime = endDate.getTime()
		const dataObjectId = this.visualization.reactCallbacks.dataSourceMap.get(DataSource.CURRENT_PATIENT) as number

		pages.forEach(page => {
			const xModalityData: TimeSeriesData | undefined = page?.data.get(dataObjectId)?.get(this.visualization.config.xAxisConfig.modality)
			const yModalityData: TimeSeriesData | undefined = page?.data.get(dataObjectId)?.get(this.visualization.config.yAxisConfig.modality)

			if (xModalityData && yModalityData) {
				const lowerIndex = this.findLowerIndex(xModalityData, startTime)
				const upperIndex = this.findUpperIndex(xModalityData, endTime)

				for (let i = lowerIndex; i < upperIndex; i++) {
					const timestamp = xModalityData.times[i]
					const x = xModalityData.data[i]
					const y = yModalityData.data[i]

					if (x && y && timestamp === yModalityData.times[i]) {
						data.push({
							color: this.visualization.config.color,
							shape: this.visualization.config.shape,
							size: this.visualization.config.size,
							x,
							y,
							timestamp,
							xScale: this.visualization.xScale,
							yScale: this.visualization.yScale
						})
					}
				}
			}
		})

		return data
	}

	public updateChildren = () => {
		this.visualization.data = this.getConfigs()
		this.xAxis?.updateConfig(this.configBuilder.getXAxisConfig())
		this.yAxis?.updateConfig(this.configBuilder.getYAxisConfig())
		this.timeline?.updateConfig(this.configBuilder.getTimelineConfig())
		this.scatterPlotCanvas?.updateConfig(this.configBuilder.getScatterPlotCanvasConfig())
	}

	public renderPage = () => {
		this.updateChildren()
		if (this.visualization.config.autoScale && !this.timeline?.isPlaying()) {
			this.viewTimesChanged()
		}
	}

	public onYAxisDrag = () => {
		this.scatterPlotCanvas?.render()
	}

	public onXAxisDrag = () => {
		this.scatterPlotCanvas?.render()
	}

	public onYAxisAutoScale = () => {
		if (!this.visualization.data || this.visualization.data.length === 0) {
			return
		}

		// Find both the highest and lowest Y values in the data
		const { minY, maxY } = this.visualization.data.reduce((acc, point) => ({
			minY: Math.min(acc.minY, point.y),
			maxY: Math.max(acc.maxY, point.y)
		}), { minY: this.visualization.data[0].y, maxY: this.visualization.data[0].y })

		// Update the Y scale domain to match the min and max Y values exactly
		// This will ensure that points are spread out to use the full screen space
		// We still apply a slight adjustment to minY and maxY to ensure that points
		// do not sit exactly on the axis boundaries, improving visibility
		const padding = (maxY - minY) * 0.05
		const adjustedMinY = minY - padding
		const adjustedMaxY = maxY + padding

		this.visualization.reactCallbacks.setRootConfig(previous => {
			const yAxisConfig = { ...previous.yAxisConfig }
			yAxisConfig.maxValue = adjustedMaxY
			yAxisConfig.minValue = adjustedMinY
			return { ...previous, yAxisConfig: yAxisConfig }
		})

		// Ensure the minimum Y value does not go below 0 if the data is all positive
		this.visualization.yScale.domain([Math.max(0, adjustedMinY), adjustedMaxY])
		this.scatterPlotCanvas?.render()
		this.yAxis?.render()
	}

	public onXAxisAutoScale = () => {
		if (!this.visualization.data || this.visualization.data.length === 0) {
			return
		}

		// Find both the lowest and highest X values in the data
		const { minX, maxX } = this.visualization.data.reduce((acc, point) => ({
			minX: Math.min(acc.minX, point.x),
			maxX: Math.max(acc.maxX, point.x)
		}), { minX: this.visualization.data[0].x, maxX: this.visualization.data[0].x })

		// Add padding to minX and maxX for better visualization and to ensure
		// that points are as spread out as possible, avoiding the edges
		const padding = (maxX - minX) * 0.05
		const adjustedMinX = minX - padding
		const adjustedMaxX = maxX + padding

		this.visualization.reactCallbacks.setRootConfig(previous => {
			const xAxisConfig = { ...previous.xAxisConfig }
			xAxisConfig.maxValue = adjustedMaxX
			xAxisConfig.minValue = adjustedMinX
			return { ...previous, xAxisConfig: xAxisConfig }
		})

		// Update the X scale domain to match the adjusted min and max X values
		this.visualization.xScale.domain([adjustedMinX, adjustedMaxX])
		this.scatterPlotCanvas?.render()
		this.xAxis?.render()
	}

	public autoScale = () => {
		// There's no need to auto scale during live mode
		// This is here because after a page loads, it tries to auto scale.
		if (this.visualization.config.liveModeEnabled) {
			return
		}

		this.onXAxisAutoScale()
		this.onYAxisAutoScale()
	}
}