import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight

# Name of the CSV file
csv_file_name = "../data/variables_final.csv"

# Name of the column that contains the target variable
target_column_name = "los_class"

# NaN threshold
nan_threshold = 0.90

# Columns to exclude
columns_to_drop_for_X_initial = [
    target_column_name,
    'hadm_id',
    'subject_id',
    'los', # Original Length of Stay, likely correlated with los_class
    'gender', # Categorical text
    'first_careunit', # Categorical text
    'admission_type', # Categorical text
    'admission_location' # Categorical text
]

# Load dataset
df = pd.read_csv(csv_file_name)

# Separate target variable y
y = df[target_column_name].values

# Create DataFrame for features X by dropping initial set of columns
X_df = df.drop(columns=columns_to_drop_for_X_initial)

# Identify and remove columns with zero standard deviation (constant columns)
numeric_columns_in_X_df = X_df.select_dtypes(include=np.number).columns # Ensure we only check numeric columns
column_stds = X_df[numeric_columns_in_X_df].std()
columns_with_zero_std = column_stds[column_stds == 0].index.tolist()



# --- Step 1: Identify and remove columns with more than 90% NaN ---
print("\nStep 1: Checking for columns that are >90% NaN...")
null_ratio = X_df.isnull().sum() / len(X_df)

# Identify the columns where the ratio of null values exceeds the defined threshold.
cols_to_drop = null_ratio[null_ratio > nan_threshold].index.tolist()

if cols_to_drop:
    print(f"Found {len(cols_to_drop)} columns with more than {nan_threshold:.0%} NaN values. These will be removed:")
    for col_name in cols_to_drop:
        # Print the exact percentage for more context.
        print(f" - {col_name} (has {null_ratio[col_name]:.2%} NaN values)")
        
    # Drop the identified columns from the DataFrame.
    X_df = X_df.drop(columns=cols_to_drop)
    print(f"Number of columns in X_df after removal: {X_df.shape[1]}")
else:
    print(f"No columns found with more than {nan_threshold:.0%} NaN values.")

# --- Step 2: Handle remaining NaNs and Infs in X_df ---
print("\nStep 2: Checking for remaining NaNs and Infs in X_df...")
if X_df.empty: # Check if X_df became empty after dropping all-NaN columns
    print("X_df is empty after removing all-NaN columns. No further NaN/Inf processing needed.")
else:
    nan_counts = X_df.isnull().sum()
    inf_counts = X_df.apply(lambda col: np.isinf(col).sum() if pd.api.types.is_numeric_dtype(col) else 0)
    columns_with_nans = nan_counts[nan_counts > 0]
    columns_with_infs = inf_counts[inf_counts > 0]

    if not columns_with_nans.empty:
        print("Columns with remaining NaNs and their counts:")
        print(columns_with_nans)
    else:
        print("No remaining NaNs found in X_df.")

    if not columns_with_infs.empty:
        print("Columns with Infs and their counts:")
        print(columns_with_infs)
    else:
        print("No Infs found in X_df.")

    if not columns_with_nans.empty or not columns_with_infs.empty:
        print("Applying temporary imputation for remaining NaNs/Infs: NaNs with mean, Infs with 0.")
        for col in X_df.columns: # Iterate over remaining columns
            if pd.api.types.is_numeric_dtype(X_df[col]):
                if np.isinf(X_df[col]).any():
                    X_df[col] = X_df[col].replace([np.inf, -np.inf], 0)
                if X_df[col].isnull().any(): # Check again if NaNs still exist after all-NaN column removal
                    mean_val = X_df[col].mean()
                    if np.isnan(mean_val): # This can happen if somehow a column had only NaNs and wasn't caught by .isnull().all()
                        print(f"Warning: Mean for column {col} is NaN during imputation. Filling with 0 instead.")
                        mean_val = 0 
                    X_df[col] = X_df[col].fillna(mean_val)
        print("Temporary imputation for remaining NaNs/Infs completed.")
        nan_counts_after = X_df.isnull().sum().sum()
        inf_counts_after = X_df.apply(lambda col: np.isinf(col).sum() if pd.api.types.is_numeric_dtype(col) else 0).sum()
        print(f"NaNs remaining after final imputation: {nan_counts_after}")
        print(f"Infs remaining after final imputation: {inf_counts_after}")
    else:
        print("No further NaNs or Infs to impute.")


# --- Step 3: Check for columns with zero or near-zero standard deviation ---
print("\nStep 3: Checking for columns with problematic (near-zero) standard deviation...")
if X_df.empty or X_df.shape[1] == 0:
    print("X_df is empty or has no columns. Skipping std deviation check.")
else:
    numeric_cols_for_std_check = X_df.select_dtypes(include=np.number).columns
    if not numeric_cols_for_std_check.empty:
        column_stds = X_df[numeric_cols_for_std_check].std()
        std_threshold = 1e-9 
        columns_with_near_zero_std = column_stds[column_stds < std_threshold].index.tolist()

        if columns_with_near_zero_std:
            print(f"Found {len(columns_with_near_zero_std)} columns with standard deviation < {std_threshold}. These will be removed:")
            for col_name in columns_with_near_zero_std:
                print(f" - {col_name} (std: {column_stds[col_name]})")
            X_df = X_df.drop(columns=columns_with_near_zero_std)
        else:
            print("No columns with problematic near-zero standard deviation found.")
    else:
        print("No numeric columns available to check for standard deviation.")

final_feature_names = []
if not X_df.empty:
    final_feature_names = X_df.columns.tolist()
    
print(f"Final number of features: {len(final_feature_names)}")
if final_feature_names:
    print(f"Final feature names: {final_feature_names if len(final_feature_names) < 10 else final_feature_names[:5] + ['...'] + final_feature_names[-5:]}")

X = np.array([])
if not X_df.empty:
    X = X_df.values
print(f"Feature DataFrame X_df converted to NumPy array X. Shape of X: {X.shape}")



# Calculate class_weight
unique_classes = np.unique(y)
class_weights_computed = compute_class_weight(class_weight='balanced', classes=unique_classes, y=y)
class_weights = dict(enumerate(class_weights_computed))

# Normalize numerical features and split data (70/15/15)
scaler = StandardScaler()
if X.shape[1] > 0: # Only scale if there are features left
    X_scaled = scaler.fit_transform(X)
else:
    print("Warning: No features left after removing zero-variance columns. Check your data.")
    X_scaled = X

# Split data: 70% for training, 30% for temporary
X_train, X_temp, y_train, y_temp = train_test_split(X_scaled, y, test_size=0.3, stratify=y)

# Split temporary data: 50% of temp for validation, 50% for test
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp)

# Print shapes and distributions to verify
print(f"Shape of X_train: {X_train.shape}")
print(f"Shape of y_train: {y_train.shape}")
print(f"Shape of X_val: {X_val.shape}")
print(f"Shape of y_val: {y_val.shape}")
print(f"Shape of X_test: {X_test.shape}")
print(f"Shape of y_test: {y_test.shape}")

print(f"\nDistribution of classes in y_train: {np.unique(y_train, return_counts=True)}")
print(f"Distribution of classes in y_val: {np.unique(y_val, return_counts=True)}")
print(f"Distribution of classes in y_test: {np.unique(y_test, return_counts=True)}")

print(f"\nUnique classes in target variable 'y': {unique_classes}")
print(f"Number of classes (for n_classes in train_dp.py): {len(unique_classes)}")
print(f"Calculated class weights: {class_weights}")

# Save the processed data
np.savez('../data/processed_data.npz', 
         X_train=X_train, y_train=y_train,
         X_val=X_val, y_val=y_val,
         X_test=X_test, y_test=y_test,
         class_weights=np.array(list(class_weights.items()), dtype=object),
         columns_X=final_feature_names
        )

print("\nPreprocessed data saved to processed_data.npz")