import numpy as np
import pickle
import re
import tensorflow as tf
import tensorflow_privacy as tfp

from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy
from sklearn.metrics import roc_auc_score, f1_score

import time 

# --- Load Preprocessed Data ---
print("Loading preprocessed data from processed_data.npz...")
with np.load('../data/processed_data.npz', allow_pickle=True) as data:
    X_train_full = data['X_train']
    y_train_full = data['y_train']
    X_val = data['X_val']
    y_val = data['y_val']
    X_test = data['X_test']
    y_test = data['y_test']
print("Data loaded successfully.")

# --- Hyperparameters ---
epochs = 50  # As per your original script
batch_size = 256 # As per your original script
delta = 1e-5     # Standard value for DP

# --- Truncate training data to be a multiple of batch_size ---
num_train_samples_full = X_train_full.shape[0]
samples_to_keep = (num_train_samples_full // batch_size) * batch_size

if num_train_samples_full > samples_to_keep:
    print(f"Truncating training data from {num_train_samples_full} to {samples_to_keep} samples to ensure full batches.")
    X_train = X_train_full[:samples_to_keep]
    y_train = y_train_full[:samples_to_keep]
else:
    X_train = X_train_full
    y_train = y_train_full
    print(f"Training data ({num_train_samples_full} samples) is already a multiple of batch_size ({batch_size}) or smaller.")


print(f"X_train shape after truncation: {X_train.shape}, y_train shape after truncation: {y_train.shape}")

unique_y_classes = np.unique(y_train) # Use y_train after potential truncation for n_classes
n_classes = len(unique_y_classes) 
print(f"Number of classes (n_classes): {n_classes}")

n_features = X_train.shape[1]
print(f"Number of features (n_features): {n_features}")

# DP parameters to iterate over
noise_multipliers = [0.1, 0.3, 0.5, 0.7, 0.9] 
l2_norm_clips = [0.5, 1.0]

# --- Start Timer ---
script_start_time = time.time()

results = [] # To store results from each run

# --- Training Loop ---
for noise_multiplier in noise_multipliers:
    for l2_norm_clip in l2_norm_clips:
        print(f"\nTraining with noise_multiplier={noise_multiplier}, l2_norm_clip={l2_norm_clip}")
        
        # Define the Keras model
        model = tf.keras.Sequential([
            tf.keras.layers.Dense(128, activation='relu', input_shape=(n_features,)),
            tf.keras.layers.Dropout(0.3), # Added a dropout layer for regularization
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dropout(0.3), # Added another dropout layer
            tf.keras.layers.Dense(n_classes, activation='softmax') # Output layer
        ])

        # Define the DP optimizer.
        optimizer = tfp.DPKerasAdamOptimizer(
            l2_norm_clip=l2_norm_clip,
            noise_multiplier=noise_multiplier,
            num_microbatches=batch_size, # Using full batch_size as num_microbatches
            learning_rate=0.001
        )

        # Compile the model
        model.compile(
            optimizer=optimizer,
            # loss='sparse_categorical_crossentropy', # Use if y_train contains integer labels
            loss=tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE, from_logits=False),
            metrics=['accuracy'],
            run_eagerly=True # Add this line
        )

        # Train the model
        print(f"Starting training for {epochs} epochs...")
        history = model.fit(
            X_train, y_train,
            batch_size=batch_size,
            epochs=epochs,
            validation_data=(X_val, y_val),
            verbose=2 
        )

        # Calculate epsilon (privacy budget)
        # Note: compute_dp_sgd_privacy expects number of examples, not len(X_train) directly if microbatching is different
        # However, if num_microbatches = batch_size, it means each batch is one microbatch.
        # The `n` parameter should be the total number of training samples.
        epsilon = compute_dp_sgd_privacy(
            number_of_examples=X_train.shape[0],
            batch_size=batch_size,
            noise_multiplier=noise_multiplier,
            num_epochs=epochs,
            delta=delta,
        )
        # The function returns a tuple,
        if isinstance(epsilon, tuple):
            epsilon_value = epsilon[0]
        else: # older versions might just return the float
            epsilon_value = epsilon

        print(f"Calculated Epsilon (ε): {epsilon_value:.4f}")

        # Evaluate the model on the test set
        test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")
        
        # Generate predictions on the test set for AUC and F1-score
        y_pred_proba = model.predict(X_test, verbose=0) # Probabilities for AUC
        y_pred_classes = np.argmax(y_pred_proba, axis=1) # Class predictions for F1

        # Calculate AUC
        try:
            if n_classes > 2: # Multiclass case
                auc_macro = roc_auc_score(y_test, y_pred_proba, multi_class='ovo', average='macro')
                # auc_weighted = roc_auc_score(y_test, y_pred_proba, multi_class='ovo', average='weighted') # Alternative
            else: # Binary case (n_classes == 2)
                 # roc_auc_score expects probabilities of the positive class for binary
                auc_macro = roc_auc_score(y_test, y_pred_proba[:, 1]) # Assuming second column is positive class
            print(f"Test AUC (Macro OVO): {auc_macro:.4f}")
        except ValueError as e:
            print(f"Could not calculate AUC: {e}. Setting AUC to 0.")
            auc_macro = 0.0


        # Calculate F1-score
        f1_macro = f1_score(y_test, y_pred_classes, average='macro', zero_division=0)
        print(f"Test F1-score (Macro): {f1_macro:.4f}")

        # Store results
        results.append({
            'noise_multiplier': noise_multiplier,
            'l2_norm_clip': l2_norm_clip,
            'epsilon': epsilon_value,
            'history': history.history,
            'test_loss': test_loss,
            'test_accuracy': test_accuracy,
            'test_auc_macro': auc_macro, # Added AUC
            'test_f1_macro': f1_macro     # Added F1-score
        })

# --- End Timer and Calculate Duration ---
script_end_time = time.time()
total_execution_time = script_end_time - script_start_time
print(f"\nGradient perturbation execution time: {total_execution_time:.2f} seconds ({total_execution_time/60:.2f} minutes).")

# --- Save Results ---
results_file = '../results/gradient_perturbation_results.pkl'
print(f"\nSaving all results to {results_file}...")
try:
    with open(results_file, 'wb') as f:
        pickle.dump(results, f)
except Exception as e:
    print(f"Error saving results: {e}")

print("Process completed.")