import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

def load_and_fix_dates(path, start_date_str='2020-01-01'):
    df = pd.read_csv(path)
    start_date = pd.to_datetime(start_date_str)
    df['Days_since_start'] = df['Year']
    df['Date'] = start_date + pd.to_timedelta(df['Days_since_start'], unit='D')
    df['Year'] = df['Date'].dt.year
    return df

def prepare_yearly_data(df, country):
    df_country = df[df['Country'] == country].copy()
    df_country = df_country[df_country['Year'].between(2020, 2023)]
    yearly_cases = df_country.groupby('Year')['Weekly cases'].sum()
    return yearly_cases

def plot_comparison(df1, df2, country1, country2, save_path='plots/country_comparison.png'):
    years = [2020, 2021, 2022, 2023]
    data1 = prepare_yearly_data(df1, country1).reindex(years, fill_value=0)
    data2 = prepare_yearly_data(df2, country2).reindex(years, fill_value=0)

    x = np.arange(len(years))
    width = 0.35

    fig, ax = plt.subplots(figsize=(10,6))
    bars1 = ax.bar(x - width/2, data1, width, label=country1)
    bars2 = ax.bar(x + width/2, data2, width, label=country2)

    ax.set_xlabel('Year')
    ax.set_ylabel('Total Weekly Cases')
    ax.set_title(f'Comparison of Weekly COVID-19 Cases: {country1} vs {country2}')
    ax.set_xticks(x)
    ax.set_xticklabels(years)
    ax.legend()

    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            ax.annotate(f'{int(height):,}',
                        xy=(bar.get_x() + bar.get_width() / 2, height),
                        xytext=(0,3), textcoords='offset points',
                        ha='center', va='bottom')

    plt.tight_layout()


    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300)
    print(f'Wykres zapisany do: "{save_path}"')
    plt.close()

if __name__ == '__main__':
    df_uk = load_and_fix_dates('data/united_kingdom_covid.csv', start_date_str='2020-01-01')
    df_pl = load_and_fix_dates('data/poland_covid.csv', start_date_str='2020-01-01')

    plot_comparison(df_uk, df_pl, 'United Kingdom', 'Poland')
