import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_total_vaccinations_and_excess_mortality_plotly(poland_csv_path, germany_csv_path, save_path_html):
    # Population data
    population_germany = 83.28e6  # 83.28 million
    population_poland = 36.69e6  # 36.69 million

    # Read and preprocess data for Poland
    poland = pd.read_csv(poland_csv_path)
    poland['date'] = pd.to_datetime(poland['date'])
    # Calculate total vaccinations per million for Poland
    poland['total_vaccinations_per_million'] = (poland['total_vaccinations'] / population_poland) * 1e6
    # excess_mortality_cumulative_per_million is already normalized

    # Read and preprocess data for Germany
    germany = pd.read_csv(germany_csv_path)
    germany['date'] = pd.to_datetime(germany['date'])
    # Calculate total vaccinations per million for Germany
    germany['total_vaccinations_per_million'] = (germany['total_vaccinations'] / population_germany) * 1e6
    # excess_mortality_cumulative_per_million is already normalized

    # Set data cutoff at the end of 2022 and start from November 2020
    end_cutoff_date = pd.to_datetime('2022-09-01')
    start_cutoff_date = pd.to_datetime('2020-11-01')

    # Truncate data based on both start and end dates for both countries
    poland_truncated = poland[
        (poland['date'] >= start_cutoff_date) &
        (poland['date'] <= end_cutoff_date)
    ]
    germany_truncated = germany[
        (germany['date'] >= start_cutoff_date) &
        (germany['date'] <= end_cutoff_date)
    ]
    # Create subplots with a secondary y-axis
    fig = make_subplots(specs=[[{"secondary_y": True}]])

    # Add traces for Total Vaccinations per Million (primary y-axis)
    fig.add_trace(go.Scatter(
        x=poland_truncated['date'],
        y=poland_truncated['total_vaccinations_per_million'],
        mode='lines',
        name='Poland (Total Vaccinations per Million)',
        line=dict(color='blue', width=6),
        yaxis='y1' # Explicitly assign to primary y-axis
    ))

    fig.add_trace(go.Scatter(
        x=germany_truncated['date'],
        y=germany_truncated['total_vaccinations_per_million'],
        mode='lines',
        name='Germany (Total Vaccinations per Million)',
        line=dict(color='orange', width=6),
        yaxis='y1' # Explicitly assign to primary y-axis
    ))

    # Add traces for Excess Mortality Cumulative per Million (secondary y-axis)
    fig.add_trace(go.Scatter(
        x=poland_truncated['date'],
        y=poland_truncated['excess_mortality_cumulative_per_million'],
        mode='lines',
        name='Poland (Excess Mortality Cumulative per Million)',
        line=dict(color='blue', dash='dot', width=8), # Dotted line for excess mortality, increased width
        yaxis='y2' # Explicitly assign to secondary y-axis
    ))

    fig.add_trace(go.Scatter(
        x=germany_truncated['date'],
        y=germany_truncated['excess_mortality_cumulative_per_million'],
        mode='lines',
        name='Germany (Excess Mortality Cumulative per Million)',
        line=dict(color='orange', dash='dot', width=8), # Dotted line for excess mortality, increased width
        yaxis='y2' # Explicitly assign to secondary y-axis
    ))

    # Update layout for titles and axes
    fig.update_layout(
        title=dict(text='COVID-19 Vaccinations and Excess Mortality: Poland vs Germany', font=dict(size=1.8*18)), # Increased title font size
        xaxis_title='', # Removed x-axis label
        # Primary y-axis title
        yaxis=dict(
            title=dict(text='Total Vaccinations per Million', font=dict(color="blue", size=1.8*14)), # Increased y-axis title font size
            tickformat=".2s", # SI-prefix format
            tickfont=dict(color="blue", size=1.8*12) # Increased y-axis tick font size
        ),
        # Secondary y-axis title
        yaxis2=dict(
            title=dict(text='Excess Mortality Cumulative per Million', font=dict(color="purple", size=1.8*14)), # Increased y-axis2 title font size
            overlaying='y', # Overlay on the primary y-axis
            side='right', # Place on the right side
            # Removed tickformat to allow Plotly to auto-scale more effectively
            tickfont=dict(color="purple", size=1.8*12) # Increased y-axis2 tick font size
        ),
        xaxis=dict(
            tickformat='%Y-%m',
            range=[start_cutoff_date, end_cutoff_date], # Set x-axis range
            tickfont=dict(size=1.8*12) # Increased x-axis tick font size
        ),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=-0.3, # Adjusted y position to put it at the bottom
            xanchor="center",
            x=0.5,
            font_size=22

        )
    )

    fig.write_html(save_path_html)
    fig.show()


# plot_total_vaccinations_and_excess_mortality_plotly(
#                           'data/poland_whole.csv',
#                           'data/germany_whole.csv',
#                           'images/plotly_vaccinations_excess_mortality_comparison.html'
# )
