import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ast
import os

# --- Configuration ---
# The CSV file generated by the training script
CSV_FILE_PATH = "../results/dataset_perturbation_results.csv"

# --- Plot Styling ---
NON_DP_LABEL = "Non-DP (Baseline)" # The label for the model without DP noise
TITLE_FONTSIZE = 21
LABEL_FONTSIZE = 18
LEGEND_FONTSIZE = 16
TICK_FONTSIZE = 14

# Line thickness 
NON_DP_LINEWIDTH = 4.0
DP_LINEWIDTH = 3.4

# Colors
NON_DP_COLOR = 'black' # Consistent color for the Non-DP model
# Get 4 distinct warm colors from a colormap like 'YlOrRd' (Yellow-Orange-Red)
WARM_COLORS_CMAP = plt.get_cmap('YlOrRd')

# --- Ensure the output directory exists ---
def ensure_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

# --- Load and Process Data ---
print(f"Loading data from {CSV_FILE_PATH}...")
try:
    results_df = pd.read_csv(CSV_FILE_PATH)
except FileNotFoundError:
    print(f"ERROR: File not found at '{CSV_FILE_PATH}'.")
    print("Please update the CSV_FILE_PATH variable in the script to the correct filename.")
    exit()

def process_dataframe(df):
    """Processes the dataframe to parse history and create consistent labels."""
    processed_rows = []
    # Check if a baseline model row exists
    if not np.isinf(df['input_data_noise_epsilon']).any():
        # This is a placeholder. REPLACE with your actual baseline data if it's not in the CSV.
        baseline_placeholder = {'dataset_name': 'Original_Standardized_Baseline', 'input_data_noise_epsilon': np.inf,
                                'test_accuracy': 0.6942, 'test_f1_score': 0.6243, 'test_auc': 0.7211,
                                'keras_training_history': None} # Add a placeholder history if you have one
        df = pd.concat([pd.DataFrame([baseline_placeholder]), df], ignore_index=True)
        print("Warning: Manually added a placeholder for the baseline model. Please ensure this data is correct.")

    for _, row in df.iterrows():
        history_data = row.get('keras_training_history', None)
        history_dict = None
        if isinstance(history_data, str):
            try:
                history_dict = ast.literal_eval(history_data)
            except (ValueError, SyntaxError) as e:
                print(f"Warning: Could not parse history for row with epsilon {row['input_data_noise_epsilon']}: {e}.")
        elif isinstance(history_data, dict):
            history_dict = history_data

        new_row = row.to_dict()
        new_row['keras_training_history_dict'] = history_dict

        epsilon_val = row['input_data_noise_epsilon']
        if np.isinf(epsilon_val):
            new_row['epsilon_numeric'] = np.inf
            new_row['epsilon_label'] = NON_DP_LABEL
        else:
            new_row['epsilon_numeric'] = float(epsilon_val)
            new_row['epsilon_label'] = f"ε = {epsilon_val:g}" # Use 'g' for general format
        processed_rows.append(new_row)

    return pd.DataFrame(processed_rows)

results_df = process_dataframe(results_df)

# --- Select Data for Plotting ---
epsilons_to_plot = [1.0, 7.27, 39.28, 250.48]

# Get the baseline model results
baseline_model_results = results_df[results_df['epsilon_numeric'] == np.inf]
# Get the specific DP models requested, sorted by epsilon
dp_models_to_plot = results_df[results_df['epsilon_numeric'].isin(epsilons_to_plot)].sort_values(by='epsilon_numeric')

print("Data loaded and plotting models for specified epsilons...")
if baseline_model_results.empty:
    print(f"Warning: Baseline (Non-DP) model not found. The plot will only show DP models.")
if len(dp_models_to_plot) != len(epsilons_to_plot):
    print("Warning: Not all specified epsilon values were found in the results file.")


# --- Plotting Section ---
output_dir = "../results/"
ensure_dir(output_dir)
print(f"Plots will be saved to the '{output_dir}/' directory.")

plt.figure(figsize=(14, 8))
ax = plt.gca() # Get current axes

# Plot Non-DP baseline
if not baseline_model_results.empty:
    history = baseline_model_results.iloc[0]['keras_training_history_dict']
    if history and 'val_accuracy' in history:
        epochs = range(1, len(history['val_accuracy']) + 1)
        ax.plot(epochs, history['val_accuracy'], label=NON_DP_LABEL, color=NON_DP_COLOR, linewidth=NON_DP_LINEWIDTH, zorder=10)

# Plot the selected DP models with warm colors
# Generate 4 distinct colors from the warm colormap
warm_colors = WARM_COLORS_CMAP(np.linspace(0.3, 0.9, len(dp_models_to_plot)))

for i, (_, row) in enumerate(dp_models_to_plot.iterrows()):
    history = row['keras_training_history_dict']
    if history and 'val_accuracy' in history:
        epochs = range(1, len(history['val_accuracy']) + 1)
        ax.plot(epochs, history['val_accuracy'], label=row['epsilon_label'], color=warm_colors[i], linewidth=DP_LINEWIDTH)

# --- Apply Styling and Axis Limits ---
ax.set_title('Validation Accuracy During Training', fontsize=TITLE_FONTSIZE)
ax.set_xlabel('Epoch', fontsize=LABEL_FONTSIZE)
ax.set_ylabel('Validation Accuracy', fontsize=LABEL_FONTSIZE)
ax.tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)
ax.legend(loc='best', fontsize=LEGEND_FONTSIZE)
ax.grid(True, linestyle='--')

# Set axis limits as requested
y_base_min, y_base_max = 0.56, 0.70
margin = (y_base_max - y_base_min) * 0.05
ax.set_xlim(0, 20)
ax.set_ylim(y_base_min - margin, y_base_max + margin)

# --- Save the Plot ---
filename = os.path.join(output_dir, 'dataset_perturbation_accuracy_comparison.png')
plt.savefig(filename)
plt.close()
print(f"Saved {filename}")

print("\nPlot generation complete.")