import pandas as pd
import matplotlib.pyplot as plt

# Load the data
df = pd.read_csv("data/ukhsa_genomic.csv")

# Convert date column to datetime format
df['date'] = pd.to_datetime(df['date'])

# Filter only genomic lineages
df = df[df['metric'] == "COVID-19_cases_lineagePercentByWeek"]

# Pivot the DataFrame: columns = genomic lineages
pivot = df.pivot(index='date', columns='stratum', values='metric_value')

# Optional: sort by variant names
pivot = pivot.sort_index(axis=1)

# Plotting
plt.figure(figsize=(14, 5))  # wider horizontal layout

for column in pivot.columns:
    plt.plot(pivot.index, pivot[column], label=column)

plt.title("Genomic Lineages of COVID-19 in England Over Time", fontsize=14)
plt.xlabel("Date")
plt.ylabel("Percentage of Cases (%)")
plt.ylim(0, 100)
plt.grid(True, linestyle='--', alpha=0.4)
plt.legend(title="Variant", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()


plt.savefig("plots/genomic_lineages_line_plot.png", dpi=300)
plt.show()

