import {
	TCalendarAdjustment,
	TMachine,
	TPeriod,
	TPeriodString,
	TPhasedItem,
	TPhaseSchedule,
} from '@repo/types'
import {
	addDays,
	addMinutes,
	areIntervalsOverlapping,
	differenceInMinutes,
	isAfter,
	isBefore,
	max,
	min,
	startOfWeek,
} from 'date-fns'

import { adjustOpenPeriods } from '../planning/adjust-open-periods'
import { getCalendarAdjustmentsForMachine } from '../planning/get-calendar-adjustments-for-machine'
import { getOpenPeriods } from '../planning/get-open-periods'

/******************************************************
 * 1) Types
 ******************************************************/

type Booking = {
	startDate: string
	endDate: string
	staffGroups: TPhasedItem[]
	phases: TPhaseSchedule<TPeriodString>
	machineId: string
}

type DemandChange = {
	time: Date
	staffGroup: string
	delta: number
}

type DemandBlock = {
	period: TPeriod
	demands: { [staffGroup: string]: number }
}

type AggregatedDemand = {
	min: { [staffGroup: string]: number }
	avg: { [staffGroup: string]: number }
	max: { [staffGroup: string]: number }
	total: {
		min: number
		avg: number
		max: number
	}
}

export type AggregatedInterval = {
	startDate: string // ISO 8601 date-time string (e.g. "2024-01-01T12:00:00.000Z")
	endDate: string // ISO 8601 date-time string (e.g. "2024-01-01T12:00:00.000Z")
	aggregatedDemand: AggregatedDemand
}

export type StaffDemandResult = {
	intervals: AggregatedInterval[]
	staffGroups: string[]
}

/******************************************************
 * 2) Utility Functions
 ******************************************************/

/** Splits [period.startDate, period.endDate) into intervals (e.g. 15min, 1h) */
function splitIntoIntervals(args: {
	period: TPeriod
	intervalMinutes: number
}): TPeriod[] {
	const { period, intervalMinutes } = args
	const result: TPeriod[] = []

	// Special handling for 7-day intervals to align with calendar weeks
	if (intervalMinutes === 10080) {
		// 7 days * 24 hours * 60 minutes
		// For week intervals, always extend to complete weeks
		const firstWeekStart = startOfWeek(period.startDate, { weekStartsOn: 1 })
		const lastWeekStart = startOfWeek(period.endDate, { weekStartsOn: 1 })

		let current = firstWeekStart
		while (!isAfter(current, lastWeekStart)) {
			const weekEnd = addDays(current, 7)
			result.push({
				startDate: current,
				// Ensure the week ends at Sunday 23:59:59.999
				endDate: new Date(weekEnd.getTime() - 1),
			})
			current = weekEnd
		}
		return result
	}

	// Normal handling for other intervals
	let current = period.startDate
	while (isBefore(current, period.endDate)) {
		const next = addMinutes(current, intervalMinutes)
		result.push({
			startDate: current,
			endDate: isAfter(next, period.endDate) ? period.endDate : next,
		})
		current = next
	}
	return result
}

/** Returns overlap in minutes between [a.startDate, a.endDate) and [b.startDate, b.endDate). */
function getOverlapInMinutes(args: { a: TPeriod; b: TPeriod }): number {
	const { a, b } = args
	const latestStart = max([a.startDate, b.startDate])
	const earliestEnd = min([a.endDate, b.endDate])
	return Math.max(0, differenceInMinutes(earliestEnd, latestStart))
}

/** Rounds a number to two decimal places. */
function roundTwoDecimals(value: number): number {
	return Math.round(value * 100) / 100
}

/******************************************************
 * 3) Main Calculation - Phase-Aware Line-Sweep
 ******************************************************/
export function calculateStaffDemand(args: {
	bookings: Booking[]
	period: TPeriod
	intervalMinutes: number
	machines: TMachine[]
	calendarAdjustments: TCalendarAdjustment[]
}): StaffDemandResult {
	const { bookings, period, intervalMinutes, machines, calendarAdjustments } =
		args

	/******************************************************
	 * A) Gather Change Points
	 ******************************************************/
	const changePoints: DemandChange[] = []

	for (const booking of bookings) {
		// Get machine availability for this booking
		const machine = machines.find(m => m.id === booking.machineId)
		if (!machine) {
			console.warn(`Machine ${booking.machineId} not found for booking`)
			continue
		}

		// Get open periods for this machine during the booking's duration
		const openPeriods = adjustOpenPeriods({
			startDate: new Date(booking.phases.before.startDate),
			endDate: new Date(booking.phases.after.endDate),
			calendarAdjustments: getCalendarAdjustmentsForMachine({
				calendarAdjustments,
				machineId: machine.id,
			}),
			openPeriods: getOpenPeriods({
				startDate: new Date(booking.phases.before.startDate),
				endDate: new Date(booking.phases.after.endDate),
				availability: machine.availability,
			}),
		})

		// Convert booking phase timestamps to Date objects
		const beforeStart = new Date(booking.phases.before.startDate)
		const beforeEnd = new Date(booking.phases.before.endDate)

		const duringStart = new Date(booking.phases.during.startDate)
		const duringEnd = new Date(booking.phases.during.endDate)

		const afterStart = new Date(booking.phases.after.startDate)
		const afterEnd = new Date(booking.phases.after.endDate)

		// For each staff group, we add or skip changes based on which phases apply
		for (const staffGroup of booking.staffGroups) {
			// Helper function to add change points only during open periods
			const addChangePointsForPeriod = (
				start: Date,
				end: Date,
				isNeeded: boolean,
			) => {
				if (!isNeeded || !isBefore(start, end)) return

				// Clip start/end to period boundaries
				const periodStart = isBefore(start, period.startDate)
					? period.startDate
					: start
				const periodEnd = end

				// Find overlapping open periods
				for (const openPeriod of openPeriods) {
					if (
						areIntervalsOverlapping(
							{ start: periodStart, end: periodEnd },
							{ start: openPeriod.startDate, end: openPeriod.endDate },
							{ inclusive: true },
						)
					) {
						// Calculate intersection
						const intersectionStart = max([periodStart, openPeriod.startDate])
						const intersectionEnd = min([periodEnd, openPeriod.endDate])

						// Add change points at intersection boundaries
						changePoints.push({
							time: intersectionStart,
							staffGroup: staffGroup.name,
							delta: staffGroup.factor,
						})
						changePoints.push({
							time: intersectionEnd,
							staffGroup: staffGroup.name,
							delta: -staffGroup.factor,
						})
					}
				}
			}

			// Add change points for each phase, but only during open periods
			addChangePointsForPeriod(beforeStart, beforeEnd, staffGroup.phases.before)
			addChangePointsForPeriod(duringStart, duringEnd, staffGroup.phases.during)
			addChangePointsForPeriod(afterStart, afterEnd, staffGroup.phases.after)
		}
	}

	// Sort all change points by time ascending
	changePoints.sort((a, b) => a.time.getTime() - b.time.getTime())

	/******************************************************
	 * B) Build Demand Blocks with a Line-Sweep
	 ******************************************************/
	const demandBlocks: DemandBlock[] = []
	const demandMap: { [staffGroup: string]: number } = {}

	let previousTime = period.startDate

	for (let i = 0; i < changePoints.length; i++) {
		const { time, staffGroup, delta } = changePoints[i]
		// Clip the time if it's before the period's start
		const currentTime = isBefore(time, period.startDate)
			? period.startDate
			: time

		// If we're beyond the period end, we can break
		if (isAfter(currentTime, period.endDate)) {
			break
		}

		// If currentTime has advanced beyond previousTime, create a block
		if (isAfter(currentTime, previousTime)) {
			demandBlocks.push({
				period: {
					startDate: previousTime,
					endDate: currentTime,
				},
				demands: { ...demandMap }, // snapshot
			})
			previousTime = currentTime
		}

		// Update demand for this staffGroup
		demandMap[staffGroup] = (demandMap[staffGroup] || 0) + delta
	}

	// Create a final block up to period.endDate, if needed
	if (isBefore(previousTime, period.endDate)) {
		demandBlocks.push({
			period: {
				startDate: previousTime,
				endDate: period.endDate,
			},
			demands: { ...demandMap },
		})
	}

	/******************************************************
	 * C) Aggregate By Interval
	 ******************************************************/
	const intervals = splitIntoIntervals({ period, intervalMinutes })

	// Collect all staff group names
	const allStaffGroups = new Set<string>()
	for (const booking of bookings) {
		for (const staffGroup of booking.staffGroups) {
			allStaffGroups.add(staffGroup.name)
		}
	}
	const staffGroups = [...allStaffGroups]

	const result: StaffDemandResult = {
		intervals: [],
		staffGroups,
	}

	// For each interval (e.g. 15min, 1h), compute min/max/time-weighted avg
	for (const chunk of intervals) {
		const minMap: { [staffGroup: string]: number } = {}
		const maxMap: { [staffGroup: string]: number } = {}
		const sumMap: { [staffGroup: string]: number } = {}

		for (const staffGroup of staffGroups) {
			minMap[staffGroup] = Number.POSITIVE_INFINITY
			maxMap[staffGroup] = Number.NEGATIVE_INFINITY
			sumMap[staffGroup] = 0
		}

		const intervalMinutes = differenceInMinutes(chunk.endDate, chunk.startDate)

		// Go through all blocks that might overlap with this interval
		for (const block of demandBlocks) {
			const overlap = getOverlapInMinutes({
				a: block.period,
				b: chunk,
			})
			if (overlap > 0) {
				// For each staff group in this block
				for (const staffGroup of staffGroups) {
					const demandValue = block.demands[staffGroup] || 0
					if (demandValue < minMap[staffGroup]) {
						minMap[staffGroup] = demandValue
					}
					if (demandValue > maxMap[staffGroup]) {
						maxMap[staffGroup] = demandValue
					}
					sumMap[staffGroup] += demandValue * overlap
				}
			}
		}

		// Build aggregated demand for this interval
		const aggregatedDemand: AggregatedDemand = {
			min: {},
			avg: {},
			max: {},
			total: { min: 0, avg: 0, max: 0 },
		}

		for (const staffGroup of staffGroups) {
			const minVal =
				minMap[staffGroup] === Number.POSITIVE_INFINITY ? 0 : minMap[staffGroup]
			const maxVal =
				maxMap[staffGroup] === Number.NEGATIVE_INFINITY ? 0 : maxMap[staffGroup]
			// Time-weighted average
			const avgVal =
				intervalMinutes === 0 ? 0 : sumMap[staffGroup] / intervalMinutes

			aggregatedDemand.min[staffGroup] = roundTwoDecimals(minVal)
			aggregatedDemand.max[staffGroup] = roundTwoDecimals(maxVal)
			aggregatedDemand.avg[staffGroup] = roundTwoDecimals(avgVal)

			aggregatedDemand.total.min += aggregatedDemand.min[staffGroup]
			aggregatedDemand.total.max += aggregatedDemand.max[staffGroup]
			aggregatedDemand.total.avg += aggregatedDemand.avg[staffGroup]
		}

		// Round totals
		aggregatedDemand.total.min = roundTwoDecimals(aggregatedDemand.total.min)
		aggregatedDemand.total.max = roundTwoDecimals(aggregatedDemand.total.max)
		aggregatedDemand.total.avg = roundTwoDecimals(aggregatedDemand.total.avg)

		result.intervals.push({
			startDate: chunk.startDate.toISOString(),
			endDate: chunk.endDate.toISOString(),
			aggregatedDemand,
		})
	}

	return result
}
