import pandas as pd
import pmdarima as pm
import matplotlib.pyplot as plt
from statsmodels.tsa.statespace.sarimax import SARIMAX
import sys
from wiki_table_info_extraction import data_extraction
import matplotlib.colors as mcolors

data_file = "../data/2020-2022/weekly-confirmed-covid-19-cases_2020_2022_20250516.csv"
title = "COVID-19 Cases in Canada (2020–2022) forecast"
y_column_name = 'Weekly cases'

def main():
    df = pd.read_csv(data_file, parse_dates=['Day'])
    df = df[df['Entity'] == "Canada"]
    df.set_index('Day', inplace=True)
    y = df[y_column_name]

    test_size = int(len(y) * 0.15)
    y_train = y.iloc[:-test_size]
    y_test = y.iloc[-test_size:]


    # Auto-tune SARIMA with seasonality, parameter tuning
    model = pm.arima.auto_arima(
        y_train,
        seasonal=True,
        m=7,
        start_p=0, max_p=3,
        start_q=0, max_q=3,
        start_P=0, max_P=2,
        start_Q=0, max_Q=2,
        d=None, D=None,
        trace=True,
        error_action='ignore',
        suppress_warnings=True,
        stepwise=True
    )

    print(model.summary())

    # Use the best parameters found
    order = model.order              # e.g., (1, 1, 1)
    seasonal_order = model.seasonal_order  # e.g., (1, 1, 0, 7)

    # Best parameters model
    sarimax_model = SARIMAX(
        y_train,
        order=order,
        seasonal_order=seasonal_order,
        enforce_stationarity=False,
        enforce_invertibility=False
    )

    results = sarimax_model.fit()
    print(results.summary())

    # Forecast
    forecast = results.get_forecast(steps=test_size)
    forecast_mean = forecast.predicted_mean
    conf_int = forecast.conf_int()

    # Plot
    plt.figure(figsize=(14, 6))
    # Plot historical data
    plt.plot(y_train.index, y_train, label='Training Data', color='magenta')
    # Plot actual test data
    plt.plot(y_test.index, y_test, label='Actual Test Data', color='blue')
    # Plot forecast
    plt.plot(y_test.index, forecast_mean, label='Forecast', color='orange')
    # Confidence intervals
    plt.fill_between(y_test.index, conf_int.iloc[:, 0], conf_int.iloc[:, 1], color='orange', alpha=0.3)

    plt.title(f"Forecast vs Actual for Last {test_size} Days")
    plt.xlabel("Date")
    plt.ylabel("Daily COVID-19 Cases")
    plt.legend()

    variants_df = data_extraction("../data/original_data/covid-variants-wikipedia-table.csv").sort_values(by='Earliest sample').reset_index(drop=True)
    variants_df = variants_df[variants_df['WHO label'] != "Alpha"]
    # Generate unique, consistent colors for each variant
    cmap = plt.cm.get_cmap('tab10')
    color_map = {variant: cmap(i % 10) for i, variant in enumerate(variants_df['WHO label'].unique())}
    for i, row in variants_df.iterrows():
        variant = row['WHO label']
        early_date = row['Earliest sample']
        voc_date = row['Designated VOC']
        place = row['First outbreak']
        color = color_map[variant]
        light_color = mcolors.to_rgba(color, alpha=0.3)  # same color, lighter

        # Vertical line for earliest sample
        plt.axvline(early_date, color=color, linestyle='--', linewidth=2)
        plt.text(
            early_date,
            plt.ylim()[1] * (0.99 - 0.2 * (i % 2)),  # place text near top
            f"{variant}\n{place}\n{early_date.date()}",
            rotation=90, verticalalignment='top',
            color=color, fontsize=9
        )

        # Vertical line for VOC designation
        plt.axvline(voc_date, color=light_color, linestyle='--', linewidth=2)
        plt.text(
            voc_date,
            plt.ylim()[1] * (0.99 + 0.2 * (i % 2 - 1)),  # place text near top, +/- 15% for visibility
            f"{variant} VOC\n{voc_date.date()}",
            rotation=90, verticalalignment='top',
            color=color, fontsize=9
        )
    plt.xlim(y.index.min(), y.index.max())
    plt.ylim(0)


    # after plotting the data, format the labels
    current_values = plt.gca().get_yticks()
    plt.gca().set_yticklabels([f'{x*1.0/1_000:.0f}k' for x in current_values])


    plt.tight_layout()
    plt.grid()
    try:
        save = sys.argv[1]
        if save!='0':
            plt.savefig("../plots/cases_forecast_2020-22.png", dpi=300)
            print ("../plots/cases_forecast_2020-22.png")
        else:
            plt.show()
    except:
        print("Add argument '0' if you want to show interactively, '1' to save to file. Default is 0.")
        plt.show()

if __name__ == "__main__":
    main()
