# Standard library utilities
import logging

# Progress bar for loops
from tqdm import tqdm

# Data handling
import pandas as pd
pd.options.mode.chained_assignment = None

# F1 timing data API
import fastf1
logging.getLogger("fastf1").setLevel(logging.ERROR)

# Fixed metadata for race events
from fixedData.raceDictionary import getRaces

__all__ = ["loadLaps"]

def convertToAbsoluteTime(df, sessionStartTime, relativeTimeColumns):
    """
    Converts relative time columns to absolute timestamps using a session start time

    Parameters:
        df (pd.DataFrame): Input DataFrame containing relative time columns
        sessionStartTime (str or pd.Timestamp): Start time of the session
        relativeTimeColumns (list of str): Columns to convert from timedelta to datetime

    Returns:
        pd.DataFrame: DataFrame with updated absolute time columns
    """

    # Ensure it's a Pandas Timestamp
    sessionStartTime = pd.Timestamp(sessionStartTime)

    for col in relativeTimeColumns:
         # Check if it's timedelta
        if col in df.columns and pd.api.types.is_timedelta64_dtype(df[col]):
            # Convert relative to absolute time
            df[col] = sessionStartTime + df[col]

    return df

def convertToSeconds(series):
    """
    Converts a pandas Series from timedelta to seconds, rounded to three decimals

    Parameters:
        series (pd.Series): Series with timedelta values

    Returns:
        pd.Series: Series with float seconds or original Series if not timedelta
    """

    return series.dt.total_seconds().round(3) if pd.api.types.is_timedelta64_dtype(series) else series

def loadQualifyingResults(year, roundNumber, qualifyingResults, liveData):
    """
    Loads and processes qualifying results for a specific race

    Parameters:
        year (int): Season year
        roundNumber (int): Round number of the race
        qualifyingResults (DataFrame or None): Pre-loaded results or None
        liveData (bool): Flag to determine whether to fetch data from live API

    Returns:
        pd.DataFrame or None: Processed qualifying results or None on error
    """

    try:
        if not liveData:
            # Load the qualifying session
            qualifyingSession = fastf1.get_session(year, roundNumber, "Q")
            qualifyingSession.load()

            # Get the results
            qualifyingResults = qualifyingSession.results.copy()

        qualifyingResults = qualifyingResults.reset_index(drop=True)

        # Remove "GridPosition" if it already exists to avoid duplication
        if "GridPosition" in qualifyingResults.columns:
            qualifyingResults.drop(columns=["GridPosition"], inplace=True, errors="ignore")

        # Make sure "Position" is in the columns before renaming it
        if "Position" in qualifyingResults.columns:
            qualifyingResults.rename(columns={"Position": "GridPosition"}, inplace=True)
        else:
            return None

        # Select only the necessary columns
        qualifyingResults = qualifyingResults[["DriverNumber", "Q1", "Q2", "Q3", "GridPosition"]]

        # Convert Q1, Q2, Q3 to seconds
        for col in ["Q1", "Q2", "Q3"]:
            qualifyingResults[col] = convertToSeconds(qualifyingResults[col])

        return qualifyingResults

    except Exception as e:
        print(f"Error loading qualifying results: {e}")
        return None

def countLapsUntilNextPit(df):
    """
    Computes the number of laps until the next pit stop for each lap

    Parameters:
        df (pd.DataFrame): DataFrame with lap and pit information

    Returns:
        pd.DataFrame: Updated DataFrame with LapsUntilNextPit column
    """

    # Convert "PitInTime" to datetime for filtering NaT values
    if "PitInTime" in df.columns:
        df["PitInTime"] = pd.to_datetime(df["PitInTime"])
        pitCondition = df["PitInTime"].notna()
    elif "PitTime" in df.columns:
        pitCondition = df["PitTime"].fillna(0) > 0

    # Initialize the new column
    df["LapsUntilNextPit"] = 0

    # Process each driver separately
    for driver in df["Driver"].unique():
        driverData = df[df["Driver"] == driver].copy()

        # Get row indices where a pit stop occurred
        pitStopIndices = driverData.loc[pitCondition].index

        # Find the last lap in the session for this driver
        lastLap = driverData["LapNumber"].max()

        # Assign remaining laps until next pit stop
        remainingLaps = []
        pitStopIterator = iter(pitStopIndices)
        nextPitStopIdx = next(pitStopIterator, None)

        for idx, lap in zip(driverData.index, driverData["LapNumber"]):
            # If it's a pit stop lap, set it to 0
            if idx in pitStopIndices:
                remainingLaps.append(0)
            else:
                # Find the next pit stop row index
                while nextPitStopIdx is not None and nextPitStopIdx <= idx:
                    nextPitStopIdx = next(pitStopIterator, None)

                # Assign the number of laps until the next pit stop
                if nextPitStopIdx is None:
                    remainingLaps.append(lastLap - lap)
                else:
                    nextPitLap = df.loc[nextPitStopIdx, "LapNumber"]
                    remainingLaps.append(nextPitLap - lap)

        # Convert to integer before assigning
        df.loc[df["Driver"] == driver, "LapsUntilNextPit"] = [int(x) for x in remainingLaps]

        return df
    
def processLaps(session, sessionStartTime, qualifyingResults, relativeTimeColumns, durationTimeColumns, laps, liveData):
    """
    Processes raw lap data with weather, time, and qualifying results

    Parameters:
        session (Session): FastF1 session object
        sessionStartTime (datetime): Start time of the session
        qualifyingResults (pd.DataFrame): Qualifying results to merge
        relativeTimeColumns (list): Columns with relative times
        durationTimeColumns (list): Columns with durations in time format
        laps (pd.DataFrame): Existing lap data or empty
        liveData (bool): Whether to use existing laps or fetch new

    Returns:
        pd.DataFrame or None: Fully processed lap data or None if empty
    """

    if not liveData:
        laps = session.laps
        if laps.empty:
            return None

        # Reset index for safety
        laps = laps.reset_index(drop=True)
        laps["sessionName"] = session.name

        # Add the weather
        weatherData = session.laps.get_weather_data().reset_index(drop=True)
        laps = pd.concat([laps, weatherData.loc[:, ~(weatherData.columns == "Time")]], axis=1)

    # Convert relative times to absolute timestamps
    if sessionStartTime is not None:
        laps = convertToAbsoluteTime(laps, sessionStartTime, relativeTimeColumns)

    # Convert duration times to seconds.milliseconds
    for col in durationTimeColumns:
        if col in laps.columns and pd.api.types.is_timedelta64_dtype(laps[col]):
            laps[col] = convertToSeconds(laps[col])

    # Merge qualifying results if available
    if qualifyingResults is not None:
        laps = laps.merge(qualifyingResults, on="DriverNumber", how="left")

    # Filter only accurate lap records
    laps = laps.dropna(subset=["GridPosition"])

    # Drop unnecessary columns
    laps.drop(columns=["FastF1Generated", "IsAccurate"], inplace=True, errors="ignore")

    # Convert object columns that should be numeric
    numericColumns = ["AirTemp", "Humidity", "Pressure", "Rainfall", "TrackTemp", "WindDirection", "WindSpeed", "Position", "GridPosition"]
    for col in numericColumns:
        if col in laps.columns:
            laps[col] = pd.to_numeric(laps[col], errors="coerce")

    # Convert all int64 and float64 to standard Python types
    for col in laps.select_dtypes(include=["int64"]).columns:
        laps[col] = laps[col].astype(int)

    for col in laps.select_dtypes(include=["float64"]).columns:
        laps[col] = laps[col].astype(float)

    # Convert datetime columns to string
    dateTimeColumns = laps.select_dtypes(include=["datetime64"]).columns
    for col in dateTimeColumns:
        laps[col] = laps[col].astype(str)

    return laps

def processPitStopTimes(df):
    """
    Computes pit stop durations and adds a PitTime column

    Parameters:
        df (pd.DataFrame): DataFrame with pit stop timestamps

    Returns:
        pd.DataFrame: Updated DataFrame with PitTime in seconds
    """

    if "PitTime" not in df.columns:
        # Convert PitInTime and PitOutTime to datetime
        df["PitInTime"] = pd.to_datetime(df["PitInTime"])
        df["PitOutTime"] = pd.to_datetime(df["PitOutTime"])

        # Compute PitTime for each group:
        # For row i-1, PitTime = PitOutTime from row i - PitInTime from current row
        pitTime = df.groupby(["Driver", "Year", "RoundNumber"]).apply(
            lambda group: group["PitOutTime"].shift(-1) - group["PitInTime"], include_groups=False
        ).reset_index(level=["Driver", "Year", "RoundNumber"], drop=True)
    else:
        pitTime = df["PitTime"]

    # Replace missing values with zero duration and convert to seconds
    df["PitTime"] = convertToSeconds(pitTime.fillna(pd.Timedelta(seconds=0)))

    # Remove the original columns
    df.drop(columns=["PitInTime", "PitOutTime"], inplace=True, errors="ignore")
    return df

def processBooleanColumns(df):
    """
    Cleans and converts specific boolean columns to integers

    Parameters:
        df (pd.DataFrame): Input DataFrame

    Returns:
        pd.DataFrame: DataFrame with updated boolean columns
    """

    booleanColumns = ["IsPersonalBest", "FreshTyre", "Rainfall"]
    for col in booleanColumns:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce")
            df[col] = df[col].fillna(0).astype(int)
    return df

def fillSectorTimes(row):
    """
    Fills one missing sector time using LapTime if exactly one sector is missing

    Parameters:
        row (pd.Series): A row from the lap data

    Returns:
        pd.Series: Row with the missing sector time filled if possible
    """

    sectors = ["Sector1Time", "Sector2Time", "Sector3Time"]
    missingCount = row[sectors].isnull().sum()

    if pd.notnull(row["LapTime"]) and missingCount == 1:
        # Identify the missing sector column
        missingSector = row[sectors].isnull().idxmax()

        # Sum the available sector times
        availableSum = row[sectors].sum(skipna=True)

        # Fill the missing value as LapTime minus the sum of the available sectors
        row[missingSector] = row["LapTime"] - availableSum
    return row

def processSectorImputation(df):
    """
    Imputes missing sector times using row-wise logic and group means

    Parameters:
        df (pd.DataFrame): Lap data with missing sector times

    Returns:
        pd.DataFrame: Updated DataFrame with filled sector times
    """

    # Apply row-wise imputation if exactly one sector is missing
    df = df.apply(fillSectorTimes, axis=1)

    # For any remaining missing values, fill with the group's mean
    sectorTimesColumns = ["Sector1Time", "Sector2Time", "Sector3Time"]
    for col in sectorTimesColumns:
        if col in df.columns:
            df[col] = df.groupby(["Driver", "Year", "RoundNumber"])[col].transform(lambda x: x.fillna(x.mean()).round(3))
    return df

def processSpeedImputation(df):
    """
    Imputes missing speed values using the group mean

    Parameters:
        df (pd.DataFrame): Lap data with potential speed columns

    Returns:
        pd.DataFrame: DataFrame with filled speed values
    """

    speedColumns = ["SpeedI1", "SpeedI2", "SpeedFL", "SpeedST"]
    for col in speedColumns:
        if col in df.columns:
            df[col] = df.groupby(["Driver", "Year", "RoundNumber"])[col].transform(lambda x: x.fillna(round(x.mean(), 0)))
    return df

def processQualifying(df):
    """
    Fills missing qualifying times (Q1, Q2, Q3) with a default high value

    Parameters:
        df (pd.DataFrame): Lap data with qualifying columns

    Returns:
        pd.DataFrame: DataFrame with filled qualifying times
    """

    qualifyingColumns = ["Q1", "Q2", "Q3"]
    for col in qualifyingColumns:
        if col in df.columns:
            df[col] = df.groupby(["Driver", "Year", "RoundNumber"])[col].transform(lambda x: x.fillna(99999.999))
    return df

def formatTimeValues(df):
    """
    Converts datetime columns to string format with millisecond precision

    Parameters:
        df (pd.DataFrame): DataFrame with time-related columns

    Returns:
        pd.DataFrame: Updated DataFrame with formatted time columns
    """

    timeColumns = ["Time", "Sector1SessionTime", "Sector2SessionTime", "Sector3SessionTime", "LapStartTime"]
    for col in timeColumns:
        if col in df.columns:
            # Special handling for Sector1SessionTime if missing
            if col == "Sector1SessionTime":
                mask = df[col].isna() & df["Sector2SessionTime"].notna() & df["Sector1Time"].notna()
                df.loc[mask, col] = pd.to_datetime(df.loc[mask, "Sector2SessionTime"]) - pd.to_timedelta(df.loc[mask, "Sector1Time"], unit="s")

            # Format as string: yyyy-mm-dd hh:mm:ss.mmm
            df[col] = pd.to_datetime(df[col]).dt.strftime("%Y-%m-%d %H:%M:%S.%f").str[:-3]
    return df

def formatLapTime(df):
    """
    Formats the LapTime column to string with three decimal places

    Parameters:
        df (pd.DataFrame): Lap data

    Returns:
        pd.DataFrame: DataFrame with formatted LapTime
    """

    if "LapTime" in df.columns:
        # Format each value as a string with 3 decimals (if not null)
        df["LapTime"] = df["LapTime"].apply(lambda x: f"{x:.3f}" if pd.notnull(x) else x)
    return df

def processTrackStatus(df):
    """
    Translates TrackStatus codes into human-readable categories

    Parameters:
        df (pd.DataFrame): DataFrame with TrackStatus codes

    Returns:
        pd.DataFrame: DataFrame with new TrackStatusDescription column
    """

    # Convert the track status value to an integer with the following values
    if "TrackStatus" in df.columns:
        df["TrackStatus"] = df["TrackStatus"].astype(str)

        # Extract highest digit from the string
        def getMaxStatus(status_str):
            digits = [int(c) for c in status_str if c.isdigit()]
            return max(digits) if digits else 0

        df["TrackStatus"] = df["TrackStatus"].apply(getMaxStatus)
        
        # Modify the corresponding values
        df["TrackStatus"] = df["TrackStatus"].replace({
            1: 0,
            2: 0,
            3: 0,
            4: 2,
            5: 0,
            6: 1,
            7: 1,
            8: 2
        })

        # Map integer values to human-readable status
        TRACK_STATUS_MEANINGS = {
            0: "Clear",
            1: "Virtual Safety Car",
            2: "Safety Car"
        }

        df["TrackStatusDescription"] = df["TrackStatus"].map(TRACK_STATUS_MEANINGS)

    return df

def postProcess(dfInput, year, roundNumber, city):
    """
    Applies final cleaning and enrichment steps to lap data

    Parameters:
        dfInput (pd.DataFrame): Raw lap data
        year (int): Year of the event
        roundNumber (int): Round number
        city (str): Race city

    Returns:
        pd.DataFrame: Fully processed lap data
    """

    # Create a copy of the dataframe and reset the index
    df = dfInput.copy().reset_index(drop=True)

    # Compute laps until the next pit stop
    df = countLapsUntilNextPit(df)

    # Add year and round info
    df["Year"] = year
    df["RoundNumber"] = roundNumber
    df["City"] = city

    # We modify the teams with the proper names
    df.loc[df["Team"] == "Alfa Romeo", "Team"] = "Kick Sauber"
    df.loc[df["Team"].isin(["AlphaTauri", "RB"]), "Team"] = "Racing Bulls"

    # Process track status
    df = processTrackStatus(df)
        
    # Process pit times
    df = processPitStopTimes(df)

    # Remove unwanted columns
    # unwantedColumns = ["sessionName", "Deleted", "DeletedReason", "LapStartDate", "LapStartTime", "TrackStatus"]
    unwantedColumns = ["sessionName", "Deleted", "DeletedReason", "LapStartDate", "LapStartTime"]
    df.drop(columns=unwantedColumns, inplace=True, errors="ignore")

    # Process booleans, sector times, speeds, and qualifying times
    df = processBooleanColumns(df)
    df = processSectorImputation(df)
    df = processSpeedImputation(df)
    df = processQualifying(df)

    # Format time columns
    df = formatTimeValues(df)

    # Format LapTime with 3 decimals
    df = formatLapTime(df)

    return df

def loadRaceLaps(year, roundNumber, city, table, engine, liveData, race, raceStartTime, qualiResults, laps):
    """
    Loads, processes, and stores lap data for a single race into a database

    Parameters:
        year (int): Year of the race
        roundNumber (int): Round number
        city (str): Race city
        table (str): Target database table name
        engine: SQLAlchemy database engine
        liveData (bool): Whether to use live session data
        race: Existing session or None
        raceStartTime: Optional start time
        qualiResults: Optional preloaded qualifying results
        laps: Optional preloaded lap data

    Returns:
        None
    """

    allLapsData = []

    # Columns requiring conversion
    relativeTimeColumns = [
        "Time", "LapStartTime", "LapStartDate", "Sector1SessionTime",
        "Sector2SessionTime", "Sector3SessionTime", "PitInTime", "PitOutTime"
    ]

    durationTimeColumns = ["LapTime", "Sector1Time", "Sector2Time", "Sector3Time"]

    # Load qualifying data
    qualifyingResults = loadQualifyingResults(year, roundNumber, qualiResults, liveData)

    # Process each race
    try:
        if not liveData:
            # Load the race
            race = fastf1.get_session(year, roundNumber, "R")
            race.load()

            # Get race start time
            raceStartTime = race.event.get("Session5Date")

        # Process lap data
        laps = processLaps(race, raceStartTime, qualifyingResults, relativeTimeColumns, durationTimeColumns, laps, liveData)
        if laps is not None:
            allLapsData.append(laps)

    except Exception as e:
            print(f"Error loading the race: {e}")

    # Combine all lap data into a single DataFrame
    if allLapsData:
        df = pd.concat(allLapsData, ignore_index=True)

        # Post process the table to add all necessary data
        df = postProcess(df, year, roundNumber, city)

        # Insert data into PostgreSQL database
        df.to_sql(table, con=engine, if_exists="append", index=False)
    else:
        print(f"No lap data found for race {roundNumber} in {year}.")

def loadLaps(engine, table, year, roundNumber, city, liveData, race, raceStartTime, qualifyingResults):
    """
    Loads lap data for one or all races and stores it into the database

    Parameters:
        engine: Database engine
        table (str): Target database table
        year (int or None): Specific year to filter, or None
        roundNumber (int or None): Specific round to filter, or None
        city (str or None): Specific city to filter, or None
        liveData (bool): Whether to use live data
        race: Optional race session
        raceStartTime: Optional session start time
        qualifyingResults: Optional qualifying results

    Returns:
        None
    """

    # If we are using live data, we just load that laps
    if liveData:
        loadRaceLaps(year, roundNumber, city, table, engine, liveData, race, raceStartTime, qualifyingResults, race)
        return

    # If the year is inputted, we just load that laps
    if year != None and roundNumber != None and city != None:
        loadRaceLaps(year, roundNumber, city, table, engine, False, None, None, None, None)
        return
    
    for race in tqdm(getRaces(), desc="Loading all laps"):
        year, roundNumber, city = race.values()
        loadRaceLaps(year, roundNumber, city, table, engine, False, None, None, None, None)