# Database connection
import psycopg2

# Data handling and manipulation
import numpy as np
import pandas as pd
pd.options.mode.chained_assignment = None

# Data loading utilities
from processData.laps import loadLaps
from processData.telemetry import loadTelemetry

__all__ = ["getData"]

def mergeTables(engine, lapsTable, telemetryTable):
    """
    Merges the laps and telemetry tables from a database into one DataFrame

    The merge is performed on shared keys: Driver, LapNumber, RoundNumber, and Year

    Parameters:
        engine (sqlalchemy.engine.Engine): SQLAlchemy engine for database connection
        lapsTable (str): Name of the table containing lap data
        telemetryTable (str): Name of the table containing telemetry data

    Returns:
        pandas.DataFrame: Merged DataFrame excluding rows with Compound marked as UNKNOWN
    """

    # Use pandas to fetch the result of the query
    lapsData = pd.read_sql_query(f"""SELECT * FROM "{lapsTable}";""", engine)
    telemetryData = pd.read_sql_query(f"""SELECT * FROM "{telemetryTable}";""", engine)

    df = pd.merge(
        telemetryData,
        lapsData,
        on=["Driver", "LapNumber", "RoundNumber", "Year"],
        how="inner"
    )
    
    # Remove rows where "Compound" is "UNKNOWN"
    df = df[df["Compound"] != "UNKNOWN"]
    
    return df

def handleData(df, liveData):
    """
    Computes time-based features such as gap to next and gap to leader for each lap

    Parameters:
        df (pandas.DataFrame): Input DataFrame containing lap and telemetry data
        liveData (bool): Whether to load fresh data from source
        
    Returns:
        pandas.DataFrame: Updated DataFrame with gap features and sorted by lap timing
    """

    # Convert "Time" column to datetime format
    df["Time"] = pd.to_datetime(df["Time"])

    # Ensure LapTime is numeric
    df["LapTime"] = pd.to_numeric(df["LapTime"], errors="coerce")

    # Sort values to ensure correct order
    df = df.sort_values(by=["Year", "RoundNumber", "LapNumber", "Time"]).reset_index(drop=True)

    # Compute PrevTimeStart (Time of the driver ahead in the same lap)
    df["PrevTimeStart"] = df.groupby(["Year", "RoundNumber", "LapNumber"])["Time"].shift(1)

    # Compute LeaderTimeStart (Time of the fastest driver in the lap)
    df["LeaderTimeStart"] = df.groupby(["Year", "RoundNumber", "LapNumber"])["Time"].transform("min")

    # If we aren't using live data we compute these values
    if not liveData:
        # Compute GapToNext (difference in seconds with the driver ahead)
        df["GapToNext"] = (df["Time"] - df["PrevTimeStart"]).dt.total_seconds()

        # Compute GapToLeader (difference in seconds with the leader)
        df["GapToLeader"] = (df["Time"] - df["LeaderTimeStart"]).dt.total_seconds()

    # Format the values properly
    df["GapToNext"] = pd.to_numeric(df["GapToNext"], errors="coerce").fillna(0).round(3)
    df["GapToLeader"] = pd.to_numeric(df["GapToLeader"], errors="coerce").fillna(0).round(3)

    df.drop(columns=["PrevTimeStart", "LeaderTimeStart"], inplace=True)

    return df

def countLapsUntilNextPit(df):
    """
    Computes how many laps remain until the next pit stop for each lap

    Parameters:
        df (pandas.DataFrame): DataFrame containing a PitTime column and race metadata

    Returns:
        pandas.DataFrame: Updated DataFrame with a new column LapsUntilNextPit
    """

    # Ensure "PitTime" is numeric (in case of unexpected data types)
    df["PitTime"] = pd.to_numeric(df["PitTime"], errors="coerce")

    # Initialize the new column
    df["LapsUntilNextPit"] = 0

    # Process each unique combination of driver, year, and round
    for (year, roundNumber, driver), driverData in df.groupby(["Year", "RoundNumber", "Driver"]):

        # Get the lap numbers where the driver pitted
        pit_laps = driverData.loc[driverData["PitTime"] > 0, "LapNumber"].values

        # If no pit stops, fill with max lap number (driver finishes the race without pitting again)
        if len(pit_laps) == 0:
            df.loc[driverData.index, "LapsUntilNextPit"] = driverData["LapNumber"].max() - driverData["LapNumber"]
            continue

        # Compute laps until next pit
        remainingLaps = []
        next_pit_index = 0  # Pointer to the next pit stop

        for lap in driverData["LapNumber"]:
            # If we are on a pit lap, set remaining laps to 0
            if lap in pit_laps:
                remainingLaps.append(0)
                next_pit_index += 1  # Move to the next pit stop if available
            else:
                # Check if there is a future pit stop
                if next_pit_index < len(pit_laps):
                    remainingLaps.append(pit_laps[next_pit_index] - lap)
                else:
                    # No more pit stops, assign laps until end of race
                    remainingLaps.append(driverData["LapNumber"].max() - lap)

        # Assign the computed values back to the dataframe
        df.loc[driverData.index, "LapsUntilNextPit"] = np.array(remainingLaps, dtype="int64")

    return df

def setLapTime(df):
    """
    Calculates missing LapTime values from sector times if not already present

    Parameters:
        df (pandas.DataFrame): DataFrame including Sector1Time, Sector2Time, Sector3Time, and LapTime

    Returns:
        pandas.DataFrame: Updated DataFrame with complete LapTime values
    """

    df["Sector1Time"] = pd.to_numeric(df["Sector1Time"], errors="coerce")
    df["Sector2Time"] = pd.to_numeric(df["Sector2Time"], errors="coerce")
    df["Sector3Time"] = pd.to_numeric(df["Sector3Time"], errors="coerce")
    df["LapTime"] = pd.to_numeric(df["LapTime"], errors="coerce")

    # Fill LapTime only where it is null
    df.loc[df["LapTime"].isna(), "LapTime"] = (
        df["Sector1Time"] + df["Sector2Time"] + df["Sector3Time"]
    ).round(3)

    return df

def setOrder(df):
    """
    Reorders columns into a standard layout and converts selected columns to integers

    Parameters:
        df (pandas.DataFrame): Input DataFrame with telemetry and lap features

    Returns:
        tuple: 
            - pandas.DataFrame: Reordered and type-adjusted DataFrame
            - list of str: List of existing columns after reordering
    """

    # Define the desired column order
    columns = [
        "RPM", "Speed", "nGear", "Throttle", "Brake", "DRS", "Driver", "LapNumber", "RoundNumber", "Year", "LastLap",
        "City", "Time", "DriverNumber", "LapTime", "Stint", "Sector1Time", "Sector2Time", "Sector3Time",
        "Sector1SessionTime", "Sector2SessionTime", "Sector3SessionTime", "SpeedI1", "SpeedI2", "SpeedFL", "SpeedST",
        "IsPersonalBest", "Compound", "TyreLife", "FreshTyre", "Team", "Position", "AirTemp", "Humidity", "Pressure",
        "Rainfall", "TrackTemp", "WindDirection", "WindSpeed", "Q1", "Q2", "Q3", "GridPosition", "LapsUntilNextPit",
        "PitTime", "GapToNext", "GapToLeader", "TrackStatus", "TrackStatusDescription"
    ]

    # Convert types for columns that are present
    numericColumns = ["LapNumber", "DriverNumber", "Stint", "TyreLife", "Humidity", "GridPosition", "LapsUntilNextPit"]
    for col in numericColumns:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce").astype("Int64")

    # Filter the desired columns list to only those that exist in df
    existingColumns = [col for col in columns if col in df.columns]

    # Reorder the DataFrame columns
    df = df[existingColumns]
    return df, existingColumns

def existingTable(engine, dataTable):
    """
    Checks whether a table exists in the database

    Parameters:
        engine (sqlalchemy.engine.Engine): SQLAlchemy engine connected to the database
        dataTable (str): Name of the table to verify

    Returns:
        bool: True if table exists, False otherwise
    """

    query = f"""
    SELECT EXISTS (
        SELECT FROM information_schema.tables 
        WHERE table_name = '{dataTable}'
    );
    """
    result = pd.read_sql_query(query, engine)
    return result.iloc[0, 0]

def removeDuplicates(engine, dataTable, df):
    """
    Removes any rows from the input DataFrame that already exist in the database table

    Parameters:
        engine (sqlalchemy.engine.Engine): SQLAlchemy engine
        dataTable (str): Target table for comparison
        df (pandas.DataFrame): DataFrame to filter for new rows only

    Returns:
        pandas.DataFrame: Filtered DataFrame with only new entries
    """

    try:
        if not existingTable(engine, dataTable):
            return df
        
        # Fetch the existing rows in the table
        existingData = pd.read_sql_query(f"""SELECT "Driver", "LapNumber", "RoundNumber", "Year" FROM "{dataTable}" """, engine)

        # Find rows that are not already in the table
        newData = df.merge(existingData, on=["Driver", "LapNumber", "RoundNumber", "Year"], how="left", indicator=True)
        newData = newData[newData["_merge"] == "left_only"].drop(columns=["_merge"])

        return newData
    except psycopg2.errors.UndefinedTable:
        return df

def removeNone(df):
    """
    Fills missing (NaN) values in the last row of the DataFrame

    If a column in the last row contains a NaN value, it is filled with the value
    from the row immediately above. If the DataFrame contains only one row, any
    NaN values are filled with 0.

    Parameters:
        df (pandas.DataFrame): Input DataFrame to process

    Returns:
        pandas.DataFrame: DataFrame with NaN values in the last row filled appropriately
    """
    
    if df.empty:
        return df

    last_pos = -1

    if len(df) > 1:
        for col in df.columns:
            if pd.isna(df.iloc[last_pos][col]):
                df.iat[last_pos, df.columns.get_loc(col)] = df.iat[last_pos - 1, df.columns.get_loc(col)]
    else:
        df.iloc[last_pos] = df.iloc[last_pos].fillna(0)

    return df

def getData(engine, dataTable, lapsTable, telemetryTable, raceYear, roundNumberYear, cityYear, liveData, race, raceStartTime, qualifying, telemetry):
    """
    Loads, merges, and processes race telemetry and lap data, storing the results in the database

    Parameters:
        engine (sqlalchemy.engine.Engine): SQLAlchemy engine for database access
        dataTable (str): Destination table for processed data
        lapsTable (str): Table containing raw lap data
        telemetryTable (str): Table containing raw telemetry data
        raceYear (int): Target year of the race
        roundNumberYear (int): Round number in the season
        cityYear (str): City where the race occurred
        liveData (bool): Whether to load fresh data from source
        race (optional): Race object for live ingestion
        raceStartTime (optional): Datetime of race start for alignment
        qualifying (optional): Qualifying result object
        telemetry (optional): Telemetry data object

    Returns:
        list of str: Column names present in the final stored dataset
    """

    # We get all data from the API
    loadLaps(engine, lapsTable, raceYear, roundNumberYear, cityYear, liveData, race, raceStartTime, qualifying)
    loadTelemetry(engine, telemetryTable, raceYear, roundNumberYear, liveData, telemetry)

    df = mergeTables(engine, lapsTable, telemetryTable)
    df = handleData(df, liveData)
    df = countLapsUntilNextPit(df)
    df = setLapTime(df)
    df, existingColumns = setOrder(df)
    df = removeNone(df)

    # Remove duplicates based on the specified columns
    newTable = removeDuplicates(engine, dataTable, df)
    newTable.to_sql(dataTable, con=engine, if_exists="append", index=False)

    return existingColumns