import pandas as pd
import plotly.graph_objects as go

# --- Carregar e preparar os dados ---
rwa = pd.read_csv("../data/covid_rwanda.csv")
rwa['date'] = pd.to_datetime(rwa['date'])

weekly_rwa = (
    rwa.set_index('date')
       .resample('W')[['total_cases', 'total_cases_per_million']]
       .max()
       .dropna()
)

# --- Períodos das variantes ---
ancestral_start   = pd.to_datetime('2020-08-02')
ancestral_end     = pd.to_datetime('2020-09-27')
alpha_beta_start  = pd.to_datetime('2020-12-06')
alpha_beta_end    = pd.to_datetime('2021-03-28')
delta_start       = pd.to_datetime('2021-05-30')
delta_end         = pd.to_datetime('2021-10-24')
omicron_start     = pd.to_datetime('2021-12-01')
omicron_end       = pd.to_datetime('2022-02-28')

peak_ancestral   = weekly_rwa['total_cases'].loc[ancestral_start : ancestral_end].max()
peak_alpha_beta  = weekly_rwa['total_cases'].loc[alpha_beta_start : alpha_beta_end].max()
peak_delta       = weekly_rwa['total_cases'].loc[delta_start : delta_end].max()
peak_omicron     = weekly_rwa['total_cases'].loc[omicron_start : omicron_end].max()

# --- Construção do gráfico ---
fig = go.Figure()

# Linha azul dos casos
fig.add_trace(
    go.Scatter(
        x=weekly_rwa.index,
        y=weekly_rwa['total_cases'],
        mode='lines+markers',
        name='Total Cases',
        line=dict(color='blue', width=2.5),
        marker=dict(size=3, color='blue'),
        hovertemplate=(
            "<b>Date:</b> %{x|%Y-%m-%d}<br>"
            "<b>Total Cases:</b> %{y:,.0f}<br>"
            "<b>Total Cases per million:</b> %{customdata[0]:,.1f}<extra></extra>"
        ),
        customdata=weekly_rwa[['total_cases_per_million']].to_numpy(),
        hoverlabel=dict(bgcolor='white')
    )
)

# Faixas de variantes
fig.add_shape(type="rect", xref="x", yref="paper", x0=ancestral_start, x1=ancestral_end, y0=0, y1=1, fillcolor="#b3cde3", opacity=0.15, line_width=0)
fig.add_shape(type="rect", xref="x", yref="paper", x0=alpha_beta_start, x1=alpha_beta_end, y0=0, y1=1, fillcolor="#00a5e1", opacity=0.10, line_width=0)
fig.add_shape(type="rect", xref="x", yref="paper", x0=delta_start, x1=delta_end, y0=0, y1=1, fillcolor="orange", opacity=0.15, line_width=0)
fig.add_shape(type="rect", xref="x", yref="paper", x0=omicron_start, x1=omicron_end, y0=0, y1=1, fillcolor="yellow", opacity=0.15, line_width=0)

# Anotações
fig.add_annotation(
    x='2020-08-30', y=peak_ancestral + peak_ancestral * 0.05,
    text="Ancestral Lineages<br>Aug–Sep 2020", showarrow=True, arrowhead=2, ax=0, ay=-40,
    font=dict(size=18, family="Arial", color="black"),
    bgcolor="white", bordercolor="black", borderwidth=1
)
fig.add_annotation(
    x=alpha_beta_start + (alpha_beta_end - alpha_beta_start) / 2,
    y=peak_alpha_beta + peak_alpha_beta * 0.05,
    text="Alpha / Beta variants<br>Dec 2020 – Mar 2021", showarrow=True, arrowhead=2, ax=0, ay=-40,
    font=dict(size=18, family="Arial", color="black"),
    bgcolor="white", bordercolor="black", borderwidth=1
)
fig.add_annotation(
    x='2021-07-25', y=peak_delta + peak_delta * 0.02,
    text="Delta variant<br>May – Oct 2021", showarrow=True, arrowhead=2, ax=0, ay=-40,
    font=dict(size=18, family="Arial", color="black"),
    bgcolor="white", bordercolor="black", borderwidth=1
)
fig.add_annotation(
    x='2022-01-02', y=peak_omicron + peak_omicron * 0.02,
    text="Omicron variant<br>Dec 2021 – Feb 2022", showarrow=True, arrowhead=2, ax=0, ay=-40,
    font=dict(size=18, family="Arial", color="black"),
    bgcolor="white", bordercolor="black", borderwidth=1
)

# Layout
fig.update_layout(
    title=dict(
        text="<b> Total COVID-19 Cases in Rwanda</b>",
        x=0.5,
        font=dict(size=28, family="Arial", color="black")
    ),
    xaxis=dict(
        title=dict(
            text="<b>Date</b>",
            font=dict(size=20, family="Arial", color="black")
        ),
        tickfont=dict(size=16, family="Arial", color="black"),
        range=[pd.to_datetime('2020-03-15'), pd.to_datetime('2023-03-12')],
        tickangle=-30,
        showgrid=True,
        gridcolor="lightgrey",
        tickformat="%b %Y",
        dtick="M3",
        showline=True,
        linecolor='black',
        linewidth=1,
        mirror=False,
    ),
    yaxis=dict(
        title=dict(
            text="<b>Number of total Cases </b>",
            font=dict(size=20, family="Arial", color="black")
        ),
        tickfont=dict(size=16, family="Arial", color="black"),
        range=[0, 160000],
        showgrid=True,
        gridcolor="lightgrey",
        tickformat=",",
        showline=True,
        linecolor='black',
        linewidth=1,
        mirror=False
    ),
    legend=dict(
        font=dict(size=18),
        yanchor="top",
        y=0.95,
        xanchor="left",
        x=0.02,
    ),
    plot_bgcolor="white",
    margin=dict(l=120, r=50, t=90, b=70),
)

# Exibir
if __name__ == "__main__":
    fig.show()
    fig.write_image("../Plots/Rwanda/total_cases_rwanda.png", width=1200, height=800)
    fig.write_html("../Plots/Rwanda/total_cases_rwanda.html", include_plotlyjs="cdn")

