import os
import pandas as pd
from prophet import Prophet 
import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib.ticker import FuncFormatter
import matplotlib.dates as mdates
import argparse

# Argument CLI: 0 = show, 1 = save
parser = argparse.ArgumentParser(description="Choose whether to display the plot (0) or save it to a file (1).")
parser.add_argument('action', type=int, choices=[0, 1], help="0 - display the plot, 1 - save to file")
args = parser.parse_args()

# Matplotlib style
rcParams['font.family'] = 'Liberation Serif' 
rcParams['font.size'] = 16                    
rcParams['text.color'] = '#2C3E50' 

# Load data
death_sw = pd.read_csv("../csv/SWEDEN-daily-cases-deaths.csv")

# Parse dates and filter
death_sw['ds'] = pd.to_datetime(death_sw['Date_reported'])
death_sw = death_sw[death_sw['ds'] <= '2022-01-14']

# Prepare for Prophet
d_sw = death_sw[['ds', 'New_cases']].copy()
d_sw = d_sw.rename(columns={'New_cases': 'y'})

# Model
m = Prophet(
    interval_width=0.95,
    daily_seasonality=False,
    weekly_seasonality=True,
    yearly_seasonality=True
)
model = m.fit(d_sw)

# Future dates & forecast
future = m.make_future_dataframe(periods=180, freq='D')
forecast = m.predict(future)

# Plot
fig = m.plot(forecast)
ax = fig.gca()
fig.subplots_adjust(top=0.93)

# Title, labels
ax.set_title("COVID-19 Daily Case Forecast – Sweden", fontsize=30)
ax.set_xlabel("Date", fontsize=25)
ax.set_ylabel("New Cases", fontsize=25)
ax.set_facecolor('#EAF4FB')

# Axis Y 
ax.yaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{x/1e3:.0f}K'))


# Axis X 
ax.xaxis.set_major_locator(mdates.YearLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
ax.xaxis.set_major_locator(mdates.MonthLocator(interval=6))  

ax.tick_params(axis='x', rotation=30)

# Grid
ax.xaxis.grid(True,'minor', color='#FFFFFF', linewidth=1)
ax.yaxis.grid(True,'minor', color='#FFFFFF', linewidth=1)
ax.xaxis.grid(True,'major', color='#FFFFFF', linewidth=2)
ax.yaxis.grid(True,'major', color='#FFFFFF', linewidth=2)

# Ticks
ax.tick_params(axis='both', which='major', labelsize=18, color="#2C3E50") 
ax.tick_params(axis='both', which='minor', labelsize=18, color="#2C3E50")

for spine in ax.spines.values():
    spine.set_visible(False)

# Show or save
if args.action == 0:
    plt.show()
else:
    filename = os.path.basename(os.path.splitext(__file__)[0] + '.png')
    images_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'images')
    os.makedirs(images_dir, exist_ok=True)
    save_path = os.path.join(images_dir, filename)
    fig.savefig(save_path, format='png', dpi=300, bbox_inches='tight', pad_inches=1.0)
    print(f"Plot saved to: {save_path}")
