# TensorFlow core library
import tensorflow as tf

# Matplotlib for plotting and visualization
import matplotlib.pyplot as plt

# System utilities for path management
import sys
import os
sys.path.append(os.path.abspath(".."))

# Logging configuration
import logging

# Suppress TensorFlow logging and warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
logging.getLogger("tensorflow").setLevel(logging.ERROR)

# Keras model loading utility
from keras.models import load_model  # type: ignore

# Custom model and loss function
from models.transformer import Transformer
from models.loss import WeightedSequenceLoss

# Data preprocessing and preparation utilities
from processData.formatData import testingDataset, prepareInput

__all__ = ["showAttentionMaps"]

import matplotlib.pyplot as plt

def plotAttentionMaps(attention, numHeads):
    """
    Plots attention maps for each head in every Transformer layer.

    Parameters:
        attention (List[tf.Tensor]): List of attention weight tensors,
            one per layer, each with shape [1, numHeads, seqLen, seqLen].
        numHeads (int): Number of attention heads in the Transformer model.

    Returns:
        None: Displays attention maps using matplotlib.
    """

    numLayers = len(attention)

    fig, axes = plt.subplots(numLayers, numHeads, figsize=(numHeads * 2, numLayers * 2))

    for i, attn in enumerate(attention):
        for head in range(numHeads):
            attentionMatrix = attn[0, head].numpy()
            ax = axes[i, head] if numLayers > 1 else axes[head]
            ax.imshow(attentionMatrix, cmap="viridis")
            ax.set_xticks([])
            ax.set_yticks([])
            if i == 0:
                ax.set_title(f"Head {head+1}", fontsize=12)
            if head == 0:
                ax.set_ylabel(f"Layer {i+1}", fontsize=12)

    fig.subplots_adjust(wspace=0.1, hspace=0.3)
    fig.suptitle("Attention heads across layers", fontsize=16, fontweight="bold", x=0.5)
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show()
    
def showAttentionMaps(tfRecordFile, lapNumber, year, roundNumber, driverNumber):
    """
    Visualizes attention maps from Transformer models for pit stop and tyre compound predictions.

    Parameters:
        tfRecordFile (str): Path to the TFRecord file containing race telemetry and metadata.
        lapNumber (int): Current lap number (used to select input sequence length).
        year (int): Year of the race event.
        roundNumber (int): Round number in the racing season.
        driverNumber (int): Unique identifier for the driver in the dataset.

    Returns:
        None: Displays attention map plots for each Transformer layer and head.
    """

    # Maximum number of laps (based on the race with most laps, Monaco)
    maxLaps = 78

    # Training hyperparameters
    numHeads = 16

    # General lap-dependent features used for prediction
    sharedLapDependentFeatures = [
        "RPM", "Speed", "Throttle", "Sector1Time", "Sector2Time", "Sector3Time",
        "AirTemp", "Humidity", "Pressure", "Rainfall", "TrackTemp", "TrackStatus",
        "WindDirection", "WindSpeed", "TyreLife", "FreshTyre", "Stint",
        "GapToNext", "GapToLeader", "Position", "LapNumber", "LapTime"
    ]

    sharedNonLapDependentFeatures = [
        "Driver", "RoundNumber", "Year", "Q1",
        "Q2", "Q3", "Team", "GridPosition", "City"
    ]

    # Define labels
    pitstopLabel = "LapsUntilNextPit"
    compoundLabel = "Compound"

    # Lap-dependent features for pitstop prediction
    lapDependentFeaturesPitstop = sharedLapDependentFeatures.copy()
    nonLapDependentFeaturesPitstop = sharedNonLapDependentFeatures.copy()

    # Lap-dependent features for compound prediction
    lapDependentFeaturesCompound = sharedLapDependentFeatures.copy()
    nonLapDependentFeaturesCompound = sharedNonLapDependentFeatures.copy()

    # Load models
    pitstopModel = load_model("../models/pitstop.keras")
    compoundModel = load_model("../models/compound.keras")

    # Load test data
    lapDependentDataPitstop, nonLapDependentPitstop, _ = next(iter(testingDataset(tfRecordFile, lapDependentFeaturesPitstop, nonLapDependentFeaturesPitstop, year, roundNumber, driverNumber, pitstopLabel)))
    lapDependentDataCompound, nonLapDependentCompound, _ = next(iter(testingDataset(tfRecordFile, lapDependentFeaturesCompound, nonLapDependentFeaturesCompound, year, roundNumber, driverNumber, compoundLabel)))

    # Determine sequence length and padding
    prefixLength = lapNumber - 1
    paddingLength = maxLaps - prefixLength

    # Prepare inputs
    inputsPitstop = prepareInput(lapDependentDataPitstop, nonLapDependentPitstop, prefixLength, paddingLength)
    inputsCompound = prepareInput(lapDependentDataCompound, nonLapDependentCompound, prefixLength, paddingLength)

    # Run inference
    _, attentionWeightsPitstop = pitstopModel(inputsPitstop, training=False, returnIntermediate=True)
    _, attentionWeightsCompound = compoundModel(inputsCompound, training=False, returnIntermediate=True)

    # Plot attention maps
    plotAttentionMaps(attentionWeightsPitstop["attentionWeights"], numHeads)
    plotAttentionMaps(attentionWeightsCompound["attentionWeights"], numHeads)