# Environment variable loading
from dotenv import load_dotenv

# Database connection tools
import os
from sqlalchemy import create_engine
import psycopg2

# TensorFlow for machine learning processing
import tensorflow as tf

# Standard datetime utilities
import datetime

# Data retrieval utility
from processData.postProcess import getData

__all__ = ["loadData"]

def createConnection():
    """
    Creates a connection to a PostgreSQL database using environment variables

    Returns:
        tuple: SQLAlchemy engine and psycopg2 connection object
    """

    # Load environment variables from .env file
    load_dotenv()

    # Get database credentials from .env
    host = os.getenv("HOST")
    dbName = os.getenv("NAME")
    user = os.getenv("USER")
    password = os.getenv("PASSWORD")
    port = os.getenv("PORT")

    # Create connection string
    connectionString = f"postgresql://{user}:{password}@{host}:{port}/{dbName}"

    # Retrieve database credentials
    connection = psycopg2.connect(
        host=host,
        database=dbName,
        user=user,
        password=password
    )

    # Create and return SQLAlchemy engine
    return create_engine(connectionString), connection

def getRows(connection, dataTable):
    """
    Retrieves all rows from a table, ordered by Year, RoundNumber, Driver, and LapNumber

    Parameters:
        connection: psycopg2 database connection
        dataTable (str): Name of the table to query

    Returns:
        list: List of tuples, one per row
    """

    cursor = connection.cursor()

    # Define an SQL query to select all rows from the table.
    cursor.execute(f"""SELECT * FROM "{dataTable}" ORDER BY "Year", "RoundNumber", "Driver", "LapNumber";""")

    rows = cursor.fetchall()
    cursor.close()

    return rows

def safeFloat(v):
    """
    Safely converts a value to float, returns 0.0 on failure

    Parameters:
        v: Input value

    Returns:
        float: Converted float or 0.0
    """

    try:
        return float(v)
    except (TypeError, ValueError):
        return 0.0
    
def floatFeature(value):
    """
    Converts a value or list into a TensorFlow float_list Feature

    Parameters:
        value: Single float or list of floats

    Returns:
        tf.train.Feature: TensorFlow float list feature
    """

    if isinstance(value, list):
        return tf.train.Feature(float_list=tf.train.FloatList(value=[safeFloat(v) for v in value]))
    return tf.train.Feature(float_list=tf.train.FloatList(value=[safeFloat(value)]))

def intFeature(value):
    """
    Converts a value or list into a TensorFlow int64_list Feature

    Parameters:
        value: Single int or list of ints

    Returns:
        tf.train.Feature: TensorFlow int list feature
    """

    if isinstance(value, list):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(v) for v in value]))
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value) if value is not None else 0]))

def bytesFeature(value):
    """
    Converts a value or list into a TensorFlow bytes_list Feature

    Parameters:
        value: Single value or list of values to be converted to UTF-8

    Returns:
        tf.train.Feature: TensorFlow bytes list feature
    """

    if isinstance(value, list):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(v).encode("utf-8") for v in value]))
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(value).encode("utf-8")]))

def processData(rows, existingColumns):
    """
    Organizes raw database rows into structured lap and static data by driver

    Parameters:
        rows (list): List of data rows from the database
        existingColumns (list): Names of columns from the database

    Returns:
        tuple: groupedData, driverLookup, teamLookup, compoundLookup, cityLookup
    """

    # Initialize grouping dictionary
    groupedData = {}

    # Create lookup vocabularies
    allDrivers = sorted(map(str, [
        "SAI", "BEA", "LEC", "HAD", "TSU", "HUL", "ZHO", "HAM", "RUS", "VET", "RIC", "STR", "DOO", "MAG", "SAR", "DEV",
        "BOR", "BOT", "PER", "ALB", "MSC", "LAT", "NOR", "OCO", "VER", "LAW", "ALO", "GAS", "PIA", "ANT", "COL"
    ]))
    allTeams = sorted(map(str, [
        "Aston Martin", "Williams", "McLaren", "Alpine", "Red Bull Racing",
        "Haas F1 Team", "Kick Sauber", "Mercedes", "Ferrari", "Racing Bulls"
    ]))
    allCompounds = sorted(map(str, ["SOFT", "MEDIUM", "HARD", "INTERMEDIATE", "WET"]))
    allCities = sorted(map(str, [
        "Mexico City", "Las Vegas", "Abu Dhabi", "Budapest", "Barcelona", "Baku", "Spielberg",
        "Shanghai", "Monza", "Montreal", "Sakhir", "Lusail", "Monaco", "Suzuka", "Zandvoort",
        "Austin", "Jeddah", "Silverstone", "Sao Paulo", "Imola", "Singapore", "Le Castellet",
        "Spa-Francorchamps", "Miami", "Melbourne"
    ]))

    # Initialize StringLookup layers with proper string input
    driverLookup = tf.keras.layers.StringLookup(vocabulary=allDrivers, output_mode="int", mask_token=None, num_oov_indices=0)
    teamLookup = tf.keras.layers.StringLookup(vocabulary=allTeams, output_mode="int", mask_token=None, num_oov_indices=0)
    compoundLookup = tf.keras.layers.StringLookup(vocabulary=allCompounds, output_mode="int", mask_token=None, num_oov_indices=0)
    cityLookup = tf.keras.layers.StringLookup(vocabulary=allCities, output_mode="int", mask_token=None, num_oov_indices=0)

    for row in rows:
        rowDict = dict(zip(existingColumns, row))

        year = rowDict.get("Year")
        roundNumber = rowDict.get("RoundNumber")
        driver = rowDict.get("Driver")
        key = (year, roundNumber, driver)

        if key not in groupedData:
            groupedData[key] = {
                "staticData": {
                    "City": rowDict.get("City"),
                    "DriverNumber": rowDict.get("DriverNumber"),
                    "Team": rowDict.get("Team"),
                    "Q1": rowDict.get("Q1"),
                    "Q2": rowDict.get("Q2"),
                    "Q3": rowDict.get("Q3"),
                    "GridPosition": rowDict.get("GridPosition")
                },
                "lapData": {col: [] for col in existingColumns if col not in groupedData.get("staticData", {})}
            }

        # Append lapData dynamically
        for col in groupedData[key]["lapData"].keys():
            groupedData[key]["lapData"][col].append(rowDict.get(col, None))

    return groupedData, driverLookup, teamLookup, compoundLookup, cityLookup

def convertTimeStringToSeconds(timeVal):
    """
    Converts a time value in string, timedelta, or datetime to total seconds

    Parameters:
        timeVal: Time input as str, timedelta, or datetime

    Returns:
        float: Total seconds
    """
    
    if timeVal is None:
        return 0.0
    # Handle datetime.datetime
    if isinstance(timeVal, datetime.datetime):
        return timeVal.hour * 3600 + timeVal.minute * 60 + timeVal.second + timeVal.microsecond / 1e6
    # Handle timedelta
    if isinstance(timeVal, datetime.timedelta):
        return timeVal.total_seconds()
    # Handle string
    try:
        if isinstance(timeVal, str) and ":" in timeVal:
            minutes, seconds = timeVal.split(":")
            return float(minutes) * 60 + float(seconds)
        return float(timeVal)
    except Exception:
        return 0.0

def serializeStaticFeatures(year, roundNumber, driver, staticData, driverLookup, teamLookup, cityLookup):
    """
    Serializes non-lap-dependent (static) data into a TensorFlow TFRecord feature dictionary.

    Parameters:
        year (int): The year of the race.
        roundNumber (int): The round number of the race in the season.
        driver (str): The driver's identifier.
        staticData (dict): A dictionary containing static race data, including:
            - City (bytes): The city where the race took place.
            - DriverNumber (int): The driver's number.
            - Team (str): The team name.
            - Q1 (float): The driver's Q1 qualifying time.
            - Q2 (float): The driver's Q2 qualifying time.
            - Q3 (float): The driver's Q3 qualifying time.
            - GridPosition (float): The driver's starting grid position.
        driverLookup (Callable): A function that maps a driver's identifier to a numerical value.
        teamLookup (Callable): A function that maps a team's name to a numerical value.

    Returns:
        dict: A dictionary containing serialized features for TensorFlow TFRecord, with keys:
            - Year (int): The year of the race.
            - RoundNumber (int): The round number of the race.
            - Driver (int): The numerical representation of the driver.
            - City (bytes): The city where the race took place.
            - DriverNumber (int): The driver's number.
            - Team (int): The numerical representation of the team.
            - Q1 (float): The driver's Q1 qualifying time.
            - Q2 (float): The driver's Q2 qualifying time.
            - Q3 (float): The driver's Q3 qualifying time.
            - GridPosition (float): The driver's starting grid position.
    """

    return {
        "Year": intFeature(year),
        "RoundNumber": intFeature(int(roundNumber)),
        "Driver": intFeature(int(driverLookup([str(driver)]).numpy()[0])),
        "City": intFeature(int(cityLookup([str(staticData["City"])]).numpy()[0])),
        "DriverNumber": intFeature(int(staticData["DriverNumber"])),
        "Team": intFeature(int(teamLookup([str(staticData["Team"])]).numpy()[0])),
        "Q1": floatFeature(staticData["Q1"]),
        "Q2": floatFeature(staticData["Q2"]),
        "Q3": floatFeature(staticData["Q3"]),
        "GridPosition": floatFeature(staticData["GridPosition"])
    }

def serializeCategorialFeatures(lapData, compoundLookup):
    """
    Serializes categorical lap-dependent features into TensorFlow format

    Parameters:
        lapData (dict): Dictionary containing lap-specific data. Expected to include a key "Compound"
        compoundLookup (tf.keras.layers.StringLookup): Lookup layer that maps compound strings to integer IDs

    Returns:
        dict: Dictionary of serialized TensorFlow features with integer-encoded values
    """

    return {
        "Compound": intFeature([compoundLookup(tf.convert_to_tensor([str(c)], dtype=tf.string)).numpy()[0] for c in lapData["Compound"]])
    }

def serializeTimeFeatures(lapData):
    """
    Serializes time-related lap data fields into float seconds

    Parameters:
        lapData (dict): Dictionary containing lap time fields, such as:
            - "Time"
            - "Sector1SessionTime"
            - "Sector2SessionTime"
            - "Sector3SessionTime"
            - "LapTime"

    Returns:
        dict: TensorFlow Feature dictionary with serialized float values in seconds
    """

    timeFields = [
        "Time",
        "Sector1SessionTime",
        "Sector2SessionTime",
        "Sector3SessionTime"
    ]

    timeFeatures = {}

    for field in timeFields:
        if field in lapData:
            timeFeatures[field] = floatFeature([convertTimeStringToSeconds(t) for t in lapData[field]])

    if "LapTime" in lapData:
        timeFeatures["LapTime"] = floatFeature([safeFloat(v) for v in lapData["LapTime"]])

    return timeFeatures

def serializeLapDependentFeatures(lapData):
    """
    Serializes numeric lap-dependent features, excluding time and categorical fields

    Parameters:
        lapData (dict): Dictionary where keys are feature names and values are lists of per-lap values

    Returns:
        dict: TensorFlow Feature dictionary with numeric features serialized as float or int

    Notes:
        - Integer features include "LapNumber", "IsPersonalBest", and "LapsUntilNextPit"
        - Excludes static features such as "Q1", "Q2", and "GridPosition"
        - Time and compound fields are handled by other serialization functions
    """

    feature = {}
 
    # Static features that should NOT be treated as lap-dependent
    static_features = ["Q1", "Q2", "Q3", "Driver", "City", "Team", "GridPosition", "DriverNumber", "Year", "RoundNumber"]

    for featureName, values in lapData.items():
        if featureName in static_features:
            continue
        
        if featureName in ["LapNumber", "IsPersonalBest", "LapsUntilNextPit"]:
            feature[featureName] = intFeature(values)
        elif featureName not in ["Compound", "Time", "LapTime", "Sector1SessionTime", "Sector2SessionTime", "Sector3SessionTime"]:
            feature[featureName] = floatFeature(values)

    return feature

def serializeSequence(year, roundNumber, driver, staticData, lapData, driverLookup, teamLookup, compoundLookup, cityLookup):
    """
    Serializes all relevant race data for a driver into a TensorFlow Example

    Parameters:
        year (int): Year of the race
        roundNumber (int): Round number in the season
        driver (str): Driver identifier
        staticData (dict): Dictionary of static data including:
            - "City"
            - "DriverNumber"
            - "Team"
            - "Q1", "Q2", "Q3"
            - "GridPosition"
        lapData (dict): Dictionary of lap-dependent data (one entry per lap)
        driverLookup (tf.keras.layers.StringLookup): Lookup layer for driver names
        teamLookup (tf.keras.layers.StringLookup): Lookup layer for team names
        compoundLookup (tf.keras.layers.StringLookup): Lookup layer for tyre compounds
        cityLookup (tf.keras.layers.StringLookup): Lookup layer for city names

    Returns:
        bytes: Serialized TensorFlow Example representing the full sequence for one driver
    """

    feature = {}

    # Add different types of features using helper functions
    feature.update(serializeStaticFeatures(year, roundNumber, driver, staticData, driverLookup, teamLookup, cityLookup))
    feature.update(serializeCategorialFeatures(lapData, compoundLookup))
    feature.update(serializeTimeFeatures(lapData))
    feature.update(serializeLapDependentFeatures(lapData))

    # Convert dictionary into a TensorFlow Example
    exampleProto = tf.train.Example(features=tf.train.Features(feature=feature))
    return exampleProto.SerializeToString()

def loadData(outputFileName="allDataByLaps", dataTable="allDataByLaps", lapsTable="laps", telemetryTable="telemetry", raceYear=None, roundNumberYear=None, cityYear=None, liveData=False, raceStartTime=None, race=None, qualifying=None, telemetry=None):
    """
    Loads and processes race data from a database and writes it to a TFRecord file

    Parameters:
        outputFileName (str): Name for the TFRecord output file
        dataTable (str): Main table with combined lap data
        lapsTable (str): Table with raw lap information
        telemetryTable (str): Table with telemetry data
        raceYear (int): Optional filter for race year
        roundNumberYear (int): Optional filter for round number
        cityYear (str): Optional filter for city
        liveData (bool): Whether to use preloaded session objects
        raceStartTime: Optional race start time
        race, qualifying, telemetry: Optional preloaded data

    Returns:
        None
    """
    
    # We create the engine and the cursor to connect to the database
    engine, connection = createConnection()

    # We get all data from the API
    existingColumns = getData(engine, dataTable, lapsTable, telemetryTable, raceYear, roundNumberYear, cityYear, liveData, race, raceStartTime, qualifying, telemetry)
    
    rows = getRows(connection, dataTable)

    groupedData, driverLookup, teamLookup, compoundLookup, cityLookup = processData(rows, existingColumns)

    with tf.io.TFRecordWriter(f"{outputFileName}.tfrecord") as writer:
        for (year, roundNumber, driver), data in groupedData.items():
            example = serializeSequence(year, roundNumber, driver, data["staticData"], data["lapData"], driverLookup, teamLookup, compoundLookup, cityLookup)
            writer.write(example)

    # If we are not using live data, we show the result, otherwise no
    if not liveData:
        print(f"TFRecord file {outputFileName} created successfully.")

    connection.close()