import pandas as pd
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split

def analyze_feature_importance(early_df, inflation_df, unemployment_df, where_to_save, exclude_lags=False):
    # Clean up the data
    if 'index' in early_df.columns:
        early_df = early_df.drop(columns='index')

    inflation_df.columns = ['date', 'time_period', 'inflation']
    unemployment_df.columns = ['date', 'time_period', 'unemployment']

    # Merge inflation and unemployment data
    temp = inflation_df.merge(unemployment_df, on='date', how='inner')

    # Merge all data and fill missing values
    df = early_df.merge(temp, on='date', how='left').fillna(0)
    df = df[df['date'] <= '2024-04-01']

    # Define columns to use
    cols_to_use = [
        'new_cases', 'weekly_hosp_admissions', 'icu_patients', 'weekly_icu_admissions',
        'stringency_index', 'reproduction_rate', 'total_tests', 'positive_rate',
        'tests_per_case', 'total_vaccinations', 'people_vaccinated',
        'people_fully_vaccinated', 'month', 'cases_lag_1',
        'cases_lag_2', 'cases_lag_3', 'cases_lag_7', 'cases_lag_14',
        'cases_lag_21', 'cases_lag_28', 'inflation', 'unemployment'
    ]
    df = df[cols_to_use]

    # Prepare features and target
    if exclude_lags:
        # Exclude columns with 'lag' in their names
        features = [col for col in df.columns if col != 'new_cases' and 'lag' not in col]
        X = df[features]
    else:
        X = df.drop("new_cases", axis=1)

    y = df["new_cases"]

    # Train/test split
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # Train the model
    model = RandomForestRegressor(random_state=42)
    model.fit(X_train, y_train)

    # Get feature importances
    importances = model.feature_importances_
    feature_names = X.columns
    feat_importances = pd.Series(importances, index=feature_names).sort_values(ascending=False)


    plt.figure(figsize=(10, 6))
    feat_importances.plot(kind="bar")
    plt.title("Feature Importances from RandomForestRegressor")
    plt.ylabel("Importance")
    plt.tight_layout()
    plt.savefig(where_to_save)
    plt.show()

    return feat_importances
