import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def make_plot5():
    # Population data
    population_germany = 83.28e6    # 83.28 million
    population_poland = 36.69e6      # 36.69 million
    
    poland = pd.read_csv('data/poland_whole.csv')
    poland['date'] = pd.to_datetime(poland['date'])
    germany = pd.read_csv('data/germany_whole.csv')
    germany['date'] = pd.to_datetime(germany['date'])
    
    # Normalize new cases per million
    poland['new_cases_per_million'] = (poland['new_cases'] / population_poland) * 1e6
    germany['new_cases_per_million'] = (germany['new_cases'] / population_germany) * 1e6
    
    # Set cutoff date for COVID-19 data
    cutoff_date = pd.to_datetime('2024-04-30')
    poland_truncated = poland[poland['date'] <= cutoff_date]
    germany_truncated = germany[germany['date'] <= cutoff_date]
    
    # --- Filter out zero new_cases_per_million to avoid histogram-like appearance ---
    poland_truncated = poland_truncated[poland_truncated['new_cases_per_million'] > 0]
    germany_truncated = germany_truncated[germany_truncated['new_cases_per_million'] > 0]
    
    # Create the main figure with a secondary y-axis
    fig = make_subplots(specs=[[{"secondary_y": True}]])
    
    # Add traces for COVID-19 new cases per million (primary y-axis)
    fig.add_trace(go.Scatter(x=poland_truncated['date'], y=poland_truncated['new_cases_per_million'],
                             mode='lines', name='Poland COVID-19 Cases', line=dict(color='blue', width=4, dash='solid')),
                  secondary_y=False)
    fig.add_trace(go.Scatter(x=germany_truncated['date'], y=germany_truncated['new_cases_per_million'],
                             mode='lines', name='Germany COVID-19 Cases', line=dict(color='orange', width=4, dash='solid')),
                  secondary_y=False)
    
    
    # --- Function for loading and formatting economic data ---
    
    # Define the date range for filtering economic data
    economic_start_date = pd.to_datetime('2020-01-01')
    economic_end_date = pd.to_datetime('2024-04-30') # April 2024
    
    def load_and_format_economic_data(file_path, start_date, end_date):
        """
        Loads economic data from a CSV, formats it, and filters by date range.
    
        Args:
            file_path (str): The path to the CSV file.
            start_date (pd.Timestamp): The start date for filtering.
            end_date (pd.Timestamp): The end date for filtering.
    
        Returns:
            pd.DataFrame: A DataFrame with 'date' and 'rate' columns, filtered by date.
        """
        df_economic = pd.read_csv(file_path, usecols=[0, 2])
        df_economic.columns = ['date', 'rate']
        df_economic['date'] = pd.to_datetime(df_economic['date'])
        df_economic = df_economic[(df_economic['date'] >= start_date) &
                                        (df_economic['date'] <= end_date)]
        return df_economic
    
    # Load and format Poland unemployment data using the function
    poland_unemployment_df = load_and_format_economic_data("data/poland_unemployment.csv", economic_start_date, economic_end_date)
    
    # Load and format Germany unemployment data using the function
    germany_unemployment_df = load_and_format_economic_data("data/germany_unemployment.csv", economic_start_date, economic_end_date)
    
    # Load and format Poland inflation data using the function
    poland_inflation_df = load_and_format_economic_data("data/poland_inflation.csv", economic_start_date, economic_end_date)
    
    # Load and format Germany inflation data using the function
    germany_inflation_df = load_and_format_economic_data("data/germany_inflation.csv", economic_start_date, economic_end_date)
    
    # --- Add economic data to the Plotly figure ---
    
    # Add trace for Poland Unemployment (secondary y-axis, using original rates)
    fig.add_trace(go.Scatter(x=poland_unemployment_df['date'], y=poland_unemployment_df['rate'],
                             mode='lines', name='Poland Unemployment Rate',
                             line=dict(color='blue', dash='dash', width=2)), # Reduced line width
                  secondary_y=True)
    
    # Add trace for Germany Unemployment (secondary y-axis, using original rates)
    fig.add_trace(go.Scatter(x=germany_unemployment_df['date'], y=germany_unemployment_df['rate'],
                             mode='lines', name='Germany Unemployment Rate',
                             line=dict(color='orange', dash='dash', width=2)), # Reduced line width
                  secondary_y=True)
    # Add trace for Poland Inflation (new third y-axis, scaled for plotting, original for hover)
    fig.add_trace(go.Scatter(x=poland_inflation_df['date'], y=poland_inflation_df['rate'] * 600,
                             mode='markers', name='Poland Inflation Rate',
                             marker=dict(color='blue', size=8),
                             yaxis='y3', # Assign to the third y-axis
                             customdata=poland_inflation_df['rate'], # Original rate for hover
                             hovertemplate='Date: %{x}<br>Inflation Rate: %{customdata:.2f}%<extra></extra>'),
                  secondary_y=False) # This trace uses yaxis='y3', which is overlaid on y
    
    # Add trace for Germany Inflation (new third y-axis, scaled for plotting, original for hover)
    fig.add_trace(go.Scatter(x=germany_inflation_df['date'], y=germany_inflation_df['rate'] * 600,
                             mode='markers', name='Germany Inflation Rate',
                             marker=dict(color='orange', size=8),
                             yaxis='y3', # Assign to the third y-axis
                             customdata=germany_inflation_df['rate'], # Original rate for hover
                             hovertemplate='Date: %{x}<br>Inflation Rate: %{customdata:.2f}%<extra></extra>'),
                  secondary_y=False) # This trace uses yaxis='y3', which is overlaid on y
    
    # --- Update layout for the combined plot with dual y-axes ---
    
    # Calculate min/max for unemployment rates to set secondary y-axis range
    min_unemployment_rate = min(poland_unemployment_df['rate'].min(), germany_unemployment_df['rate'].min())
    max_unemployment_rate = max(poland_unemployment_df['rate'].max(), germany_unemployment_df['rate'].max())
    
    # Calculate the upper bound for the secondary y-axis so max_unemployment_rate is at 60%
    # The lowest value is the floor of the minimum unemployment rate
    yaxis2_lower_bound = int(min_unemployment_rate)
    yaxis2_upper_bound = max_unemployment_rate / 0.6 if max_unemployment_rate > 0 else 10 # Avoid division by zero
    
    fig.update_layout(
        title='COVID-19 Cases and Economic Indicators: Poland vs Germany',
        title_font_size=30, # Reduced title font size
        xaxis_title='', # Removed x-axis label
        yaxis_title='New Cases per Million / Inflation Rate (%)', # Primary Y-axis
        yaxis_title_font_size=26, # Reduced y-axis title font size
        yaxis2_title='Unemployment Rate (%)', # Secondary Y-axis
        yaxis2_title_font_size=26, # Reduced y2-axis title font size
        yaxis2=dict(
            side='right', # Place secondary y-axis on the right
            overlaying='y', # Overlay on the primary y-axis
            range=[yaxis2_lower_bound, yaxis2_upper_bound], # Set custom range
            showgrid=False, # Optional: hide grid for secondary axis
            tickfont_size=20 # Reduced tick font size for secondary y-axis
        ),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=-0.2, # Adjusted y position to put it at the bottom
            xanchor="center",
            x=0.5,
            font_size=22 # Reduced legend font size
        ),
        xaxis=dict(
            rangeselector=dict(
                buttons=list([
                    dict(count=6, label="6m", step="month", stepmode="backward"),
                    dict(count=1, label="1y", step="year", stepmode="backward"),
                    dict(step="all")
                ])
            ),
            range=[min(poland_truncated['date'].min(), economic_start_date), cutoff_date],
            tickfont_size=20 # Reduced x-axis tick font size
        ),
        margin=dict(l=80, r=80, b=120, t=100), # Adjusted margins
        hovermode='x unified',
        font_size=20 # Reduced global font size
    )
    
    fig.write_html('images/plotly_poland_vs_germany_cases_and_economic.html')
    fig.show()
    
