import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import os

# -- 1. DATA LOADING AND PREPARATION --
try:
    file_path = "../data/rwanda_Economy.csv"
    df = pd.read_csv(file_path)
except FileNotFoundError:
    print(f"ERROR: File not found at '{file_path}'. Please check the name and path.")
    exit()

df['ds'] = pd.to_datetime(df['ds'])
print("Activity dataset loaded successfully.")

# Resample data monthly to smooth the lines
monthly_sector_activity = df.groupby('business_vertical').resample('M', on='ds')['activity_percentage'].mean().reset_index()

# List of selected sectors
selected_sectors = [
    'Restaurants',
    'Travel',
    'Local Events',
    'Public Good'
]

focused_data = monthly_sector_activity[monthly_sector_activity['business_vertical'].isin(selected_sectors)]
print(f"\nAggregating data monthly for sectors: {selected_sectors}")

# Color map for each sector
color_map = {
    'Restaurants': '#1f77b4',
    'Travel': '#ff7f0e',
    'Local Events': '#2ca02c',
    'Public Good': '#9467bd'
}

# --- 2. CREATE THE PLOT WITH SUBPLOTS ---
fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=selected_sectors,
    vertical_spacing=0.25,  # <-- Aumentado
    horizontal_spacing=0.08
)

for i, sector in enumerate(selected_sectors):
    row = i // 2 + 1
    col = i % 2 + 1

    sector_data = focused_data[focused_data['business_vertical'] == sector]

    fig.add_trace(go.Scatter(
        x=sector_data['ds'],
        y=sector_data['activity_percentage'],
        mode='lines',
        name=sector,
        line=dict(color=color_map[sector], width=2.5),
        showlegend=False,
        hovertemplate=(
            "<b>" + sector + "</b><br>" +
            "Date: %{x|%b %Y}<br>" +
            "Activity: %{y:.1f}%" +
            "<extra></extra>"
        )
    ), row=row, col=col)

# --- 3. CONFIGURE LAYOUT AND ADD ANNOTATIONS ---
for i in range(1, 5):
    row = (i - 1) // 2 + 1
    col = (i - 1) % 2 + 1

    fig.add_hline(y=100, line_width=1, line_dash="dash", line_color="black", row=row, col=col)

    fig.add_vrect(x0='2020-03-21', x1='2020-05-04', fillcolor="red", opacity=0.3, layer="below", line_width=0, row=row, col=col)
    fig.add_vrect(x0='2021-01-18', x1='2021-02-07', fillcolor="purple", opacity=0.3, layer="below", line_width=0, row=row, col=col)
    fig.add_vrect(x0='2021-07-17', x1='2021-07-31', fillcolor="sienna", opacity=0.3, layer="below", line_width=0, row=row, col=col)

# Add invisible traces to create the legend for lockdowns
fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', name='1st National Lockdown',
                         marker=dict(color='red', size=10, symbol='square', opacity=0.5)))
fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', name='Lockdown in Kigali',
                         marker=dict(color='purple', size=10, symbol='square', opacity=0.5)))
fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', name='Delta Wave Lockdown',
                         marker=dict(color='sienna', size=10, symbol='square', opacity=0.5)))

# Layout
fig.update_layout(
    title=dict(
        text="<b>Economic Activity per Sector</b>",
        x=0.5,
        font=dict(size=28, family="Arial", color="black")
    ),
    plot_bgcolor="white",
    margin=dict(l=80, r=50, t=160, b=120),  # <-- Mais espaço embaixo
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.05,
        xanchor="center",
        x=0.5,
        font=dict(family="Arial", size=12, color="black"),
        bgcolor="rgba(255,255,255,0.8)",
        bordercolor="Black",
        borderwidth=1
    )
)

# Y-axis range calculation
y_min = focused_data['activity_percentage'].min()
y_max = focused_data['activity_percentage'].max()
y_padding = (y_max - y_min) * 0.1

# Axes style
fig.update_xaxes(
    title_text=None,
    tickfont=dict(size=10, family="Arial", color="black"),  # <-- Reduzido
    showgrid=True,
    gridcolor="lightgrey",
    tickformat="%b %Y",  # <-- Exibe mês e ano
    tickangle=-45,
    showline=True,
    linecolor='black',
    linewidth=1
)

fig.update_yaxes(
    title_text="Activity (%)",
    title_font=dict(size=16, family="Arial", color="black"),
    tickfont=dict(size=12, family="Arial", color="black"),
    showgrid=True,
    gridcolor="lightgrey",
    showline=True,
    linecolor='black',
    linewidth=1,
    range=[y_min - y_padding, y_max + y_padding]
)

# Subplot title fonts
for annotation in fig['layout']['annotations']:
    annotation['font'] = dict(size=16, family="Arial", color="black")

# --- 4. SHOW AND SAVE ---
if __name__ == "__main__":
    output_dir = "../Plots/Economy"
    os.makedirs(output_dir, exist_ok=True)
    
    output_path = os.path.join(output_dir, "economic_activity_per_sector.html")

    fig.show()
    fig.write_html(output_path, include_plotlyjs="cdn")
    print(f"\nFinal interactive graph saved successfully to: {output_path}")
    fig.write_image(os.path.join(output_dir, "economic_activity_per_sector.png"), width=1200, height=800)