import stuff
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters()  # datetime converter for a matplotlib
import seaborn as sns
sns.set(style="ticks", font_scale=1.5)
from statsmodels.tsa.seasonal import seasonal_decompose
import matplotlib.dates as mdates
from matplotlib.dates import DateFormatter
from datetime import datetime as dt
import time
from statsmodels.tsa.stattools import adfuller
import matplotlib.colors as mcolors
import urllib.request
import json
import plotly.io as pio
pio.renderers.default = "plotly_mimetype+notebook_connected"
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
pio.templates.default = "presentation"

# %matplotlib widget
download latest data, save as csv
# helper function to check if a year is a leap year
def is_leap_year(year):
    if year % 4 != 0:
        return False
    elif year % 100 != 0:
        return True
    elif year % 400 != 0:
        return False
    else:
        return True
    
def load_and_process_data(type, filename):
    # from https://climatereanalyzer.org/clim/sst_daily/
    # up-to-date URLs for world average and for north atlantic SST
    if type == "world":
        url = "https://climatereanalyzer.org/clim/sst_daily/json/oisst2.1_world2_sst_day.json"
    if type == "north":
        url = "https://climatereanalyzer.org/clim/sst_daily/json/oisst2.1_natlan1_sst_day.json"
    # load JSON data from URL
    with urllib.request.urlopen(url) as response:
        data = json.load(response)
    # Convert JSON data to DataFrame
    data = pd.DataFrame(data)
    # this dataframe is not good to work with, let's make a new one
    # columns are year numbers, rows are temperatures for each day of the year
    years = data['name'].astype(str).tolist()
    values = data['data'].tolist()
    df = pd.DataFrame(values).T
    df.columns = years
    # now let's make a continuous df, with all the data in one column
    # initializing an empty list to hold dataframes before concatenating them
    dfs = []
    # iterating over each column in the original dataframe
    for column in df.columns:
        try:
            # converting the column name to an integer to handle it as a year
            year = int(column)
            # determining the number of days in the year
            days_in_year = 366 if is_leap_year(year) else 365
            # creating a date range for the year
            dates = pd.date_range(start=f'{year}-01-01', end=f'{year}-12-31', periods=days_in_year)
            # creating a temporary dataframe for the year's data
            temp_df = pd.DataFrame({'sst': df[column][:days_in_year].values}, index=dates)
            # adding the temporary dataframe to the list
            dfs.append(temp_df)
        except ValueError:
            # skipping columns that do not represent a year (e.g., "1982-2011 mean", "plus 2σ", "minus 2σ")
            continue
    # concatenating all the temporary dataframes into one
    df_sst_concat = pd.concat(dfs)
    # resetting the index to have a datetime index
    df_sst_concat.index = pd.to_datetime(df_sst_concat.index)
    df_sst_concat.index.name = 'date'
    df_sst_concat.dropna(inplace=True)
    # save to file
    df_sst_concat.to_csv(filename, index=True)

load_and_process_data("world", "sst_world.csv")
load_and_process_data("north", "sst_north.csv")
load and plot data
df_north = pd.read_csv("sst_north.csv", index_col='date', parse_dates=True)
df_world = pd.read_csv("sst_world.csv", index_col='date', parse_dates=True)
add columns: day and year
# world
df_world['doy'] = df_world.index.day_of_year
df_world['year'] = df_world.index.year
df_grouped = df_world.groupby('year')

# north
df_north['doy'] = df_north.index.day_of_year
df_north['year'] = df_north.index.year
df_grouped = df_north.groupby('year')

34.1 range slider

widget 1
blue = px.colors.qualitative.D3[0]
orange = px.colors.qualitative.D3[1]

fig = make_subplots(specs=[[{"secondary_y": True}]])

fig.add_trace(
    go.Scatter(x=list(df_world.index),
               y=list(df_world['sst']),
               name='world mean',
               line=dict(color=blue),),
    secondary_y=False,
)

fig.add_trace(
    go.Scatter(x=list(df_north.index),
               y=list(df_north['sst']),
               name='north atlantic',
               line=dict(color=orange),),
    secondary_y=True,
)

# Add range slider
fig.update_layout(
    title='sea surface temperature',
    # yaxis_title='temperature (°C)',
    xaxis=dict(
        rangeslider={"visible":True},
        type="date"
    ),
    legend={"orientation":"h",           # Horizontal legend
            "yanchor":"top",             # Anchor legend to the top
            "y":1.1,                    # Adjust vertical position
            "xanchor":"center",           # Anchor legend to the right
            "x":0.5,                       # Adjust horizontal position
           },
)

# Set y-axes titles
fig.update_yaxes(
    title_text="temperature (°C)",
    titlefont=dict(
            color=blue
        ),
    secondary_y=False,)
fig.update_yaxes(
    title_text="temperature (°C)",
    titlefont=dict(
            color=orange
        ),
    secondary_y=True,)

34.2 temperature vs DOY, with colorbar

choose colors for next widget
# Define colormap and color range
# base_cmap = plt.cm.hot_r
# start, end = 0.3, 0.8

# base_cmap = plt.cm.coolwarm
# start, end = 0.3, 0.7

base_cmap = plt.cm.turbo
start, end = 0.1, 0.75


# Create truncated colormap
new_colors = base_cmap(np.linspace(start, end, 256))
new_cmap = mcolors.LinearSegmentedColormap.from_list("trunc({n},{a:.2f},{b:.2f})".format(n=base_cmap.name, a=start, b=end), new_colors)

# List of years
years = range(1981, 2025)

# Create dictionary mapping years to hexadecimal colors
year_colors = {}
for i, year in enumerate(years):
    # Calculate normalized value for the year
    norm_value = (year - years[0]) / (years[-1] - years[0])
    
    # Get color from colormap
    rgba_color = new_cmap(norm_value)
    
    # Convert RGBA to hex
    hex_color = mcolors.rgb2hex(rgba_color[:3])  # Exclude alpha channel
    
    # Add to dictionary
    year_colors[year] = hex_color

# Print the dictionary
# print(year_colors)

# overwrite last 2 years with hotter colors
year_colors[2023] = plt.cm.colors.to_hex('hotpink')
year_colors[2024] = plt.cm.colors.to_hex('deeppink')
widget 2
years = list(year_colors.keys())
colors = [year_colors[year] for year in years]

# Define a custom colorscale based on the extracted colors
colorscale = [[i / (len(years) - 1), colors[i]] for i in range(len(years))]

fig = go.Figure()

# iterate over each unique year
for year in df_world.index.year.unique():
    # filter the data for the current year
    data_year = df_world[df_world.index.year == year]
    
    # add a trace for the current year
    fig.add_trace(go.Scatter(
        x=data_year.index.dayofyear,  # x-axis: day of year
        y=data_year['sst'],            # y-axis: sea surface temperature
        mode='lines',
        # name=str(year),                 # Name of the trace (year)
        line=dict(color=year_colors[year]),
        hovertemplate='<b>Date</b>: %{text}<br>' +
                       '<b>SST (°C)</b>: %{y}',  # Customize hover template
        text=data_year.index.strftime('%Y-%m-%d'),  # Convert date to YYYY-MM-DD format
        hoverlabel=dict(namelength=0)  # Set namelength to 0 to remove the tag
    ))

# update layout
fig.update_layout(
    title='Sea Surface Temperature, World Average',
    xaxis_title='day of year',
    yaxis_title='temperature (°C)',
    showlegend=False,
)
# dummy data from which colorbar will be used
colorbar_trace  = go.Scatter(x=[None],
                             y=[None],
                             mode='markers',
                             marker=dict(
                                 colorscale=colorscale, 
                                 showscale=True,
                                 cmin=1981,
                                 cmax=2024,
                                 colorbar=dict(thickness=15, tickvals=[1981,1991,2001,2011,2024],
                                            #    ticktext=['Low', 'High'], outlinewidth=0
                                               )
                             ),
                             hoverinfo='none'
                            )
fig.add_trace(colorbar_trace)
fig.show()