import pandas as pd
from bokeh.palettes import Category10
from datetime import datetime
from bokeh.plotting import figure, show, output_file, save
from bokeh.models import ColumnDataSource, HoverTool, NumeralTickFormatter, DatetimeTickFormatter, FixedTicker, Grid
from bokeh.models.tickers import MonthsTicker
import argparse
from matplotlib import rcParams

# CLI argument parser
parser = argparse.ArgumentParser(description="Choose whether to display the plot (0) or save it to a file (1).")
parser.add_argument('action', type=int, choices=[0, 1], help="0 - display the plot, 1 - save to file")
args = parser.parse_args()

rcParams['font.family'] = 'Liberation Serif' 

# Load data from CSV files
df_hun = pd.read_csv("../csv/HUNGARY-avg-vaccine.csv")
df_swe = pd.read_csv("../csv/SWEDEN-avg-vaccine.csv")

# Convert 'Date_reported' column to datetime format
df_hun["Day"] = pd.to_datetime(df_hun["Day"])
df_swe["Day"] = pd.to_datetime(df_swe["Day"])

# Create Bokeh data sources for Sweden and Hungary
source_hun = ColumnDataSource(data={
    "date": df_hun["Day"],
    "doses": df_hun["COVID-19 doses (daily, 7-day average)"]
})
source_swe = ColumnDataSource(data={
    "date": df_swe["Day"],
    "doses": df_swe["COVID-19 doses (daily, 7-day average)"]
})

# Create the plot figure
p = figure(
    title="COVID-19: Daily Vaccine Doses (7-day average)",
    x_axis_label="Date",
    y_axis_label="Doses",
    x_axis_type="datetime",
    width=900,
    height=400,
    tools="pan,wheel_zoom,box_zoom,reset"
)

# Style plot title
p.title.align = "center"
p.title.text_font_size = "20pt"
p.title.text_font = 'Liberation Serif' 

# Set background color inside plot axes only
p.background_fill_color = "#E6F2FF" 

# Set date range for x-axis
p.x_range.start = datetime(2021, 1, 1)
p.x_range.end = datetime(2022, 1, 1)

# Add lines for Sweden and Hungary data
line_hun = p.line('date', 'doses', source=source_hun, line_width=2,
                  color=Category10[10][0], legend_label="Hungary")
line_swe = p.line('date', 'doses', source=source_swe, line_width=2,
                  color=Category10[10][1], legend_label="Sweden")


# Format y-axis ticks 
p.yaxis.formatter = NumeralTickFormatter(format="0.0a")  


# Configure hover tooltip
hover_hun = HoverTool(
    tooltips=[
        ("Country", "Hungary"),
        ("Date", "@date{%F}"),
        ("Doses", "@doses{0,0}")
    ],
    formatters={"@date": "datetime"},
    mode="vline",
    renderers=[line_hun]
)
hover_swe = HoverTool(
    tooltips=[
        ("Country", "Sweden"),
        ("Date", "@date{%F}"),
        ("Doses", "@doses{0,0}")
    ],
    formatters={"@date": "datetime"},
    mode="vline",
    renderers=[line_swe]
)
p.add_tools(hover_hun, hover_swe)

# Increase axis label font size
p.xaxis.axis_label_text_font_size = "14pt"
p.yaxis.axis_label_text_font_size = "14pt"

# Style x and y axis major labels
p.xaxis.major_label_orientation = 0.52  # Rotate x labels 
p.xaxis.major_label_text_color = "#2C3E50"
p.xaxis.major_label_text_font_size = "12pt"
p.yaxis.major_label_text_color = "#2C3E50"
p.yaxis.major_label_text_font_size = "12pt"

# Style legend (position, interactivity, font size)
p.legend.location = "top_left"
p.legend.click_policy = "hide"
p.legend.label_text_font_size = "13pt"

# Set major ticks on x-axis every 6 months 
major_ticker = MonthsTicker(months=list(range(1, 13)))
p.xaxis.ticker = major_ticker
p.xaxis.formatter = DatetimeTickFormatter(months="%d.%m.%Y", years="%d.%m.%Y")

# Configure major x-grid lines
p.xgrid[0].ticker = major_ticker
p.xgrid[0].grid_line_color = "#FFFFFF"
p.xgrid[0].grid_line_width = 1.5

# Add minor x-grid lines for every month
minor_ticker_x = MonthsTicker(months=list(range(1, 13)))
minor_grid_x = Grid(dimension=0, ticker=minor_ticker_x,
                    grid_line_color="#FFFFFF", grid_line_width=1)
p.add_layout(minor_grid_x)

# Configure major y-grid lines
p.ygrid[0].grid_line_color = "#FFFFFF"
p.ygrid[0].grid_line_width = 1.5


minor_y_ticker = FixedTicker(ticks=list(range(0, 130000, 5000)))
minor_grid_y = Grid(dimension=1, ticker=minor_y_ticker,
                    grid_line_color="#FFFFFF", grid_line_width=1)
p.add_layout(minor_grid_y)

# Define output HTML file
output_path = "../images/vaccine_doses_comparison.html"
output_file(output_path)
output_file(output_path)

# Show plot or save to file based on CLI argument
if args.action == 0:
    show(p)
else:
    save(p, filename=output_path)
    print(f"Plot saved to: {output_path}")


