import pandas as pd
import plotly.graph_objects as go

# Carregar dados
rwa = pd.read_csv("../data/covid_rwanda.csv")
pop = pd.read_csv("../data/population_rwanda.csv")

rwa['date'] = pd.to_datetime(rwa['date'])
population_map = dict(zip(pop['Year'], pop['Population']))
rwa['year'] = rwa['date'].dt.year
rwa['population'] = rwa['year'].map(population_map)

# Backfill
vaccination_cols = ['people_vaccinated', 'people_fully_vaccinated', 'total_boosters']
for col in vaccination_cols:
    first_valid_index = (rwa[col].fillna(0) > 0).idxmax()
    if rwa.loc[first_valid_index, col] > 0:
        rwa.loc[first_valid_index:, col] = rwa.loc[first_valid_index:, col].bfill()

# Corte
start_date = pd.to_datetime('2021-03-04')
end_date = pd.to_datetime('2022-12-31')
rwa = rwa[(rwa['date'] >= start_date) & (rwa['date'] <= end_date)]

# Porcentagens
rwa['perc_first'] = rwa['people_vaccinated'] / rwa['population'] * 100
rwa['perc_full'] = rwa['people_fully_vaccinated'] / rwa['population'] * 100
rwa['perc_boost'] = rwa['total_boosters'] / rwa['population'] * 100

x_range = [start_date, pd.to_datetime('2023-01-02')]
y_max = max(rwa['people_vaccinated'].max(), rwa['people_fully_vaccinated'].max(), rwa['total_boosters'].max())
y_range = [0, y_max * 1.05]

# Datas semanais
weekly_dates = rwa['date'].dropna().sort_values().unique()[::7]

# Frames animados
frames = []
for date in weekly_dates:
    current_data = rwa[rwa['date'] <= date]
    frames.append(go.Frame(
        name=str(date),
        data=[
            go.Scatter(
                x=current_data['date'],
                y=current_data['people_vaccinated'],
                mode='lines+markers',
                name='People with First Dose',
                line=dict(color='blue', width=2.5),
                marker=dict(size=3),
                hovertemplate=(
                    "<b>Date:</b> %{x|%Y-%m-%d}<br>" +
                    "<b>People:</b> %{y:,}<br>" +
                    "<b>% of Population:</b> %{customdata:.2f}%<extra></extra>"
                ),
                customdata=current_data['perc_first']
            ),
            go.Scatter(
                x=current_data['date'],
                y=current_data['people_fully_vaccinated'],
                mode='lines+markers',
                name='People Fully Vaccinated',
                line=dict(color='green', width=2.5),
                marker=dict(size=3),
                hovertemplate=(
                    "<b>Date:</b> %{x|%Y-%m-%d}<br>" +
                    "<b>People:</b> %{y:,}<br>" +
                    "<b>% of Population:</b> %{customdata:.2f}%<extra></extra>"
                ),
                customdata=current_data['perc_full']
            ),
            go.Scatter(
                x=current_data['date'],
                y=current_data['total_boosters'],
                mode='lines+markers',
                name='Booster Doses',
                line=dict(color='orange', width=2.5),
                marker=dict(size=3),
                hovertemplate=(
                    "<b>Date:</b> %{x|%Y-%m-%d}<br>" +
                    "<b>People:</b> %{y:,}<br>" +
                    "<b>% of Population:</b> %{customdata:.2f}%<extra></extra>"
                ),
                customdata=current_data['perc_boost']
            )
        ]
    ))

# Traços iniciais
fig = go.Figure(
    data=[
        go.Scatter(
            x=rwa['date'], y=rwa['people_vaccinated'],
            mode='lines+markers', name='People with First Dose',
            line=dict(color='blue', width=2.5), marker=dict(size=3),
            customdata=rwa['perc_first'],
            hovertemplate="<b>Date:</b> %{x|%Y-%m-%d}<br><b>People:</b> %{y:,}<br><b>% of Population:</b> %{customdata:.2f}%<extra></extra>"
        ),
        go.Scatter(
            x=rwa['date'], y=rwa['people_fully_vaccinated'],
            mode='lines+markers', name='People Fully Vaccinated',
            line=dict(color='green', width=2.5), marker=dict(size=3),
            customdata=rwa['perc_full'],
            hovertemplate="<b>Date:</b> %{x|%Y-%m-%d}<br><b>People:</b> %{y:,}<br><b>% of Population:</b> %{customdata:.2f}%<extra></extra>"
        ),
        go.Scatter(
            x=rwa['date'], y=rwa['total_boosters'],
            mode='lines+markers', name='Booster Doses',
            line=dict(color='orange', width=2.5), marker=dict(size=3),
            customdata=rwa['perc_boost'],
            hovertemplate="<b>Date:</b> %{x|%Y-%m-%d}<br><b>People:</b> %{y:,}<br><b>% of Population:</b> %{customdata:.2f}%<extra></extra>"
        )
    ],
    frames=frames
)

# Layout
fig.update_layout(
    title=dict(
        text="<b>COVID-19 Vaccination Progress in Rwanda</b>",
        x=0.5,
        font=dict(size=28)
    ),
    xaxis=dict(
        title=dict(text="<b>Date</b>", font=dict(size=20)),
        tickfont=dict(size=14),
        range=x_range,
        fixedrange=True,
        tickangle=-30,
        tickformat="%b %Y",
        dtick="M3",
        showgrid=True,
        gridcolor="lightgrey",
        showline=True,
        linecolor='black'
    ),
yaxis=dict(
    title=dict(text="<b>Number of People</b>", font=dict(size=20)),
    tickfont=dict(size=14),
    range=y_range,
    fixedrange=True,
    tickformat=",",
    showgrid=True,
    gridcolor="lightgrey",
    showline=True,
    linecolor='black'
),
hoverlabel=dict(
    font=dict(size=19),
),
legend=dict(font=dict(size=16)),
plot_bgcolor="white",
margin=dict(l=80, r=40, t=80, b=160),
updatemenus=[dict(
        type="buttons",
        showactive=False,
        buttons=[
            dict(label="Play", method="animate", args=[None, {
                "frame": {"duration": 100, "redraw": False},
                "fromcurrent": True,
                "transition": {"duration": 0}
            }]),
            dict(label="Pause", method="animate", args=[[None], {
                "frame": {"duration": 0, "redraw": False},
                "mode": "immediate",
                "transition": {"duration": 0}
            }])
        ],
        x=0, y=-0.32, xanchor="left", yanchor="bottom",
        font=dict(size=3)  # Reduz o tamanho do botão
    )],
    sliders=[dict(
        active=0,
        steps=[
            dict(method="animate", args=[[str(date)], {
                "frame": {"duration": 0, "redraw": True},
                "mode": "immediate"
            }], label=pd.to_datetime(date).strftime('%b-%Y')) for date in weekly_dates
        ],
        x=0, y=-0.38, xanchor="left", yanchor="bottom",
        currentvalue=dict(prefix="Date: ", font=dict(size=14)),
        transition=dict(duration=0),
        font=dict(size=12)
    )]
)

if __name__ == "__main__":
    fig.show()
    fig.write_html("../Plots/Rwanda/vaccination_rwanda.html", include_plotlyjs="cdn")
    fig.write_image("../Plots/Rwanda/vaccination_rwanda.png", scale=2, width=1200, height=800)
