import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

country_data = {
    "Rwanda": {"median_age": 20.8},
    "Poland": {"median_age": 42.5},
    "South Africa": {"median_age": 28.7}
}

countries = ["Rwanda", "South Africa", "Poland"]

population_by_country = {
    "Poland": np.mean([37899070.0, 36981559.0, 36821749.0, 36687353.0]),
    "Rwanda": np.mean([13065837.0, 13355260.0, 13651030.0, 13954471.0]),
    "South Africa": np.mean([60562381.0, 61502603.0, 62378410.0, 63212384.0])
}

df = pd.read_csv("../data/covid_rwanda_poland_south_africa.csv")
df = df[df["country"].isin(countries)]

def smart_mean(series):
    return series.replace(0, np.nan).mean(skipna=True)

grouped = df.groupby("country").agg({
    "total_deaths_per_million": "max",
    "total_cases_per_million": "max",
    "stringency_index": smart_mean,
    "people_fully_vaccinated": "max"
}).rename(columns={
    "total_deaths_per_million": "Total Deaths per Million",
    "total_cases_per_million": "Total Cases per Million",
    "stringency_index": "Avg. Stringency Index",
    "people_fully_vaccinated": "People Fully Vaccinated"
})

grouped["% Fully Vaccinated"] = [
    (grouped.loc[c, "People Fully Vaccinated"] / population_by_country[c]) * 100
    for c in grouped.index
]

grouped["Median Age"] = grouped.index.map(lambda c: country_data[c]["median_age"])
grouped = grouped.loc[countries]
grouped = grouped[[
    "Total Deaths per Million",
    "Total Cases per Million",
    "% Fully Vaccinated",
    "Avg. Stringency Index",
    "Median Age"
]].round(2)

header = ["Country"] + list(grouped.columns)
data_rows = []
for country in grouped.index:
    row = [country] + list(grouped.loc[country])
    data_rows.append(row)

fig, ax = plt.subplots(figsize=(12, 3))
ax.axis('off')
tbl = ax.table(
    cellText=data_rows,
    colLabels=header,
    loc='center',
    cellLoc='center',
    colColours=["#4F4F4F"] + ["#4F4F4F"] * (len(header)-1)  # Cor para o cabeçalho
)

# Estilização geral
tbl.auto_set_font_size(False)
tbl.set_fontsize(6)
tbl.scale(1, 1)

# Cor do texto do cabeçalho
for (row, col), cell in tbl.get_celld().items():
    if row == 0:
        cell.set_text_props(color='white', weight='bold')
        cell.set_linewidth(1.2)
    else:
        cell.set_linewidth(0.8)

plt.savefig("../Plots/Comparisons/table_comparison.png", bbox_inches='tight', dpi=300)
plt.show()
