import pandas as pd
from bokeh.plotting import figure, show, output_file, save
from bokeh.models import ColumnDataSource, HoverTool, NumeralTickFormatter, DatetimeTickFormatter, FixedTicker, Grid
from bokeh.models.tickers import MonthsTicker
from datetime import datetime
import argparse
from matplotlib import rcParams

# CLI argument parser
parser = argparse.ArgumentParser(description="Choose whether to display the plot (0) or save it to a file (1).")
parser.add_argument('action', type=int, choices=[0, 1], help="0 - display the plot, 1 - save to file")
args = parser.parse_args()


rcParams['font.family'] = 'Liberation Serif' 

# Load data from CSV files
death_sw = pd.read_csv('../csv/SWEDEN-daily-cases-deaths.csv')
death_hun = pd.read_csv('../csv/HUNGARY-daily-cases-deaths.csv')

# Convert 'Date_reported' column to datetime format
death_sw['Date_reported'] = pd.to_datetime(death_sw['Date_reported'])
death_hun['Date_reported'] = pd.to_datetime(death_hun['Date_reported'])

# Create Bokeh data sources for Sweden and Hungary
source_sw = ColumnDataSource(data={
    'date': death_sw['Date_reported'],
    'deaths': death_sw['Cumulative_deaths']
})
source_hun = ColumnDataSource(data={
    'date': death_hun['Date_reported'],
    'deaths': death_hun['Cumulative_deaths']
})

# Create the plot figure
p = figure(
    title="COVID-19: Cumulative Deaths – Sweden vs Hungary",
    x_axis_label="Date",
    y_axis_label="Cumulative Deaths",
    x_axis_type='datetime',
    width=900,
    height=400,
    tools="pan,wheel_zoom,box_zoom,reset"
)

# Style plot title
p.title.align = "center"
p.title.text_font_size = "20pt"
p.title.text_font = 'Liberation Serif' 

# Set background color inside plot axes only
p.background_fill_color = "#E6F2FF" 

# Set date range for x-axis
p.x_range.start = datetime(2020, 1, 4)
p.x_range.end = datetime(2022, 1, 1)

# Add lines for Sweden and Hungary data
line_sw = p.line('date', 'deaths', source=source_sw,
                 line_color="black", line_dash="dashed", line_width=2, legend_label="Sweden")
line_hun = p.line('date', 'deaths', source=source_hun,
                  line_color="red", line_width=2, legend_label="Hungary")

# Configure hover tooltip
hover = HoverTool(
    tooltips=[
        ("Date", "@date{%F}"),
        ("Deaths", "@deaths{0,0}")
    ],
    formatters={'@date': 'datetime'},
    mode='vline',
    renderers=[line_sw, line_hun]
)
p.add_tools(hover)

# Format y-axis ticks (e.g., 30000 -> 30K)
p.yaxis.formatter = NumeralTickFormatter(format="0a")

# Increase axis label font size
p.xaxis.axis_label_text_font_size = "14pt"
p.yaxis.axis_label_text_font_size = "14pt"

# Style x and y axis major labels
p.xaxis.major_label_orientation = 0.52  # Rotate x labels 
p.xaxis.major_label_text_color = "#2C3E50"
p.xaxis.major_label_text_font_size = "12pt"
p.yaxis.major_label_text_color = "#2C3E50"
p.yaxis.major_label_text_font_size = "12pt"

# Style legend (position, interactivity, font size)
p.legend.location = "top_left"
p.legend.click_policy = "hide"
p.legend.label_text_font_size = "13pt"

# Set major ticks on x-axis every 6 months (January and July)
major_ticker = MonthsTicker(months=[1, 7])
p.xaxis.ticker = major_ticker
p.xaxis.formatter = DatetimeTickFormatter(months="%d.%m.%Y", years="%d.%m.%Y")

# Configure major x-grid lines
p.xgrid[0].ticker = major_ticker
p.xgrid[0].grid_line_color = "#FFFFFF"
p.xgrid[0].grid_line_width = 1.5

# Add minor x-grid lines for every month
minor_ticker_x = MonthsTicker(months=list(range(1, 13)))
minor_grid_x = Grid(dimension=0, ticker=minor_ticker_x,
                    grid_line_color="#FFFFFF", grid_line_width=1)
p.add_layout(minor_grid_x)

# Configure major y-grid lines
p.ygrid[0].grid_line_color = "#FFFFFF"
p.ygrid[0].grid_line_width = 1.5

# Add minor y-grid lines every 5,000 deaths 
minor_y_ticker = FixedTicker(ticks=list(range(0, 50001, 5000)))
minor_grid_y = Grid(dimension=1, ticker=minor_y_ticker,
                    grid_line_color="#FFFFFF", grid_line_width=1)
p.add_layout(minor_grid_y)

# Define output HTML file
output_path = "../images/covid_cumulative_deaths_comparison.html"
output_file(output_path)

# Show plot or save to file based on CLI argument
if args.action == 0:
    show(p)
else:
    save(p, filename=output_path)
    print(f"Plot saved to: {output_path}")
