import pandas as pd
import plotly.graph_objects as go

# --- Carregar e preparar os dados ---
rwa = pd.read_csv("../data/covid_rwanda.csv")
rwa['date'] = pd.to_datetime(rwa['date'])

weekly_rwa = (
    rwa.set_index('date')
       .resample('W')[['new_tests', 'new_tests_per_thousand']]
       .sum()
       .dropna()
)

# --- Construção do gráfico ---
fig = go.Figure()

# Linha azul dos testes
fig.add_trace(
    go.Scatter(
        x=weekly_rwa.index,
        y=weekly_rwa['new_tests'],
        mode='lines+markers',
        name='Weekly Tests',
        line=dict(color='blue', width=2.5),
        marker=dict(size=3, color='blue'),
        hovertemplate=(
            "<b>Date:</b> %{x|%Y-%m-%d}<br>"
            "<b>Weekly Tests:</b> %{y:,.0f}<br>"
            "<b>Tests per thousand:</b> %{customdata[0]:,.1f}<extra></extra>"
        ),
        customdata=weekly_rwa[['new_tests_per_thousand']].to_numpy(),
        hoverlabel=dict(bgcolor='white')
    )
)

# Layout
fig.update_layout(
    title=dict(
        text="<b>Weekly COVID-19 Tests in Rwanda</b>",
        x=0.5,
        font=dict(size=28, family="Arial", color="black")
    ),
    xaxis=dict(
        title=dict(
            text="<b>Date</b>",
            font=dict(size=20, family="Arial", color="black")
        ),
        tickfont=dict(size=16, family="Arial", color="black"),
        range=[pd.to_datetime('2020-03-15'), pd.to_datetime('2023-03-12')],
        tickangle=-30,
        showgrid=True,
        gridcolor="lightgrey",
        tickformat="%b %Y",
        dtick="M3",
        showline=True,
        linecolor='black',
        linewidth=1,
        mirror=False,
    ),
    yaxis=dict(
        title=dict(
            text="<b>Number of Tests (weekly)</b>",
            font=dict(size=20, family="Arial", color="black")
        ),
        tickfont=dict(size=16, family="Arial", color="black"),
        range=[0, 150000],
        showgrid=True,
        gridcolor="lightgrey",
        tickformat=",",
        showline=True,
        linecolor='black',
        linewidth=1,
        mirror=False
    ),
    legend=dict(
        font=dict(size=18),
        yanchor="top",
        y=0.95,
        xanchor="left",
        x=0.02,
    ),
    plot_bgcolor="white",
    margin=dict(l=120, r=50, t=90, b=70),
)

# Exibir
if __name__ == "__main__":
    fig.show()
    fig.write_image("../Plots/Rwanda/weekly_tests_rwanda.png", width=1200, height=800)
    fig.write_html("../Plots/Rwanda/weekly_tests_rwanda.html", include_plotlyjs="cdn")
