# TensorFlow for deep learning and model preparation
import tensorflow as tf

# Matplotlib for plotting
import matplotlib.pyplot as plt

# Numerical operations
import numpy as np

# Python utilities for counting elements in collections
from collections import Counter

# TensorFlow Keras utilities
from tensorflow.keras.layers import Normalization  # type: ignore

__all__ = ["testingDataset", "prepareInput", "prepareDataset", "getWeights"]

def parseTfExample(example):
    """
    Parses a serialized TFRecord example into a dictionary of tensors

    Parameters:
        example (tf.Tensor): A scalar string Tensor representing one serialized Example

    Returns:
        dict: A dictionary mapping feature keys to parsed Tensor or SparseTensor values
    """

    features = {
      # Non-lap dependent features
      **{k: tf.io.FixedLenFeature([], tf.int64)   for k in ["Driver","DriverNumber","RoundNumber","Year","Team","City"]},
      **{k: tf.io.FixedLenFeature([], tf.float32) for k in ["Q1","Q2","Q3","GridPosition"]},

      # Lap-dependent features
      **{k: tf.io.VarLenFeature(tf.float32) for k in [
          "RPM","Speed","Throttle","Sector1Time","Sector2Time","Sector3Time",
          "AirTemp","Humidity","Pressure","Rainfall","TrackTemp","TrackStatus",
          "WindDirection","WindSpeed","PitTime","TyreLife","FreshTyre",
          "Stint","GapToNext","GapToLeader","Position","LapTime"
      ]},
      "LapNumber": tf.io.VarLenFeature(tf.int64),
      "Compound": tf.io.VarLenFeature(tf.int64),
      "LapsUntilNextPit": tf.io.VarLenFeature(tf.int64)
    }

    ex = tf.io.parse_single_example(example, features)
    return ex

def extractFeatures(ex, lapDependentFeatures, nonLapDependentFeatures, targetValue):
    """
    Extracts structured lap-dependent, static, and target tensors from parsed data

    Parameters:
        ex (dict): Dictionary from parseTfExample
        lapDependentFeatures (list): Keys for lap-wise features
        nonLapDependentFeatures (list): Keys for static features
        targetValue (str): Key for the target sequence

    Returns:
        tuple: lapDependent, nonLapDependent, and target tensors
    """

    # Build lap-dependent tensor with shape [num_laps, D_lap]
    lapDependentColumns = []
    for k in lapDependentFeatures:
        v = tf.sparse.to_dense(ex[k], default_value=0)
        if v.dtype == tf.int64:
            v = tf.cast(v, tf.float32)
        lapDependentColumns.append(v)
    lapDependent = tf.stack(lapDependentColumns, axis=-1)

    # Build non-lap-dependent tensor with shape [D_nonLapDependent]
    nonLapDependentColumns = []
    for k in nonLapDependentFeatures:
        val = ex[k]
        val = tf.cast(val, tf.float32)
        nonLapDependentColumns.append(val)
    nonLapDependent = tf.stack(nonLapDependentColumns, axis=-1)

    # Extract and cast the target sequence [num_laps]
    target = tf.sparse.to_dense(ex[targetValue], default_value=0)
    target = tf.cast(target, tf.int32)

    return lapDependent, nonLapDependent, target

def filterByRace(ex, year, roundNumber, driver):
    """
    Returns a boolean tensor to filter data for a specific race and driver

    Parameters:
        ex (dict): Parsed example
        year (int): Target year
        roundNumber (int): Target round number
        driver (int): Target driver number

    Returns:
        tf.Tensor: A boolean scalar tensor used for filtering
    """

    return tf.logical_and(
        tf.logical_and(
            tf.equal(tf.cast(ex["Year"], tf.int64), tf.cast(year, tf.int64)),
            tf.equal(tf.cast(ex["RoundNumber"], tf.int64), tf.cast(roundNumber, tf.int64))
        ),
        tf.equal(tf.cast(ex["DriverNumber"], tf.int64), tf.cast(driver, tf.int64))
    )

def makePrefixes(lapDependent, nonLapDependent, lapDependentFeatures, nonLapDependentFeatures, targetData, maxLaps, numClasses):
    """
    Creates training prefixes for each lap to model sequences step by step

    Parameters:
        lapDependent (tf.Tensor): Tensor of shape [num_laps, D_lap]
        nonLapDependent (tf.Tensor): Tensor of shape [D_nonLapDependent]
        lapDependentFeatures (list): Lap-wise feature names
        nonLapDependentFeatures (list): Static feature names
        targetData (tf.Tensor): Target sequence of shape [num_laps]
        maxLaps (int): Maximum number of laps used for padding
        numClasses (int): Total number of classification classes

    Returns:
        tf.data.Dataset: A dataset of ((sequence, mask, nonLapDependent), target) tuples
    """

    numLaps = tf.shape(lapDependent)[0]

    def _one(t):
        # Slice the first t laps
        lapDendentToT = lapDependent[:t]

        # Pad the sequence to maxLaps
        paddedLapDependent = tf.pad(lapDendentToT, [[0, maxLaps - t], [0, 0]])

        # Create a boolean mask to differentiate valid and padded entries
        mask = tf.concat([
            tf.ones([t], tf.bool),
            tf.zeros([maxLaps - t], tf.bool)
        ], axis=0)

        # Clamp target class index to numClasses - 1
        target = tf.minimum(targetData[t-1], numClasses-1)
        return paddedLapDependent, mask, nonLapDependent, target

    # Apply the function to each t in [1, ..., numLaps]
    sequence, masks, nonLapDependents, targets = tf.map_fn(
        _one,
        tf.range(1, numLaps + 1, dtype=tf.int32),
        fn_output_signature=(
            tf.TensorSpec([maxLaps, len(lapDependentFeatures)], tf.float32),
            tf.TensorSpec([maxLaps], tf.bool),
            tf.TensorSpec([len(nonLapDependentFeatures)], tf.float32),
            tf.TensorSpec([], tf.int32)
        )
    )

    return tf.data.Dataset.from_tensor_slices((sequence, masks, nonLapDependents, targets))

def toKeras(lapDependent, mask, nonLapDependent, targetValue):
    """
    Formats data into a structure compatible with Keras model input

    Parameters:
        lapDependent (tf.Tensor): Lap-dependent tensor of shape [maxLaps, D_lap]
        mask (tf.Tensor): Boolean mask of shape [maxLaps]
        nonLapDependent (tf.Tensor): Static feature tensor of shape [D_nonLapDependent]
        targetValue (tf.Tensor): Target label

    Returns:
        tuple: (input_dict, targetValue) suitable for model.fit
    """

    return (
        {
            "lapDependent": lapDependent,
            "lapMask": mask,
            "nonLapDependent": nonLapDependent
        },
        targetValue
    )

def prepareDataset(tfRecordFile, lapDependentFeatures, nonLapDependentFeatures, targetValue, batchSize, maxLaps, classes):
    """
    Prepares a normalized and batched dataset from TFRecord files for training and evaluation

    Parameters:
        tfRecordFile (str): Path to the TFRecord file
        lapDependentFeatures (list): Keys for lap-wise input features
        nonLapDependentFeatures (list): Keys for static input features
        targetValue (str): Key for the classification target
        batchSize (int): Size of each training batch
        maxLaps (int): Maximum number of laps per input
        classes (int): Total number of output classes

    Returns:
        tuple: (
            lapDependentNormalizer,
            nonLapDependentNormalizer,
            trainDataset,
            validationDataset,
            testDataset,
            validationSteps
        )
    """

    # Load the TFRecord file and start parsing
    raw = tf.data.TFRecordDataset(tfRecordFile)

    # Full preprocessing pipeline
    dataset = (
        raw
        .map(lambda ex: parseTfExample(ex), num_parallel_calls=tf.data.AUTOTUNE)
        .map(lambda ex: extractFeatures(ex, lapDependentFeatures, nonLapDependentFeatures, targetValue), num_parallel_calls=tf.data.AUTOTUNE)
        .flat_map(lambda lapDependent, nonLapDependent, target: makePrefixes(lapDependent, nonLapDependent, lapDependentFeatures, nonLapDependentFeatures, target, maxLaps, classes))
        .map(toKeras)
        .cache()
    )

    # Calculate total dataset size
    totalSize = sum(1 for _ in dataset)
    trainSize = int(0.7 * totalSize)
    validationSize = int(0.15 * totalSize)
    testSize = totalSize - trainSize - validationSize
    validationSteps = int(np.ceil(validationSize / batchSize))

    # Shuffle once to ensure reproducibility, then split
    shuffled = dataset.shuffle(totalSize, reshuffle_each_iteration=False)
    trainSplit = shuffled.take(trainSize)

    # Prepare training dataset
    trainDataset = (
        trainSplit
        .repeat()
        .batch(batchSize)
        .prefetch(tf.data.AUTOTUNE)
    )

    # Prepare validation dataset
    validationDataset = (
        shuffled
        .skip(trainSize)
        .take(validationSize)
        .repeat()
        .batch(batchSize)
        .prefetch(tf.data.AUTOTUNE)
    )

    # Prepare test dataset
    testDataset = (
        shuffled
        .skip(trainSize + validationSize)
        .take(testSize)
        .batch(batchSize)
        .prefetch(tf.data.AUTOTUNE)
    )

    # Create normalizers for lap-dependent and static feature
    lapDependentNormalizer = Normalization(axis=-1, name="lapDependentNormalizer")
    nonLapDependentNormalizer = Normalization(axis=-1, name="nonLapDependentNormalizer")

    # Extract feature inputs for normalization adaptation
    sampleToAdapt = trainSplit.map(lambda x, y: x)

    # Adapt normalizers using training data
    lapDependentNormalizer.adapt(sampleToAdapt.map(lambda inp: inp["lapDependent"]))
    nonLapDependentNormalizer.adapt(sampleToAdapt.map(lambda inp: inp["nonLapDependent"]))

    return lapDependentNormalizer, nonLapDependentNormalizer, trainDataset, validationDataset, testDataset, validationSteps

def testingDataset(tfRecordFile, lapDependentFeatures, nonLapDependentFeatures, year, roundNumber, driverNumber, targetValue):
    """
    Loads and filters a dataset to include only entries for a specific race and driver

    Parameters:
        tfRecordFile (str): Path to the TFRecord file
        lapDependentFeatures (list): Lap-wise feature keys
        nonLapDependentFeatures (list): Static feature keys
        year (int): Year to filter
        roundNumber (int): Round to filter
        driverNumber (int): Driver number to filter
        targetValue (str): Target feature key

    Returns:
        tf.data.Dataset: A dataset with one example filtered by race and driver
    """

    # Load the raw serialized TFRecord dataset
    raw = tf.data.TFRecordDataset(tfRecordFile)

    # Parse each example into structured feature format
    parsed = raw.map(lambda ex: parseTfExample(ex))

    # Filter the dataset to keep only the desired race and driver
    filtered = parsed.filter(lambda ex: filterByRace(ex, year, roundNumber, driverNumber))

    # Extract structured tensors (lap-dependent, non-lap-dependent, target)
    final = filtered.map(lambda ex: extractFeatures(ex, lapDependentFeatures, nonLapDependentFeatures, targetValue))

    return final

def getWeights(tfRecordFile, target, numClasses):
    """
    Computes inverse-frequency class weights and plots class distribution

    Parameters:
        tfRecordFile (str): Path to the TFRecord file
        target (str): Key name of the target label
        numClasses (int): Number of target classes

    Returns:
        np.ndarray: An array of class weights for loss adjustment
    """

    laps = []

    # Read and parse each example from the TFRecord file
    for example in tf.data.TFRecordDataset(tfRecordFile):
        parsed = parseTfExample(example)
        dense = tf.sparse.to_dense(parsed[target], default_value=0)
        laps.extend(dense.numpy().tolist())

    # Count frequency of each label
    counts = Counter(laps)

    # Plot the raw distribution of the target classes
    plt.bar(counts.keys(), counts.values())
    plt.xlabel(target, fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.title(f"Distribution of {target} in the dataset", fontweight="bold", fontsize=16, loc="center")
    plt.show()

    # Flatten the lap list and filter for valid class indices
    flatLaps = [item for seq in laps for item in (seq if isinstance(seq, list) else [seq])]
    counts = Counter([x for x in flatLaps if 0 <= x <= numClasses])

    total = sum(counts.values())

    # Compute inverse frequency weights
    weightsDict = {i: total / (len(counts) * count) for i, count in counts.items()}
    weights = np.array([weightsDict.get(i, 1.0) for i in range(numClasses)], dtype=np.float32)

    # Plot the computed weights
    plt.figure(figsize=(10, 4))
    plt.bar(range(len(weights)), weights)
    plt.xlabel("Class")
    plt.ylabel("Assigned weight")
    plt.title("Class weights for weighted loss")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    return weights

def prepareInput(lapDependentData, nonLapDependent, prefixLength, paddingLength):
    """
    Slices and pads input data for inference or prediction using a Transformer model

    Parameters:
        lapDependentData (tf.Tensor): Lap-wise data of shape [num_laps, D_lap]
        nonLapDependent (tf.Tensor): Static features of shape [D_static]
        prefixLength (int): Number of observed laps
        paddingLength (int): Padding to apply to reach model input size

    Returns:
        dict: Dictionary formatted for Keras model input
    """

    # Slice the data (shape [prefixLength, D_lap])
    lapDependentPrefix = lapDependentData[:prefixLength]

    # Add padding for consistency (shape [maxLaps, D_lap])
    lapDependent = tf.pad(lapDependentPrefix, [[0, paddingLength], [0, 0]])

    # Create a binary mask: True for valid laps, False for padding
    mask = tf.concat([
      tf.ones([prefixLength], tf.bool),
      tf.zeros([paddingLength], tf.bool)
    ], axis=0)

	# Expand dimensions to simulate a batch of size 1
    lapDependent = tf.expand_dims(lapDependent, 0)
    lapMask = tf.expand_dims(mask, 0)
    nonLapDependent = tf.expand_dims(nonLapDependent, 0)

    # Format inputs for Keras model
    inputs = toKeras(lapDependent, lapMask, nonLapDependent, "")
    return inputs[0]