import pandas as pd
import matplotlib.pyplot as plt
import os
from matplotlib.ticker import MultipleLocator
import matplotlib.dates as mdates
from matplotlib import rcParams
import argparse

parser = argparse.ArgumentParser(description="Choose whether to display the plot (0) or save it to a file (1).")
parser.add_argument('action', type=int, choices=[0, 1], help="0 - display the plot, 1 - save to file")
args = parser.parse_args()


rcParams['font.family'] = 'Liberation Serif' 
rcParams['font.size'] = 13                    
rcParams['text.color'] = '#2C3E50'          


date_format = mdates.DateFormatter('%b %Y') 
nordic_countries = [
    "Sweden", "Denmark", "Finland", "Norway", "Iceland"
]

countries_per_100 = pd.read_csv("../csv/cases_per_100.csv")
countries_per_100["Date_reported"] = pd.to_datetime(countries_per_100["Date_reported"])
countries_per_100 = countries_per_100[countries_per_100["Country"].isin(nordic_countries)]

fig, axs = plt.subplots(2, 3, figsize=(11.5, 6))
fig.delaxes(axs[1, 2])


fig.text(-0.08, 0.5, 'deaths', va='center', rotation='vertical', fontsize=19)
fig.text(0.08, 1.06, 'Confirmed COVID-19 deaths per 100K People in the Nordic Region', va='center', rotation='horizontal', fontsize=22)

plt.subplots_adjust(left=0.02, right=0.98, top=0.95, bottom=0.05, wspace=0.1, hspace=0.2)

countries = countries_per_100["Country"].drop_duplicates().tolist()

colors = [
    '#C0392B',  # Sweden
    '#2980B9',  # Denmark
    '#27AE60',  # Finland
    '#F4D03F',  # Norway
    '#8E44AD'   # Iceland
]

color_idx = 0
for i, country in enumerate(countries):
    row = i // 3
    col = i % 3
    axs[row, col].set_title(country, fontsize=20)
    axs[row, col].plot(
        countries_per_100[countries_per_100["Country"] == country]["Date_reported"],
        countries_per_100[countries_per_100["Country"] == country]["deaths_per_100"],
        label=country,
        color=colors[color_idx],
        linewidth=2
    )
    color_idx += 1

for row, i in enumerate(axs):
    for col, ax in enumerate(i):
        y_down_limit = countries_per_100["deaths_per_100"].min()
        y_up_limit = countries_per_100["deaths_per_100"].max()
        ax.set_ylim([0, y_up_limit])
        ax.set_facecolor('#EAF4FB')

        ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
        ax.xaxis.set_major_locator(mdates.MonthLocator(interval=10))
        ax.xaxis.set_minor_locator(mdates.MonthLocator(interval=5))

        ax.yaxis.set_major_locator(MultipleLocator(50))
        ax.yaxis.set_minor_locator(MultipleLocator(25))

        ax.xaxis.grid(True, 'minor', color='#FFFFFF', linewidth=0.9)
        ax.yaxis.grid(True, 'minor', color='#FFFFFF', linewidth=0.9)
        ax.xaxis.grid(True, 'major', color='#FFFFFF', linewidth=1.6)
        ax.yaxis.grid(True, 'major', color='#FFFFFF', linewidth=1.6)

        ax.tick_params(axis='both', which='major', labelsize=12, color="#EAF4FB")
        ax.tick_params(axis='both', which='minor', labelsize=10, color="#EAF4FB")

        if row != 1:
            ax.tick_params(axis='x', which='both', labelbottom=False, length=0)

        if col != 0:
            ax.tick_params(axis='y', which='both', labelleft=False, length=0)

        for spine in ax.spines.values():
            spine.set_visible(False)

fig.legend(loc='upper left', bbox_to_anchor=(1, 0.98), ncol=1, fontsize=15)
fig.savefig('../images/' + os.path.basename(os.path.splitext(__file__)[0] + '.png'),
            format='png', dpi=300, bbox_inches='tight', pad_inches=1.0)
