import pandas as pd
import plotly.graph_objects as go

# Carregar os dados
df = pd.read_csv("../data/covid_rwanda_poland_south_africa.csv")
df['date'] = pd.to_datetime(df['date'])

# Filtrar os países desejados
countries = ['Rwanda', 'Poland', 'South Africa']
df = df[df['country'].isin(countries)]

# Filtrar o intervalo de datas
start_date = pd.to_datetime('2020-03-01')
end_date = pd.to_datetime('2023-06-18')
df = df[(df['date'] >= start_date) & (df['date'] <= end_date)]

# Datas semanais para animação
weekly_dates = df['date'].dropna().sort_values().unique()[::7]

# Cores para os países
color_map = {
    'Rwanda': 'blue',
    'Poland': 'red',
    'South Africa': 'green'
}

# Criar frames da animação
frames = []
for date in weekly_dates:
    frame_data = []
    current_data = df[df['date'] <= date]
    for country in countries:
        country_data = current_data[current_data['country'] == country]
        frame_data.append(
            go.Scatter(
                x=country_data['date'],
                y=country_data['stringency_index'],
                mode='lines+markers',
                name=country,
                line=dict(color=color_map[country], width=2.5),
                marker=dict(size=3),
                customdata=[[country] for _ in range(len(country_data))],
                hovertemplate=(
                    "<b>%{customdata[0]}</b><br>" +
                    "%{x|%d %b %Y}<br>" +
                    "Stringency Index: %{y:.1f}" +
                    "<extra></extra>"
                ),
                hoverlabel=dict(font=dict(size=14))
            )
        )
    frames.append(go.Frame(name=str(date), data=frame_data))

# Traços iniciais
initial_data = []
for country in countries:
    country_data = df[df['country'] == country]
    initial_data.append(
        go.Scatter(
            x=country_data['date'],
            y=country_data['stringency_index'],
            mode='lines+markers',
            name=country,
            line=dict(color=color_map[country], width=2.5),
            marker=dict(size=3),
            customdata=[[country] for _ in range(len(country_data))],
            hovertemplate=(
                "<b>%{customdata[0]}</b><br>" +
                "%{x|%d %b %Y}<br>" +
                "Stringency Index: %{y:.1f}" +
                "<extra></extra>"
            ),
            hoverlabel=dict(font=dict(size=14))
        )
    )

# Layout
fig = go.Figure(
    data=initial_data,
    frames=frames
)

fig.update_layout(
    title=dict(
        text="<b>COVID-19 Stringency Index Over Time</b>",
        x=0.5,
        font=dict(size=26)
    ),
    xaxis=dict(
        title=dict(text="<b>Date</b>", font=dict(size=20)),
        tickfont=dict(size=14),
        range=[start_date, end_date],
        tickformat="%b %Y",
        dtick="M2",
        showgrid=True,
        gridcolor="lightgrey"
    ),
    yaxis=dict(
        title=dict(text="<b>Stringency Index</b>", font=dict(size=20)),
        tickfont=dict(size=14),
        range=[0, df['stringency_index'].max() * 1.05],
        showgrid=True,
        gridcolor="lightgrey"
    ),
    plot_bgcolor="white",
    margin=dict(l=80, r=40, t=80, b=120),
    updatemenus=[dict(
        type="buttons",
        showactive=False,
        buttons=[
            dict(label="▶", method="animate", args=[None, {
                "frame": {"duration": 100, "redraw": False},
                "fromcurrent": True,
                "transition": {"duration": 0}
            }]),
            dict(label="❚❚", method="animate", args=[[None], {
                "frame": {"duration": 0, "redraw": False},
                "mode": "immediate",
                "transition": {"duration": 0}
            }])
        ],
        x=0, y=-0.35, xanchor="left", yanchor="bottom",
        direction="left",
        pad=dict(r=10, t=10),
        font=dict(size=14)
    )],
    sliders=[dict(
        active=0,
        steps=[
            dict(method="animate", args=[[str(date)], {
                "frame": {"duration": 0, "redraw": True},
                "mode": "immediate"
            }], label=pd.to_datetime(date).strftime('%b-%Y')) for date in weekly_dates
        ],
        x=0, y=-0.42, xanchor="left", yanchor="bottom",
        currentvalue=dict(prefix="Date: ", font=dict(size=14)),
        transition=dict(duration=0),
        font=dict(size=12),
        len=1.0
    )]
)

if __name__ == "__main__":
    fig.show()
    fig.write_html("../Plots/Comparisons/covid_stringency_index_comparison.html", include_plotlyjs="cdn")
    fig.write_image("../Plots/Comparisons/covid_stringency_index_comparison.png", width=1200, height=800, scale=2)
