import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error, mean_absolute_error
import warnings
warnings.filterwarnings("ignore")

def forecast_plot(df, show_or_save, where_to_save='images/sarima_early.png'):
    # Prepare data
    df['date'] = pd.to_datetime(df['date'])
    df = df.set_index('date')
    y = df['total_cases'].asfreq('D').fillna(0)

    # 7-day moving average (with a 7-day lag)
    y_rolling = y.shift(7).rolling(window=7).mean()

    # Train/test split
    train = y_rolling.iloc[:-40]
    test = y_rolling.iloc[-40:]

    # Build and fit the SARIMA(1,1,0)(2,1,1)[7] model
    model = SARIMAX(train,
                    order=(1, 1, 0),
                    seasonal_order=(2, 1, 1, 7),
                    enforce_stationarity=False,
                    enforce_invertibility=False,
                    trend='c')  # 'c' = intercept/const
    results = model.fit(disp=False)

    # Forecast
    forecast = results.get_forecast(steps=40)
    predicted_mean = forecast.predicted_mean
    conf_int = forecast.conf_int()
    
    # Evaluation
    rmse = np.sqrt(mean_squared_error(test, predicted_mean))
    mae = mean_absolute_error(test, predicted_mean)
    percent_sd = round(rmse * 100 / np.var(test) ** 0.5)
    print(f"RMSE: {rmse:.2f} (231k) --> {percent_sd}% of test standard deviation")
    print(f"MAE: {mae:.2f} (185k)")

    # Plot
    train_last = train[-40:]

    plt.figure(figsize=(12, 6))
    plt.rcParams.update({ 'font.size': 16 })
    plt.plot(train_last.index, train_last, label='Train - real data')
    plt.plot(test.index, test, label='Test - real data')
    plt.plot(test.index, predicted_mean, label=f'Forecast (predicted), RMSE=200k ≈ {percent_sd}% of total sdev of test data', color='green')
    plt.fill_between(test.index,
                     conf_int.iloc[:, 0],
                     conf_int.iloc[:, 1],
                     color='green',
                     alpha=0.3,
                     label='95% Confidence Interval')
    plt.title('SARIMA(1,1,0)(2,1,1)[7] – 40-day Forecast of Total COVID-19 Cases')
    plt.xlabel("Date")
    plt.ylabel("Number of Cases")
    plt.legend()
    plt.xticks(rotation=45)

    # Format Y-axis to show millions
    from matplotlib.ticker import FuncFormatter
    def millions(x, pos):
        return f'{x * 1e-6:.1f}M'
    plt.gca().yaxis.set_major_formatter(FuncFormatter(millions))

    plt.tight_layout()
    if show_or_save == 0:
        plt.show()

    elif show_or_save == 1:
        try:
            plt.savefig(where_to_save)
            print(f"Plot successfully saved to: {where_to_save}")
        except FileNotFoundError:
            print(f"Error: The directory for saving the plot was not found: {where_to_save}")
        except PermissionError:
            print(f"Error: Permission denied when trying to save the plot to: {where_to_save}")
        except Exception as e:
            print(f"Unexpected error while saving the plot: {e}")

    else:
        print("sarima plot for early data requires args: DataFrame, show_or_save (0-show/1-save), path_to_save (optional)")
