# TensorFlow for deep learning and prediction
import tensorflow as tf

# Plotting and visualization
import matplotlib.pyplot as plt
import numpy as np

# Enhanced visualization library built on top of matplotlib
import seaborn as sns

# Evaluation metrics and tools from scikit-learn
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    mean_absolute_error,
    mean_squared_error,
    top_k_accuracy_score
)

# For model confidence calibration 
from sklearn.calibration import calibration_curve

__all__ = ["getTheoreticalResults"]

def getPredictions(model, dataset):
    """
    Runs model inference on a dataset to collect ground truth labels, predicted classes, and predicted probabilities.

    Parameters:
        model (tf.keras.Model): The trained model used for inference.
        dataset (tf.data.Dataset): A TensorFlow dataset containing batches of (input, target).

    Returns:
        tuple: A tuple containing:
            - yTrue (np.ndarray): Ground truth labels.
            - yPred (np.ndarray): Predicted class indices.
            - yProbs (np.ndarray): Predicted probabilities for each class.
    """

    yTrue = []
    yPred = []
    yProbs = []

    for batchInputs, batchTargets in dataset:
        preds = model.predict(batchInputs, verbose=0)
        predictedClasses = tf.argmax(preds, axis=-1).numpy().flatten()
        yPred.extend(predictedClasses)
        yTrue.extend(batchTargets.numpy().flatten())

        # Handle 3D logits from sequence models
        if preds.ndim == 3:
            B, L, C = preds.shape
            preds_2d = preds.reshape(B * L, C)
            yProbs.extend(preds_2d)
        else:
            yProbs.extend(preds)

    return np.array(yTrue), np.array(yPred), np.array(yProbs)

def plotConfusionMatrix(yTrue, yPred, numClasses, normalize=False, figsize=(20, 12), filename=""):
    """
    Plots a confusion matrix heatmap and saves it as a PDF file.

    Parameters:
        yTrue (np.ndarray): Ground truth labels.
        yPred (np.ndarray): Predicted labels.
        numClasses (int): Number of target classes.
        normalize (bool): If True, normalize the matrix by row (per-class).
        figsize (tuple): Size of the figure (default: (20, 12)).
        filename (str): File name prefix for the saved PDF.
    """

    # Define the labels to use
    labels = [f"Class {i}" for i in range(numClasses)]

    # Compute the confusion matrix
    cm = confusion_matrix(yTrue, yPred, normalize="true" if normalize else None)

    # Plot the matrix
    plt.figure(figsize=figsize)
    sns.heatmap(
        cm,
        annot=True,
        fmt=".2f" if normalize else "d",
        cmap="Blues",
        xticklabels=labels,
        yticklabels=labels
    )
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.xlabel("Predicted", fontsize=12)
    plt.ylabel("True", fontsize=12)
    title = "Normalized Confusion Matrix" if normalize else "Confusion Matrix"
    plt.title(f"{title}".strip(), fontweight="bold", fontsize=16, loc="center")
    plt.tight_layout()
    plt.savefig(f"{filename}.pdf".lower())
    plt.show()

def getClassificationMetrics(yTrue, yPred, probs, numClasses):
    """
    Prints detailed classification metrics and top-k accuracies.

    Parameters:
        yTrue (np.ndarray): True class labels.
        yPred (np.ndarray): Predicted class labels.
        probs (np.ndarray): Model-predicted class probabilities.
        numClasses (int): Total number of classes.
    """

    print("Classification Report")
    labels = [f"Class {i}" for i in range(numClasses)]
    print(classification_report(yTrue, yPred, target_names=labels, zero_division=0))

    mae = mean_absolute_error(yTrue, yPred)
    rmse = np.sqrt(mean_squared_error(yTrue, yPred))
    print(f"\nMean Absolute Error (MAE):  {mae:.4f}")
    print(f"Root Mean Squared Error (RMSE): {rmse:.4f}")

    top2 = top_k_accuracy_score(yTrue, probs, k=2, labels=np.arange(numClasses))
    top3 = top_k_accuracy_score(yTrue, probs, k=3, labels=np.arange(numClasses))
    print(f"\nTop-2 Accuracy: {top2:.4f}")
    print(f"Top-3 Accuracy: {top3:.4f}")

def plotCalibrationCurves(yTrue, probs, selectedClasses, numBins=10, filename=""):
    """
    Plots calibration curves for selected classes and saves the result as a PDF.

    Parameters:
        yTrue (np.ndarray): Ground truth class labels.
        probs (np.ndarray): Predicted class probabilities (shape: N x C).
        selectedClasses (list of int): Indices of classes to include in the plot.
        numBins (int): Number of bins for calibration computation.
        filename (str): File name prefix for the output PDF.
    """

    plt.figure(figsize=(8, 6))
    for cls in selectedClasses:
        yBin = (yTrue == cls).astype(int)
        probClasses = probs[:, cls]
        fracPos, meanProb = calibration_curve(yBin, probClasses, n_bins=numBins, strategy="uniform")
        plt.plot(meanProb, fracPos, marker="o", label=f"Class {cls}")

    plt.plot([0, 1], [0, 1], "k--", label="Perfect calibration")
    plt.xlabel("Mean predicted probability", fontsize=12)
    plt.ylabel("Fraction of positives", fontsize=12)
    plt.title("Calibration curves", fontweight="bold", fontsize=16, loc="center")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{filename}.pdf".lower())
    plt.show()

def computeEce(yTrue, probs, nBins=10):
    """
    Computes and prints the Expected Calibration Error (ECE) of a model.

    ECE measures the difference between predicted probability confidence and actual accuracy.

    Parameters:
        yTrue (np.ndarray): Ground truth class labels.
        probs (np.ndarray): Predicted class probabilities.
        nBins (int): Number of bins used to calculate calibration error.
    """

    confidences = np.max(probs, axis=1)
    predictions = np.argmax(probs, axis=1)
    accuracies = (predictions == yTrue)
    binEdges = np.linspace(0, 1, nBins + 1)
    ece = 0.0

    for i in range(nBins):
        binMask = (confidences > binEdges[i]) & (confidences <= binEdges[i + 1])
        binSize = np.sum(binMask)
        if binSize > 0:
            binAccuracy = np.mean(accuracies[binMask])
            binConfidence = np.mean(confidences[binMask])
            ece += (binSize / len(yTrue)) * np.abs(binAccuracy - binConfidence)

    print(f"Expected Calibration Error (ECE): {ece:.4f}")

def getTheoreticalResults(model, dataset, classes, name):
    """
    Evaluates a model's performance with classification metrics, confusion matrices,
    calibration plots, and expected calibration error (ECE).

    Outputs:
        - Classification report with MAE, RMSE, top-k accuracies.
        - Confusion matrix (raw and normalized) saved as PDF.
        - Calibration curves saved as PDF.
        - Printed ECE.

    Parameters:
        model (tf.keras.Model): The trained model to evaluate.
        dataset (tf.data.Dataset): The input dataset for testing.
        classes (int): Number of output classes.
        name (str): Identifier prefix for saved plot files.
    """

    yTrue, yPred, probs = getPredictions(model, dataset)

    getClassificationMetrics(yTrue, yPred, probs, classes)

    # Plot raw and normalized confusion matrices
    plotConfusionMatrix(yTrue, yPred, classes, filename=f"{name}ConfusionMatrix")
    plotConfusionMatrix(yTrue, yPred, classes, normalize=True, filename=f"{name}NormalizedConfusionMatrix")

    # Decide which classes to show in calibration curves
    selectedClasses = list(range(0, max(classes, 1), 5)) if classes >= 10 else list(range(classes))
    plotCalibrationCurves(yTrue, probs, selectedClasses=selectedClasses, filename=f"{name}CalibrationCurves")

    # Compute the Expected Calibration Error
    computeEce(yTrue, probs)