import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import mean_squared_error
import warnings

# Suppress convergence and future warnings for a cleaner output.
warnings.filterwarnings('ignore')

#
# == 1. DATA LOADING AND PREPARATION ==
#
try:
    # Load the dataset from the specified file path.
    rwa = pd.read_csv("../data/covid_rwanda.csv")
except FileNotFoundError:
    print("ERROR: Data file not found. Please check the file path.")
    exit()

# Ensure the 'date' column is in datetime format and set it as the DataFrame index.
rwa['date'] = pd.to_datetime(rwa['date'])
rwa.set_index('date', inplace=True)

# Filter the dataset to the relevant period.
rwa = rwa[rwa.index < '2022-03-30']

# Resample the data to a weekly frequency, taking the maximum value for each week.
# This focuses on the cumulative total deaths at the end of each week.
weekly_deaths = rwa['total_deaths'].resample('W').max().dropna()
print("Weekly deaths data loaded and prepared successfully.")


#
# == 2. TRAIN-TEST SPLIT ==
#
# Define the split date to train the model on data before the Omicron wave.
split_date = '2021-10-15'
train = weekly_deaths[weekly_deaths.index <= split_date]
test = weekly_deaths[weekly_deaths.index > split_date]

print(f"\nData split at: {split_date}")
print(f"Training set size: {len(train)} weeks | Test set size: {len(test)} weeks")


#
# == 3. SARIMAX MODEL DEFINITION AND TRAINING ==
#
# Define the SARIMAX model with baseline parameters.
# order(p,d,q): Non-seasonal components (AR, I, MA).
# seasonal_order(P,D,Q,m): Seasonal components, with m=52 for annual seasonality in weekly data.
model = sm.tsa.SARIMAX(train,
                        order=(1, 1, 1),
                        seasonal_order=(1, 1, 0, 52),
                        enforce_stationarity=False,
                        enforce_invertibility=False)

# Fit the model to the training data.
print("\nTraining the SARIMAX model...")
results = model.fit(disp=False)
print(results.summary())


#
# == 4. FORECASTING AND EVALUATION ==
#
# Generate a forecast for the same number of steps as the test set length.
forecast_object = results.get_forecast(steps=len(test))
predictions = forecast_object.predicted_mean

# Evaluate the model using Root Mean Squared Error (RMSE).
rmse = np.sqrt(mean_squared_error(test, predictions))
print(f"\n--- Model Evaluation ---")
print(f"Forecast RMSE: {rmse:.2f}")


#
# == 5. VISUALIZATION ==
#
# Get the 95% confidence intervals for the forecast.
conf_int = forecast_object.conf_int(alpha=0.05)

# Create the plot.
fig, ax = plt.subplots(figsize=(16, 9))

# Plot training data, test data, and the model's forecast.
ax.plot(train.index, train, label='Training Data (Actual)', color='#1f77b4')
ax.plot(test.index, test, label='Test Data (Actual)', color='#ff7f0e', marker='.')
ax.plot(predictions.index, predictions, label=f'SARIMAX Forecast (RMSE={rmse:.0f})', color='#2ca02c', linestyle='--')
ax.fill_between(conf_int.index, conf_int.iloc[:, 0], conf_int.iloc[:, 1], color='green', alpha=0.2, label='95% Confidence Interval')

# Configure plot titles, labels, and legend for clarity.
ax.set_title("SARIMAX: Forecasting Total Deaths for the Omicron Wave", fontsize=18, weight='bold')
ax.set_xlabel('Date', fontsize=14)
ax.set_ylabel('Total Deaths (Weekly)', fontsize=14)
ax.legend(loc='upper left', fontsize=12)
ax.grid(True, which='major', linestyle='--', linewidth=0.5)

# Set the x-axis limits to zoom in on the relevant forecast period.
ax.set_xlim(pd.to_datetime('2021-02-01'), pd.to_datetime('2022-03-30'))
ax.set_ylim(bottom=0) 

# Display the final plot.
plt.show()
# Save the figure to a file.
fig.savefig("../Plots/Forecast/total_deaths_forecast_rwanda.png", bbox_inches='tight', dpi=100)