import matplotlib.pyplot as plt
import matplotlib.animation as animation
import pandas as pd
import numpy as np
import matplotlib.ticker as ticker

def create_covid_animation_with_macro(cases, inflation, unemployment, output_path):
    # Data preparation (ensure 'date' column is of datetime type)
    inflation.columns = ['date', 'time_period', 'inflation']
    unemployment.columns = ['date', 'time_period', 'unemployment']

    cases = cases[cases['date'] <= '2024-04-30']
    inflation = inflation[inflation['date'] <= '2024-04-30']
    unemployment = unemployment[unemployment['date'] <= '2024-04-30']

    cases['date'] = pd.to_datetime(cases['date'])
    inflation['date'] = pd.to_datetime(inflation['date'])
    unemployment['date'] = pd.to_datetime(unemployment['date'])

    temp = inflation.merge(unemployment, how='inner', on='date')
    data = cases.merge(temp, how='left', on='date')

    data = data[['date', 'total_cases', 'inflation', 'unemployment']]
    data['inflation'] = data['inflation'].ffill().bfill()
    data['unemployment'] = data['unemployment'].ffill().bfill()

    dates = data['date']
    cases = data['total_cases'].values
    frames = len(data)

    idx_2022 = data.index[data['date'] >= pd.Timestamp('2022-01-01')][0]
    idx_2023 = data.index[data['date'] >= pd.Timestamp('2023-01-01')][0]

    pause_frames = 80

    # Thinning - choose every how many frames you want to show (e.g. every 5)
    step = 5

    # Build list of frames with pauses at specific original indices, then thin it out
    def build_frames_with_pause(idx_pause, total_frames, pause_len, step):
        frames_before = list(range(0, idx_pause))
        frames_pause = [idx_pause] * pause_len
        frames_after = list(range(idx_pause + 1, total_frames))
        full_frames = frames_before + frames_pause + frames_after
        # Thin the frames: take every step-th frame but always include all pauses
        frames_reduced = []
        i = 0
        while i < len(full_frames):
            frame = full_frames[i]
            if i >= idx_pause and i < idx_pause + pause_len:
                frames_reduced.extend(full_frames[idx_pause:idx_pause + pause_len])
                i = idx_pause + pause_len
            else:
                frames_reduced.append(frame)
                i += step
        return frames_reduced

    frames_2022 = build_frames_with_pause(idx_2022, idx_2023, pause_frames, step)
    frames_2023 = build_frames_with_pause(idx_2023, frames, pause_frames, step)

    # Combine all frames, paying attention to pause boundaries
    frames_before_2022 = list(range(0, idx_2022, step))
    frames_pause_2022 = [idx_2022] * pause_frames
    frames_between = list(range(idx_2022 + 1, idx_2023, step))
    frames_pause_2023 = [idx_2023] * pause_frames
    frames_after_2023 = list(range(idx_2023 + 1, frames, step))

    frames_for_anim = (
        frames_before_2022 + frames_pause_2022 +
        frames_between + frames_pause_2023 +
        frames_after_2023
    )

    # List of lockdowns (example dates in Germany)
    lockdowns = [
        {'date': '2020-03-22', 'label': 'Lockdown 1'},
        {'date': '2020-11-02', 'label': 'Lockdown 2'},
        {'date': '2020-12-16', 'label': 'Lockdown 3'},
        {'date': '2021-04-24', 'label': 'Lockdown 4'}
    ]

    # Convert lockdown dates to Timestamps
    for ld in lockdowns:
        ld['timestamp'] = pd.Timestamp(ld['date'])

    # Set up plot
    fig, ax = plt.subplots(figsize=(12, 6))
    line, = ax.plot([], [], color='blue')

    # Stop-frame annotations
    text_stop_2022 = ax.text(0, 0, 'Begin of COVID explosion, time : 2022-01-01', color='red', fontsize=12,
                             ha='center', va='bottom', visible=False)
    text_stop_2023 = ax.text(0, 0, 'Covid situation stabilizes, time : 2023-01-01', color='green', fontsize=12,
                             ha='center', va='bottom', visible=False)

    # Moving counter
    text_counter = ax.text(dates[0], cases[0], '', color='blue', fontsize=10,
                           ha='left', va='bottom', backgroundcolor='white')

    # Format Y-axis to display in millions and thousands
    def format_yticks(x, pos):
        if x >= 1e6:
            return f'{x*1e-6:.1f}M'
        elif x >= 1e3:
            return f'{x*1e-3:.1f}K'
        else:
            return f'{int(x)}'

    ax.yaxis.set_major_formatter(ticker.FuncFormatter(format_yticks))

    # X-axis for dates – show 8 evenly spaced ticks
    ax.set_xlim(dates.min(), dates.max())
    ax.set_ylim(0, cases.max() * 1.1)
    ax.set_title('Total COVID-19 Cases Over Time')
    ax.set_xlabel('Date')
    ax.set_ylabel('Total Cases')

    # Add vertical lockdown lines and labels – always visible (non-animated)
    for ld in lockdowns:
        ax.axvline(x=ld['timestamp'], color='gray', linestyle='--', linewidth=1, alpha=0.4)
        ax.text(ld['timestamp'], ax.get_ylim()[1]*0.8, ld['label'], rotation=90,
                verticalalignment='bottom', horizontalalignment='right',
                fontsize=9, color='gray', alpha=0.7)

    ticks_idx = np.linspace(0, frames - 1, 8, dtype=int)
    ticks_dates = dates.iloc[ticks_idx]
    ax.set_xticks(ticks_dates)
    ax.set_xticklabels([d.strftime('%Y-%m') for d in ticks_dates], rotation=45)

    def format_value(x):
        if x >= 1e6:
            return f'{x*1e-6:.1f}M'
        elif x >= 1e3:
            return f'{x*1e-3:.1f}K'
        else:
            return f'{x:.1f}'

    info_text = ax.text(0.02, 0.70, '', transform=ax.transAxes, ha='left', va='center',
                        fontsize=10, color='black', backgroundcolor='None')

    def update(frame_i):
        frame = frames_for_anim[frame_i]
        line.set_data(dates[:frame+1], cases[:frame+1])

        # Update trailing counter at the end of the line
        current_x = dates.iloc[frame]
        current_y = cases[frame]
        text_counter.set_position((current_x, current_y))
        text_counter.set_position((current_x, current_y - ax.get_ylim()[1]*0.05))
        text_counter.set_text(format_value(current_y))

        # Update info box in top-left corner
        current_date = dates.iloc[frame].strftime('%Y-%m-%d')
        current_inflation = data['inflation'].iloc[frame]
        current_unemployment = data['unemployment'].iloc[frame]

        info_text.set_text(
            f'{current_date}\n'
            f'Inflation in Germany: {current_inflation:.1f}%\n'
            f'Unemployment in Germany: {current_unemployment:.1f}%'
        )

        # Pause frame and line for 2022
        if frame >= idx_2022:
            if frame in range(idx_2022, idx_2022 + pause_frames):
                text_stop_2022.set_visible(True)
                text_stop_2022.set_position((dates.iloc[frame], ax.get_ylim()[1]*0.9))
            else:
                text_stop_2022.set_visible(False)
            if not hasattr(update, 'vline_2022'):
                update.vline_2022 = ax.axvline(x=dates.iloc[idx_2022], color='red', linestyle='--')
        else:
            if hasattr(update, 'vline_2022'):
                update.vline_2022.remove()
                del update.vline_2022
            text_stop_2022.set_visible(False)

        # Pause frame and line for 2023
        if frame >= idx_2023:
            if frame in range(idx_2023, idx_2023 + pause_frames):
                text_stop_2023.set_visible(True)
                text_stop_2023.set_position((dates.iloc[frame], ax.get_ylim()[1]*0.9))
            else:
                text_stop_2023.set_visible(False)
            if not hasattr(update, 'vline_2023'):
                update.vline_2023 = ax.axvline(x=dates.iloc[idx_2023], color='green', linestyle='--')
        else:
            if hasattr(update, 'vline_2023'):
                update.vline_2023.remove()
                del update.vline_2023
            text_stop_2023.set_visible(False)

        return line, text_stop_2022, text_stop_2023, text_counter, info_text

    ani = animation.FuncAnimation(fig, update, frames=len(frames_for_anim), interval=50, blit=True)
    ani.save(output_path, writer='pillow')
    plt.close(fig)
