import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.impute import SimpleImputer
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from diffprivlib.mechanisms import GaussianAnalytic

import tensorflow as tf

import time 

# --- Parameters ---
CSV_FILE_NAME = "../data/variables_final.csv"
TARGET_COLUMN_NAME = "los_class"
COLUMNS_TO_DROP_FOR_X_INITIAL = [
    TARGET_COLUMN_NAME, 'hadm_id', 'subject_id', 'los',
    'gender', 'first_careunit', 'admission_type', 'admission_location'
]
NULL_THRESHOLD = 0.9

# --- Keras Model Hyperparameters ---
KERAS_EPOCHS = 50
KERAS_BATCH_SIZE = 256

# Function to define your Keras model
def create_keras_model(n_features, n_classes):
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(128, activation='relu', input_shape=(n_features,)),
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(n_classes, activation='softmax')
    ])
    optimizer = tf.keras.optimizers.Adam()
    loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
    model.compile(optimizer=optimizer, loss=loss_function, metrics=['accuracy'])
    return model

# 1. Load and Preprocess Data
print("Loading data...")
data_df = pd.read_csv(CSV_FILE_NAME)

print("Separating target variable and initial features...")
y_labels_original = data_df[TARGET_COLUMN_NAME].values
features_df = data_df.drop(columns=COLUMNS_TO_DROP_FOR_X_INITIAL, errors='ignore')
print(f"Initial features in features_df: {features_df.shape[1]}")

numeric_features_df = features_df.select_dtypes(include=np.number)
print(f"Numeric features after type selection: {numeric_features_df.shape[1]}")

print("Replacing infinities with NaN...")
numeric_features_df.replace([np.inf, -np.inf], np.nan, inplace=True)

print(f"Dropping columns with more than {NULL_THRESHOLD*100}% null values...")
percent_null = numeric_features_df.isnull().sum() / len(numeric_features_df)
cols_to_drop_by_nulls = percent_null[percent_null >= NULL_THRESHOLD].index
numeric_features_df = numeric_features_df.drop(columns=cols_to_drop_by_nulls)
print(f"Columns dropped due to excessive nulls: {list(cols_to_drop_by_nulls)}")
print(f"Remaining features after null-based dropping: {numeric_features_df.shape[1]}")

print("Imputing remaining NaNs with the median...")
feature_names_after_null_processing = numeric_features_df.columns.tolist()
imputer = SimpleImputer(strategy='median')
imputed_features_array = imputer.fit_transform(numeric_features_df)
processed_features_df = pd.DataFrame(imputed_features_array, columns=feature_names_after_null_processing)
print(f"Remaining features after imputation: {processed_features_df.shape[1]}")

print("Dropping constant columns (zero standard deviation)...")
column_stds = processed_features_df.std()
cols_to_drop_by_std = column_stds[column_stds == 0].index
processed_features_df = processed_features_df.drop(columns=cols_to_drop_by_std)
print(f"Columns dropped due to zero standard deviation: {list(cols_to_drop_by_std)}")
print(f"Final features for scaling: {processed_features_df.shape[1]}")

if processed_features_df.empty:
    raise ValueError("No features remaining after preprocessing. Please check the steps.")
final_features_array = processed_features_df.values

n_classes_target = len(np.unique(y_labels_original))
if n_classes_target <= 1:
    raise ValueError(f"Target variable '{TARGET_COLUMN_NAME}' has {n_classes_target} unique classes.")
print(f"Number of unique classes (n_classes_target): {n_classes_target}")
y_labels_for_keras = y_labels_original

print("Splitting data into training and test sets...")
X_train_original_full, X_test_original, y_train_original_full, y_test_for_keras = train_test_split(
    final_features_array, y_labels_for_keras, test_size=0.2, random_state=42,
    stratify=y_labels_for_keras
)

num_train_samples_full = X_train_original_full.shape[0]
num_train_samples_to_use = (num_train_samples_full // KERAS_BATCH_SIZE) * KERAS_BATCH_SIZE
X_train_truncated = X_train_original_full[:num_train_samples_to_use]
y_train_truncated = y_train_original_full[:num_train_samples_to_use]
print(f"Using {num_train_samples_to_use} training samples (multiple of batch_size {KERAS_BATCH_SIZE})")

print("Applying StandardScaler...")
scaler_standard = StandardScaler()
X_train_standardized = scaler_standard.fit_transform(X_train_truncated)
X_test_standardized = scaler_standard.transform(X_test_original)

print("Applying MinMaxScaler to X_train_standardized for noise addition stage (only for noise calibration)...")
scaler_min_max = MinMaxScaler()
X_train_normalized_for_noise_calibration = scaler_min_max.fit_transform(X_train_standardized)

# --- Prepare datasets for training ---
datasets_to_train = []

# Dataset 0: Original Standardized (No DP noise) - For Baseline
datasets_to_train.append({
    "name": "Original_Standardized_Baseline", # Renamed for clarity
    "epsilon": np.inf, # Representing no DP noise addition from this script's mechanism
    "X_train": X_train_standardized,
    "y_train": y_train_truncated
})
print(f"Prepared dataset: Original_Standardized_Baseline (shape {X_train_standardized.shape})")

# Define Epsilon values for noisy datasets
epsilon_values_to_test = [0.1, 1.0, 2.5, 5.0, 7.27, 13.56, 20.0, 39.28, 100.0, 250.48]
delta_value = 1e-5

# --- Start Timer ---
script_start_time = time.time()

print("\nGenerating noisy datasets...")
for current_epsilon in epsilon_values_to_test:
    print(f"Generating noisy dataset for epsilon = {current_epsilon:.2f}")
    # Using GaussianAnalytic as it handles epsilon > 1 well and was used in last successful run
    gaussian_mechanism = GaussianAnalytic(epsilon=current_epsilon, delta=delta_value, sensitivity=1.0)
    
    X_noisy_normalized_current = np.zeros_like(X_train_normalized_for_noise_calibration)

    for col_idx in range(X_train_normalized_for_noise_calibration.shape[1]):
        for row_idx in range(X_train_normalized_for_noise_calibration.shape[0]):
            scalar_value = X_train_normalized_for_noise_calibration[row_idx, col_idx]
            X_noisy_normalized_current[row_idx, col_idx] = gaussian_mechanism.randomise(scalar_value)

    X_noisy_std = scaler_min_max.inverse_transform(X_noisy_normalized_current)
    datasets_to_train.append({
        "name": f"Noisy_Std_Eps_{current_epsilon:.2f}", # Adjusted to ensure two decimal places for consistency
        "epsilon": current_epsilon,
        "X_train": X_noisy_std,
        "y_train": y_train_truncated
    })
    print(f"Prepared dataset: Noisy_Std_Eps_{current_epsilon:.2f} (shape {X_noisy_std.shape})")

# --- Train Keras models on the prepared datasets and evaluate ---
print("\nTraining and evaluating Keras models...")
results_list = []
n_features_model = X_train_standardized.shape[1]

for dataset_info in datasets_to_train:
    current_dataset_name = dataset_info["name"]
    current_epsilon_val = dataset_info["epsilon"]
    current_X_train = dataset_info["X_train"]
    current_y_train_int_labels = dataset_info["y_train"]

    print(f"\nProcessing dataset: {current_dataset_name} (Input Data Noise Epsilon: {current_epsilon_val})")

    keras_model = create_keras_model(n_features=n_features_model, n_classes=n_classes_target)
    
    print(f"Training Keras model for {current_dataset_name}...")
    try:
        history = keras_model.fit(
            current_X_train, current_y_train_int_labels,
            epochs=KERAS_EPOCHS,
            batch_size=KERAS_BATCH_SIZE,
            validation_split=0.1, 
            verbose=0 
        )
        print("Model training completed.")
        
        print("Evaluating model on clean, standardized test set...")
        loss, acc = keras_model.evaluate(X_test_standardized, y_test_for_keras, verbose=0) 
        
        y_pred_proba = keras_model.predict(X_test_standardized)
        y_pred_classes = np.argmax(y_pred_proba, axis=1)
        
        prec = precision_score(y_test_for_keras, y_pred_classes, average='weighted', zero_division=0)
        rec = recall_score(y_test_for_keras, y_pred_classes, average='weighted', zero_division=0)
        f1_val = f1_score(y_test_for_keras, y_pred_classes, average='macro', zero_division=0)
        
        try:
            auc = roc_auc_score(y_test_for_keras, y_pred_proba, multi_class='ovo', average='macro')
        except ValueError as e_auc:
            print(f"Could not calculate AUC: {e_auc}. Setting AUC to NaN.")
            auc = np.nan
            
        # Convert history.history lists to json serializable format (e.g. list of floats)
        # Keras history values can sometimes be float32, which might not be directly JSON serializable
        # or easily handled by Pandas DataFrame if not converted.
        serializable_history = {k: [float(val) for val in v] for k, v in history.history.items()}

        results_list.append({
            "dataset_name": current_dataset_name,
            "input_data_noise_epsilon": current_epsilon_val,
            "test_loss": float(loss), # Ensure basic Python float
            "test_accuracy": float(acc),
            "test_precision": float(prec),
            "test_recall": float(rec),
            "test_f1_score": float(f1_val),
            "test_auc": float(auc) if not np.isnan(auc) else np.nan, # Handle NaN for AUC
            "keras_training_history": serializable_history, # MODIFIED: Save full history
            "training_epochs_completed": len(history.history['loss']) # Keep actual epochs run
        })
        print(f"Results for {current_dataset_name}: Test Acc={acc:.4f}, Test F1={f1_val:.4f}, Test AUC={auc:.4f}")

    except Exception as e:
        print(f"ERROR training or evaluating Keras model for {current_dataset_name}: {e}")
        import traceback
        traceback.print_exc()
        results_list.append({
            "dataset_name": current_dataset_name,
            "input_data_noise_epsilon": current_epsilon_val,
            "test_loss": np.nan,
            "test_accuracy": np.nan,
            "test_precision": np.nan,
            "test_recall": np.nan,
            "test_f1_score": np.nan,
            "test_auc": np.nan,
            "keras_training_history": {"error": str(e)}, # Save error in history field
            "training_epochs_completed": 0,
            "error_message": str(e) # Separate error message field
        })

# --- End Timer and Calculate Duration ---
script_end_time = time.time()
total_execution_time = script_end_time - script_start_time
print(f"\nDataset perturbation execution time: {total_execution_time:.2f} seconds ({total_execution_time/60:.2f} minutes).")

# 4. Save results
print("\nSaving all results...")
results_dataframe = pd.DataFrame(results_list)
output_csv_filename = "../results/dataset_perturbation_results.csv"
try:
    results_dataframe.to_csv(output_csv_filename, index=False)
    print(f"Results saved to {output_csv_filename}")
except Exception as e:
    print(f"Error saving results to CSV: {e}")

print("\nProcess completed.")