# Standard library utilities
import logging
from tqdm import tqdm

# Data handling and manipulation
import numpy as np
import pandas as pd
pd.options.mode.chained_assignment = None

# External F1 data library
import fastf1
logging.getLogger("fastf1").setLevel(logging.ERROR)

# Static race metadata
from fixedData.raceDictionary import getRaces

__all__ = ["loadTelemetry"]

def convertToAbsoluteTime(df, sessionStartTime, relativeTimeColumns):
    """
    Converts relative time columns in a DataFrame to absolute timestamps

    Parameters:
        df (pd.DataFrame): Input DataFrame containing relative time columns
        sessionStartTime (str or pd.Timestamp): Session start time to anchor relative values
        relativeTimeColumns (list of str): Columns with timedelta values to convert

    Returns:
        pd.DataFrame: Updated DataFrame with absolute timestamps
    """

    # Ensure it's a Pandas Timestamp
    sessionStartTime = pd.Timestamp(sessionStartTime)

    for col in relativeTimeColumns:
         # Check if it's timedelta
        if col in df.columns and pd.api.types.is_timedelta64_dtype(df[col]):
            # Convert relative to absolute time
            df[col] = sessionStartTime + df[col]

    return df

def convertToSeconds(series):
    """
    Converts a timedelta Series to seconds, rounded to three decimals

    Parameters:
        series (pd.Series): Timedelta Series

    Returns:
        pd.Series: Series in float seconds, or original Series if not timedelta
    """

    return series.dt.total_seconds().round(3) if pd.api.types.is_timedelta64_dtype(series) else series

def processLapTelemetry(lap, year, roundNumber, durationTimeColumns, telemetryDf, liveData):
    """
    Extracts and formats telemetry for a specific lap with race metadata

    Parameters:
        lap: Lap object or row with driver and lap number
        year (int): Race year
        roundNumber (int): Race round number
        durationTimeColumns (list): Timedelta columns to convert to seconds
        telemetryDf (pd.DataFrame): Raw telemetry DataFrame for live mode
        liveData (bool): Whether using real-time or pre-saved FastF1 session

    Returns:
        pd.DataFrame or None: Processed telemetry with metadata, or None if empty
    """

    try:
        if not liveData:
            # Get telemetry data for the lap
            telemetry = lap.get_car_data()

            driver = lap.Driver
            lapNumber = lap.LapNumber

            # Ensure telemetry is not empty
            if telemetry.empty:
                return None
        else:
            telemetry = telemetryDf.copy()
            driver = lap["Driver"]
            lapNumber = lap["LapNumber"]

        # Convert relative times to absolute timestamps
        for col in durationTimeColumns:
            if col in telemetry.columns and pd.api.types.is_timedelta64_dtype(telemetry[col]):
                telemetry[col] = convertToSeconds(telemetry[col])

        # Add metadata to the telemetry DataFrame
        telemetry["Driver"] = driver
        telemetry["LapNumber"] = lapNumber
        telemetry["RoundNumber"] = roundNumber
        telemetry["Year"] = year

        return telemetry

    except Exception as e:
        return None

def getTelemetryLabel(df):
    """
    Labels each lap as a final lap or not, accounting for lapped drivers

    Parameters:
        df (pd.DataFrame): Telemetry DataFrame with Year, RoundNumber, Driver, LapNumber

    Returns:
        pd.DataFrame: Updated DataFrame with LastLap indicator
    """

    # Initialize "LastLap" column with 0
    df["LastLap"] = 0

    # Find the last lap per driver
    lastLapsPerDriver = df.groupby(["Year", "RoundNumber", "Driver"])["LapNumber"].max().reset_index()

    # Find the max lap of the entire race
    maxLapInRace = df.groupby(["Year", "RoundNumber"])["LapNumber"].max().reset_index()
    maxLapInRace.rename(columns={"LapNumber": "MaxRaceLap"}, inplace=True)

    # Merge both to check conditions
    df = df.merge(lastLapsPerDriver, on=["Year", "RoundNumber", "Driver"], suffixes=("", "_DriverMax"))
    df = df.merge(maxLapInRace, on=["Year", "RoundNumber"], how="left")

    # Assign LastLap label:
    # - If the driver’s last lap is the same as the race's last lap then LastLap = 1.
    # - If they retired before the final lap, LastLap = 0 (default).
    # We subtract 3 to make sure lapped cars are counted (we consider any car is lapped more than 3 laps)
    df["LastLap"] = ((df["LapNumber"] == df["LapNumber_DriverMax"]) &
                               (df["LapNumber_DriverMax"] < df["MaxRaceLap"]-3)).astype(int)

    # Drop temporary columns
    df.drop(columns=["LapNumber_DriverMax", "MaxRaceLap"], inplace=True)

    return df

def groupTelemetryData(df):
    """
    Aggregates telemetry data by lap, driver, round, and year

    Parameters:
        df (pd.DataFrame): Raw telemetry data

    Returns:
        pd.DataFrame: Aggregated lap-level statistics with means and last lap flag
    """

    numericColumns = ["RPM", "Speed", "nGear", "Throttle", "Brake", "DRS"]
    aggregation = {}

    for col in numericColumns:
        if col in df.columns:
            aggregation[col] = lambda x: int(np.mean(x))

    aggregation["LastLap"] = "max"

    telemetryByLap = df.groupby(
        ["Driver", "LapNumber", "RoundNumber", "Year"], as_index=False
    ).agg(aggregation)

    # Order the results like in SQL
    telemetryByLap = telemetryByLap.sort_values(
        by=["Year", "RoundNumber", "Driver", "LapNumber"]
    ).reset_index(drop=True)

    return telemetryByLap

def loadTelemetryData(year, roundNumber, table, engine, liveData, telemetry):
    """
    Loads and processes telemetry data for a race and stores it in the database

    Parameters:
        year (int): Year of the race
        roundNumber (int): Round number of the race
        table (str): Name of database table to store results
        engine (sqlalchemy.Engine): SQLAlchemy engine for PostgreSQL
        liveData (bool): Whether using external live telemetry
        telemetry (pd.DataFrame or None): Input telemetry DataFrame if liveData is True

    Returns:
        None
    """

    # List to store all telemetry data
    allTelemetryData = []

    # List of columns that need conversion
    durationTimeColumns = ["SessionTime"]
    
    try:    
        if not liveData:
            # Load the race
            race = fastf1.get_session(year, roundNumber, "R")
            race.load()
            iterator = race.laps.iterlaps([])
        else:
            iterator = telemetry.iterrows()

        # Process each lap
        for lap in iterator:
            lapTelemetry = processLapTelemetry(lap[1], year, roundNumber, durationTimeColumns, telemetry, liveData)

            if lapTelemetry is not None:
                allTelemetryData.append(lapTelemetry)
    except Exception as e:
            print(f"Error loading the race: {e}")

    # Combine all telemetry data into a single DataFrame
    if allTelemetryData:
        df = pd.concat(allTelemetryData, ignore_index=True)
        df = df.reset_index(drop=True)

        # We label the telemetry
        df = getTelemetryLabel(df)

        # Group telemetry data
        df = groupTelemetryData(df)

        # Insert data into PostgreSQL database
        df.to_sql(table, con=engine, if_exists="append", index=False)
    else:
        print(f"No telemetry data found for race {roundNumber} in {year}.")
            
        # Special fallback only when liveData is used
        if liveData and telemetry is not None and not telemetry.empty:
            driver = telemetry.iloc[0].get("Driver")
            insertFallbackTelemetryRow(table, engine, year, roundNumber, driver)

def insertFallbackTelemetryRow(table, engine, year, roundNumber, driver):
    """
    Inserts a fallback telemetry row with mean values from existing data
    if no telemetry is found for a race (used in liveData mode).

    Parameters:
        table (str): Name of database table
        engine (sqlalchemy.Engine): SQLAlchemy database engine
        year (int): Year of the race
        roundNumber (int): Round number of the race
        driver (str): Driver code
    """

    try:
        # Load table
        existingDf = pd.read_sql_table(table, con=engine)

        # Filter for the specific driver, round and year
        filteredDf = existingDf[
            (existingDf["Driver"] == driver) &
            (existingDf["Year"] == year) &
            (existingDf["RoundNumber"] == roundNumber)
        ]

        if not filteredDf.empty:
            # Get last LapNumber + 1
            nextLapNumber = filteredDf["LapNumber"].max() + 1

            # Compute mean values
            columnsToAverage = ["RPM", "Speed", "Throttle", "LastLap"]
            meanValues = (
                filteredDf[columnsToAverage]
                .mean(numeric_only=True)
                .round()
                .astype(int)
                .to_dict()
            )

            # Build fallback row
            fallbackRow = {
                "Driver": driver,
                "LapNumber": nextLapNumber,
                "RoundNumber": roundNumber,
                "Year": year,
                **meanValues
            }

            fallbackDf = pd.DataFrame([fallbackRow])
            fallbackDf.to_sql(table, con=engine, if_exists="append", index=False)
            print(f"Inserted fallback row for driver {driver}, lap {nextLapNumber}.")
        else:
            print(f"No existing telemetry found for driver {driver}, year {year}, round {roundNumber}.")

    except Exception as e:
        print(f"Error inserting fallback telemetry row: {e}")

def loadTelemetry(engine, table, year, roundNumber, liveData, telemetry):
    """
    Loads telemetry data either for a specific race or all races

    Parameters:
        engine: SQLAlchemy engine or DB connection
        table (str): Destination table for telemetry data
        year (int or None): Specific race year, or None to process all
        roundNumber (int or None): Specific round number, or None to process all
        liveData (bool): If True, loads only the provided race
        telemetry (pd.DataFrame or None): Telemetry DataFrame for live processing

    Returns:
        None
    """

    # If we are using live data, we just load that laps
    if liveData:
        loadTelemetryData(year, roundNumber, table, engine, liveData, telemetry)
        return

    # If the year is inputed, we just load that telemetry
    if year != None and roundNumber != None:
        loadTelemetryData(year, roundNumber, table, engine, False, None)
        return 

    for race in tqdm(getRaces(), desc="Loading all telemetry"):
        year, roundNumber, _ = race.values()
        loadTelemetryData(year, roundNumber, table, engine, False, None)