# TensorFlow for deep learning and tensor operations
import tensorflow as tf

# Numerical and tabular data handling
import numpy as np
import pandas as pd

# Matplotlib for static plotting and visual customization
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib.colors import to_rgb
from matplotlib.lines import Line2D

# Python standard libraries for timing and datetime manipulation
import time
from datetime import datetime, timezone, timedelta

# System utilities for path management
import sys
import os
sys.path.append(os.path.abspath(".."))

# Data extraction and preparation utilities
from processData.modelData import getData
from processData.formatData import testingDataset, prepareInput

# Live data retrieval module
from testing.getLiveData import getLiveData

__all__ = ["runSimulator"]

def plotPredictions(predictionHistory, lapNumber, lapsUntilNextPit, compound, raceLaps, city):
    """
    Visualizes pit stop predictions and tyre compound over race laps.

    Parameters:
        predictionHistory (pd.DataFrame): Historical predictions to date.
        lapNumber (int): Current lap number.
        lapsUntilNextPit (int): Predicted number of laps until next pit stop.
        compound (int): Encoded compound prediction (index).
        raceLaps (int): Total number of laps in the race.
        city (str): City name used for naming the output file.

    Returns:
        pd.DataFrame: Updated prediction history including the new prediction.
    """

    # Tyre types and associated colors
    tyres = ["SOFT", "MEDIUM", "HARD", "INTERMEDIATE", "WET"]
    tyreColors = {       
        "SOFT": "#FF192E",
        "MEDIUM": "#F8C200",
        "HARD": "#272727",
        "INTERMEDIATE": "#009F2B",
        "WET": "#3765A8"
    }

    # Append the current prediction
    newPrediction = pd.DataFrame([{
        "LapNumber": int(lapNumber),
        "LapsUntilNextPit": lapsUntilNextPit,
        "Compound": sorted(map(str, tyres))[compound]
    }])
    predictionHistory = pd.concat([predictionHistory, newPrediction], ignore_index=True)

    # Plotting
    plt.figure(figsize=(10, 5))
    ax = plt.gca()

    # Sort and enforce integer x-axis
    predictionHistory = predictionHistory.sort_values("LapNumber").reset_index(drop=True)
    predictionHistory["LapNumber"] = predictionHistory["LapNumber"].astype(int)

    # Scatter all points with their corresponding color
    for idx, row in predictionHistory.iterrows():
        plt.scatter(row["LapNumber"], row["LapsUntilNextPit"],
                    color=tyreColors[row["Compound"]], label=None)

    # Draw gradient lines between consecutive points
    amountSegments = 50
    for i in range(len(predictionHistory) - 1):
        x0, y0 = predictionHistory.loc[i, ["LapNumber", "LapsUntilNextPit"]]
        x1, y1 = predictionHistory.loc[i + 1, ["LapNumber", "LapsUntilNextPit"]]

        c0 = np.array(to_rgb(tyreColors[predictionHistory.loc[i, "Compound"]]))
        c1 = np.array(to_rgb(tyreColors[predictionHistory.loc[i + 1, "Compound"]]))

        for j in range(amountSegments):
            t0 = j / amountSegments
            t1 = (j + 1) / amountSegments

            # Line coordinates
            x = [x0 * (1 - t0) + x1 * t0, x0 * (1 - t1) + x1 * t1]
            y = [y0 * (1 - t0) + y1 * t0, y0 * (1 - t1) + y1 * t1]

            # Color based on start of segment
            color = (1 - t0) * c0 + t0 * c1
            plt.plot(x, y, color=color)

    plt.xlabel("Lap number", fontsize=12)
    plt.ylabel("Laps to pit", fontsize=12)
    plt.title("Pit stop prediction over race laps", fontweight="bold", fontsize=16, loc="center")
    plt.grid(True)
    
    # Axis settings
    ax.set_xlim(0, raceLaps + 3)
    ax.set_ylim(0, predictionHistory["LapsUntilNextPit"].max() + 3)
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    
    # Add the legend
    legend = [Line2D([0], [0], marker="o", color="w", label=name.capitalize(),
                    markerfacecolor=tyreColors[name], markersize=10)
              for name in tyres]
    plt.legend(handles=legend, title="Compound", loc="upper right")

    plt.tight_layout()
    
    # Save the image
    savePath = f"../testing/{city}"
    os.makedirs(savePath, exist_ok=True)
    plt.savefig(f"{savePath}/lap{lapNumber}.pdf", dpi=300)

    plt.show()

    return predictionHistory

def runPrediction(model, tfRecordFile, year, roundNumber, driverNumber, lapDependentFeatures, nonLapDependentFeatures, target, maxLaps, lapNumber):
    """
    Runs a single prediction for a specific lap using a trained model.

    Parameters:
        model (tf.keras.Model): The model to use for prediction.
        tfRecordFile (str): Path to the TFRecord file containing race data.
        year (int): Year of the race.
        roundNumber (int): Round number of the race.
        driverNumber (str): Driver identifier.
        lapDependentFeatures (list): List of lap-based feature names.
        nonLapDependentFeatures (list): List of static feature names.
        target (str): Target label to predict.
        maxLaps (int): Maximum laps in the race.
        lapNumber (int): Current lap number.

    Returns:
        int: Predicted value (class index) from the model.
    """

    # Get preprocessed testing data for a specific driver and race
    lapDependentData, nonLapDependent, _ = next(iter(testingDataset(tfRecordFile, lapDependentFeatures, nonLapDependentFeatures, year, roundNumber, driverNumber, target)))

    # Grab the first n-1 laps, since these are completed
    prefixLength = lapNumber - 1

    # Compute the amount of padding we need
    paddingLength = maxLaps - prefixLength

    # Prepare formatted model inputs
    inputs = prepareInput(lapDependentData, nonLapDependent, prefixLength, paddingLength)

    # Run the prediction
    probs = model.predict(inputs, verbose=0)

    # Obtain the value
    prediction = tf.argmax(probs[0]).numpy()
    return prediction

def makePredictions(pitstopModel, compoundModel, tfRecordFile, year, roundNumber, driverNumber,
                    lapDependentFeaturesPitstop, nonLapDependentFeaturesPitstop,
                    lapDependentFeaturesCompound, nonLapDependentFeaturesCompound,
                    maxLaps, raceLaps, city, lapNumber, predictionHistory):
    """
    Runs both pit stop and compound predictions and updates the visualization.

    Parameters:
        pitstopModel (tf.keras.Model): Model for pit stop timing.
        compoundModel (tf.keras.Model): Model for tyre compound prediction.
        tfRecordFile (str): TFRecord file with race data.
        year (int): Race year.
        roundNumber (int): Race round number.
        driverNumber (str): Driver number to predict for.
        lapDependentFeaturesPitstop (list): Features used for pitstop prediction.
        nonLapDependentFeaturesPitstop (list): Static features for pitstop prediction.
        lapDependentFeaturesCompound (list): Features used for compound prediction.
        nonLapDependentFeaturesCompound (list): Static features for compound prediction.
        maxLaps (int): Total laps in the race.
        raceLaps (int): Number of race laps expected.
        city (str): City where the race is held.
        lapNumber (int): Current lap number.
        predictionHistory (pd.DataFrame): Existing prediction history.

    Returns:
        pd.DataFrame: Updated prediction history.
    """

    # Run the prediction for pitstop timing
    lapsToPit = runPrediction(pitstopModel, tfRecordFile, year, roundNumber, driverNumber, lapDependentFeaturesPitstop, nonLapDependentFeaturesPitstop, "LapsUntilNextPit", maxLaps, lapNumber)

    # Run the prediction for compund
    compound = runPrediction(compoundModel, tfRecordFile, year, roundNumber, driverNumber, lapDependentFeaturesCompound, nonLapDependentFeaturesCompound, "Compound", maxLaps, lapNumber)

    # Plot the predictions
    predictionHistory = plotPredictions(predictionHistory, lapNumber, lapsToPit, compound, raceLaps, city)
    
    return predictionHistory

def runSimulator(city, driverNumber):
    """
    Main entry point to simulate real-time predictions for a Formula 1 race.

    This function:
        - Waits for the race to start based on official timing.
        - Periodically checks for new lap completions.
        - If a new lap is detected, makes predictions for pit stop timing and tyre compound.
        - Saves plots of predictions over time.

    Parameters:
        city (str): The host city of the race.
        driverNumber (str): The driver number to track.
    
    Returns:
        None
    """

    # Extract the necessary data to build the models
    data, predictionData, raceStartTime = getData(driverNumber, city)

    # Time delta to load the data
    timeInterval = 10

    # Dataframe to store the seen laps
    lapsData = pd.DataFrame()

    # Dataframe to store all predictions
    predictionHistory = pd.DataFrame(columns=["LapNumber", "LapsUntilNextPit", "Compound"])

    # Convert raceStartTime string to datetime object
    startTime = datetime.fromisoformat(raceStartTime)

    # Initialize previousTime
    previousTime = startTime - timedelta(seconds=timeInterval)
    currentTime = startTime

    # Wait until start time
    while datetime.now(timezone.utc) < startTime:
        print(f"Waiting for race start time... ({datetime.now(timezone.utc)} < {startTime})")
        time.sleep(1)

    print(f"Starting simulated live updates from {startTime.strftime('%d/%m/%Y at %H:%M:%S')}")

    while True:
        try:
            iterationStart = time.time()

            print(f"Retrieving data at {currentTime.strftime('%H:%M:%S')}")

            # Retrieve the data
            newData, lapsData, lapNumber = getLiveData(
                *data,
                driverNumber,
                currentTime.isoformat(timespec="microseconds"),
                previousTime.isoformat(timespec="microseconds"),
                lapsData
            )

            # Update time markers
            previousTime = currentTime
            currentTime += timedelta(seconds=timeInterval)

            # If there's new data we make predictions using the simulator
            if newData:
                predictionHistory = makePredictions(*predictionData, lapNumber, predictionHistory)

        except Exception as e:
            pass

        # Calculate elapsed time and adjust sleep to maintain fixed interval
        elapsedTime = time.time() - iterationStart
        sleepDuration = max(0, timeInterval - elapsedTime)
        time.sleep(sleepDuration)