#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Dec 21 15:39:33 2022

@author: badler
This script creates quicklooks from the WINDoe output

"""
import os
import sys
import glob
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import matplotlib
import string
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import gridspec
from cycler import cycler
from scipy.interpolate import interp1d
import datetime
import numpy as np
import numpy.matlib
import pandas as pd
from pandas.tseries.frequencies import to_offset
import xarray as xr
from netCDF4 import Dataset
import scipy.io as sio
import scipy
import shutil
import timeit
import tol_colors as tc
import matplotlib
#matplotlib.use('TkAgg') # Or 'TkAgg', 'GTK3Agg', etc.
matplotlib.use('Agg') # Or 'TkAgg', 'GTK3Agg', etc.
import matplotlib.pyplot as plt

import warnings
warnings.simplefilter("ignore", UserWarning)

# plot settings
params = {'legend.fontsize': 'xx-large',
          'figure.figsize': (15, 5),
          'axes.labelsize': 'x-large',
          'axes.titlesize':'xx-large',
          'xtick.labelsize':'x-large',
          'ytick.labelsize':'x-large',
          'xtick.minor.size':3.5,
          'xtick.top': True,
          'ytick.right': True}
plt.rcParams.update(params)
default_cycler=(cycler(color=plt.cm.nipy_spectral(np.linspace(0,1,15))))
plt.rc('axes', prop_cycle=default_cycler)
#some general defintions
tsdefcolor=['black','indianred','forestgreen','royalblue','mediumpurple',\
            'magenta','cyan','lime','orange','yellow','red']
tscycler=cycler(color=tsdefcolor)

dpi=50

#some functions
def datenum_to_datetime(datenum):
    """
    from: https://gist.github.com/victorkristof/b9d794fe1ed12e708b9d
    Convert Matlab datenum into Python datetime.
    :param datenum: Date in datenum format
    :return:        Datetime object corresponding to datenum.
    """
    days = datenum % 1
    return datetime.datetime.fromordinal(int(datenum)) \
           + datetime.timedelta(days=days) \
           - datetime.timedelta(days=366)
           
def datestr2num(*args):
    x=args[0]
    if len(args)==1:
        str_format='%Y%m%d'
    else:
        str_format=args[1]
    n=mdates.date2num(datetime.datetime.strptime(x,str_format))    
    return n   

def num2datestr(*args):
	x = args[0]
	if len(args)==1:
		str_format='%Y%m%d%H%M%S'
	else:
		str_format=args[1]
	d=mdates.num2date(x)
	s=d.strftime(str_format)

	return s
      
def get_ticks(start, end):
    from datetime import timedelta as td
    delta = end - start
    if delta <= td(days=1.5):
        loc = mdates.HourLocator(byhour=range(0,24,6))
        fmt = mdates.DateFormatter('%H:%M\n%m/%d')
    elif delta <= td(days=3):
        loc = mdates.HourLocator(byhour=range(0,24,12))
        fmt = mdates.DateFormatter('%m/%d %H')
    elif delta <= td(days=15):
        loc = mdates.DayLocator()
        fmt = mdates.DateFormatter('%m/%d')
    elif delta <= td(days=30):
        loc = mdates.DayLocator(interval=2)
        fmt = mdates.DateFormatter('%m/%d')
    elif delta <= td(days=50):
        loc = mdates.DayLocator(interval=5)
        fmt = mdates.DateFormatter('%m/%d')
    elif delta <= td(days=100):
        loc = mdates.DayLocator(interval=10)
        fmt = mdates.DateFormatter('%m/%d')
    else:
        loc = mdates.DayLocator(interval=20)
        fmt = mdates.DateFormatter('%m/%d')
   
    return loc,fmt 

def get_ticks_minor(start, end):
    from datetime import timedelta as td
    delta = end - start

    if delta <= td(minutes=10):
        loc = mdates.MinuteLocator()
    elif delta <= td(minutes=30):
        loc = mdates.MinuteLocator(byminute=range(0,60,5))
    elif delta <= td(hours=1):
        loc = mdates.MinuteLocator(byminute=range(0,60,5))
    elif delta <= td(hours=6):
        loc = mdates.HourLocator()
    elif delta <= td(days=1.5):
        loc = mdates.HourLocator(byhour=range(0,24,1))
    elif delta <= td(days=6):
        loc = mdates.HourLocator(byhour=range(0,24,6))
    elif delta <= td(weeks=10):
        loc = mdates.DayLocator()
    elif delta <= td(weeks=15):
        loc = mdates.DayLocator(interval=1)
        # loc = mdates.WeekdayLocator()
    elif delta <= td(weeks=104):
        loc = mdates.MonthLocator()
    elif delta <= td(weeks=208):
        loc = mdates.MonthLocator(interval=3)
    else:
        loc = mdates.MonthLocator(interval=6)
    return loc

def uv2ddff(u,v):
# [ff, dd] = uv2ddff(u, v);
#
# For input vectors u and v containing horitontal wind components
# horitontal wind speed and wind direction is calculated and output
# 
# Andreas Wieser 29-SEP-2004
# modified Bianca Adler 5.12.2018
#translated from Matlab to python started 13.04.2020

    ff=np.sqrt(u**2+v**2)
    dd=np.mod(180+np.rad2deg(np.arctan2(u,v)),360)
    return ff,dd

def plot_time_height(data,pspecs,drange):
    ##%%
    varsall=pspecs['varsall']
    clabelstr=pspecs['clabelstr']
    clim=pspecs['clim']
    ylim=pspecs['ylim']
    xlim=pspecs['xlim']
    fout=pspecs['fout']+'_'+str(drange)+'.png'
    fig,axs = plt.subplots(nrows=len(varsall),ncols=1,figsize=(13.5,4.5*len(varsall)),\
        facecolor='w',edgecolor='w',gridspec_kw={'hspace': 0.07,'wspace': 0.00},\
        tight_layout={'pad':0},sharex=True)
       
    axs=np.atleast_1d(axs)
    
    for i_varsall in range(0,len(varsall)):
        vars_=varsall[i_varsall]
        if type(axs) is np.ndarray:
            ax=axs[i_varsall]
        else:
            ax=axs
        cmin=clim[i_varsall][0]
        cmax=clim[i_varsall][1]
        for i_vars in range(0,len(vars_)):
            var=vars_[i_vars]
            if 'velocity (' in clabelstr[i_varsall]:
                new_cmap=plt.get_cmap(tc.tol_cmap('sunset'))
                cm = np.max([np.abs(cmin),cmax])
                cmin=-1*cm
                cmax=cm
            else:
                new_cmap=plt.get_cmap(tc.tol_cmap('rainbow_PuRd'))


            dplot=data['windoe']
            p=ax.pcolormesh(mdates.date2num(dplot.time.values),dplot['height'].values,dplot[var].values.T,vmin=cmin,vmax=cmax,cmap=new_cmap,rasterized=True,shading='auto')
            if 'wspd' in var:
                #add normalized wind vectors
                u=dplot['u_wind'].values/dplot['wspd'].values
                v=dplot['v_wind'].values/dplot['wspd'].values
                windtime=mdates.date2num(dplot.time.values)
                windheight=dplot.height.values
                if ylim[1] > 2:
                    hstep=10
                else:
                    hstep=5
                tstep=2
                ax.quiver(windtime[::tstep],windheight[::hstep],u[::tstep,::hstep].T,v[::tstep,::hstep].T,width=0.002)
            ax.grid('on')
        ax.set_ylim(ylim)
        ax.set_xlim(xlim)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        #cax=append_axes(ax)
        cb=plt.colorbar(p,cax=cax,extend='both')
        cb.set_label(label=clabelstr[i_varsall])
        ax.set_ylabel('Height (km AGL)')
        loc,fmt=get_ticks(mdates.num2date(xlim[0]),mdates.num2date(xlim[1]))
        ax.xaxis.set_major_locator(loc)
        ax.xaxis.set_major_formatter(fmt)
        locminor=get_ticks_minor(mdates.num2date(xlim[0]),mdates.num2date(xlim[1]))
        ax.xaxis.set_minor_locator(locminor)
    axs[-1].set_xlabel('Time (UTC)')
    axs[0].set_title(datestring+', '+site+' '+combination)
    

    if len(varsall)==3:
        logo=fig.add_axes([0.1,.015,0.05,0.05],anchor='NW',zorder=1)
        im=plt.imread(os.path.join(campaigndir,'functions','noaa_web.png'))
        logo.imshow(im)
        logo.axis('off')
            
        plt.text(0.9, 0.03, 'WINDoe (Gebauer and Bell, 2024) '+combination+\
                     '\nContact: Bianca.Adler@noaa.gov, Creation date: '+datetime.datetime.now(datetime.UTC).strftime('%d %b %Y %H:%M UTC'), \
             fontsize=14, transform=plt.gcf().transFigure,\
             ha='right')
    elif len(varsall)==5:
        logo=fig.add_axes([0.1,.025,0.05,0.05],anchor='NW',zorder=1)
        im=plt.imread(os.path.join(campaigndir,'functions','noaa_web.png'))
        logo.imshow(im)
        logo.axis('off')
            
        plt.text(0.9, 0.05, 'WINDoe (Gebauer and Bell, 2024) '+combination+\
                     '\nContact: Bianca.Adler@noaa.gov, Creation date: '+datetime.datetime.now(datetime.UTC).strftime('%d %b %Y %H:%M UTC'), \
             fontsize=14, transform=plt.gcf().transFigure,\
             ha='right')
    else:
        logo=fig.add_axes([0.1,.002,0.05,0.05],anchor='NW',zorder=1)
        im=plt.imread(os.path.join(campaigndir,'functions','noaa_web.png'))
        logo.imshow(im)
        logo.axis('off')
            
        plt.text(0.9, 0.01, 'WINDoe (Gebauer and Bell, 2024) '+combination+\
                     '\nContact: Bianca.Adler@noaa.gov, Creation date: '+datetime.datetime.now(datetime.UTC).strftime('%d %b %Y %H:%M UTC'), \
             fontsize=14, transform=plt.gcf().transFigure,\
             ha='right')

    print('save as '+fout)
    plt.savefig(fout,format='png',dpi=dpi,bbox_inches='tight')


def plot_timeseries_obsflag(data,pspecs,drange):
    ##%%
    varsall=pspecs['varsall']
    ylabelstr=pspecs['ylabelstr']
    ylims=pspecs['ylims']
    markers=pspecs['markers']
    fout=pspecs['fout']+'_'+str(drange)+'.png'
    fig,axs = plt.subplots(nrows=len(varsall),ncols=1,figsize=(10.5,3*len(varsall)),\
        facecolor='w',edgecolor='w',gridspec_kw={'hspace': 0.07,'wspace': 0.00},\
        tight_layout={'pad':0},sharex=True)

    axs=np.atleast_1d(axs)

    for i_varsall in range(0,len(varsall)):
        vars_=varsall[i_varsall]
        if type(axs) is np.ndarray:
            ax=axs[i_varsall]
        else:
            ax=axs
        ax.set_prop_cycle(tscycler)
        for i_vars in range(0,len(vars_)):
            var=vars_[i_vars]
            dplot=data
            plotdt=mdates.date2num(dplot['time'].values)
            plotvar=dplot[var.replace(' ','')].values
            labelstr=var
            p,=ax.plot(plotdt,plotvar,label=labelstr)
            if len(markers[i_varsall])>0:
                p.set_marker(markers[i_varsall])
                p.set_linestyle('None')

        ax.grid('on')
        try:
            ax.set_ylim(ylims[i_varsall])
        except:
            print('no ylim defined for subplot '+str(i_varsall))
        ax.set_ylabel(ylabelstr[i_varsall])

        lines,labels=ax.get_legend_handles_labels()
        if len(lines)>8:
            ax.legend(fontsize='medium',ncol=int(len(lines)/3.))
        elif len(lines) == 1:
            ax.legend(fontsize='medium')
        else:
            ax.legend(fontsize='medium',ncol=int(len(lines)/2.))
    
    ax.set_xlim(xlim[0],xlim[1])
    loc,fmt=get_ticks(mdates.num2date(ax.get_xlim()[0]),mdates.num2date(ax.get_xlim()[1]))
    ax.xaxis.set_major_locator(loc)
    ax.xaxis.set_major_formatter(fmt)
    if drange==1:
        locminor=mdates.HourLocator()
        ax.xaxis.set_minor_locator(locminor)
    ax.set_xlabel('Time (UTC)')
    axs[0].set_title(datestring+', '+site+' '+combination)

    logo=fig.add_axes([0.1,0.04,0.05,0.05],anchor='NW',zorder=1)
    im=plt.imread(os.path.join(campaigndir,'functions','noaa_web.png'))
    logo.imshow(im)
    logo.axis('off')

    plt.text(0.9, 0.05, 'Sample size of observations used as input for WINDoe'+\
                 '\nContact: Bianca.Adler@noaa.gov, Creation date: '+datetime.datetime.now(datetime.UTC).strftime('%d %b %Y %H:%M UTC'), \
         fontsize=14, transform=plt.gcf().transFigure,\
         ha='right')
    print('save as '+fout)
    plt.savefig(fout,format='png',dpi=dpi,bbox_inches='tight')


def plot_timeseries_nativeres(data,pspecs,drange):
    ##%%
    varsall=pspecs['varsall']
    ylabelstr=pspecs['ylabelstr']
    ylims=pspecs['ylims']
    markers=pspecs['markers']
    fout=pspecs['fout']+'_'+str(drange)+'.png'
    fig,axs = plt.subplots(nrows=len(varsall),ncols=1,figsize=(13.5,3*len(varsall)),\
        facecolor='w',edgecolor='w',gridspec_kw={'hspace': 0.07,'wspace': 0.00},\
        tight_layout={'pad':0},sharex=True)

    axs=np.atleast_1d(axs)

    for i_varsall in range(0,len(varsall)):
        vars_=varsall[i_varsall]
        if type(axs) is np.ndarray:
            ax=axs[i_varsall]
        else:
            ax=axs
        ax.set_prop_cycle(tscycler)
        for i_vars in range(0,len(vars_)):
            var=vars_[i_vars]
            dplot=data['lidar']
            plotdt=mdates.date2num(dplot['time'].values)
            plotvar=dplot[var].values
            labelstr=var
            p,=ax.plot(plotdt,plotvar,label=labelstr)
            if len(markers[i_varsall])>0:
                p.set_marker(markers[i_varsall])
                p.set_linestyle('None')

        ax.grid('on')
        try:
            ax.set_ylim(ylims[i_varsall])
        except:
            print('no ylim defined for subplot '+str(i_varsall))
        ax.set_ylabel(ylabelstr[i_varsall])

        lines,labels=ax.get_legend_handles_labels()
        if len(lines)>8:
            ax.legend(fontsize='medium',ncol=int(len(lines)/3.))
        else:
            ax.legend(fontsize='medium',ncol=int(len(lines)/2.))

    ax.set_xlim(xlim[0],xlim[1])
    loc,fmt=get_ticks(mdates.num2date(ax.get_xlim()[0]),mdates.num2date(ax.get_xlim()[1]))
    ax.xaxis.set_major_locator(loc)
    ax.xaxis.set_major_formatter(fmt)
    if drange==1:
        locminor=mdates.HourLocator()
        ax.xaxis.set_minor_locator(locminor)
    ax.set_xlabel('Time (UTC)')
    axs[0].set_title(datestring+', '+site+' '+lidar)

    logo=fig.add_axes([0.1,0.04,0.05,0.05],anchor='NW',zorder=1)
    im=plt.imread(os.path.join(campaigndir,'functions','noaa_web.png'))
    logo.imshow(im)
    logo.axis('off')

    plt.text(0.9, 0.05, 'ASSIST housekeeping'+\
                 '\nContact: Bianca.Adler@noaa.gov, Creation date: '+datetime.datetime.now(datetime.UTC).strftime('%d %b %Y %H:%M UTC'), \
         fontsize=14, transform=plt.gcf().transFigure,\
         ha='right')
    print('save as '+fout)
    plt.savefig(fout,format='png',dpi=dpi,bbox_inches='tight')


#%%
#get some global variables
windoedir=os.getenv('windoedir')
campaigndir=os.getenv('campaigndir')
combination=os.getenv('combination')
datestring=os.getenv('pdate')
site=os.getenv('site')
siteid=os.getenv('siteid')
oversion=os.getenv('oversion')
fpathout=os.path.join(windoedir,'quicklooks')
if not os.path.isdir(fpathout):
    os.makedirs(fpathout)

#24 hr periods
sdate=pd.to_datetime(datestring)
edate=pd.to_datetime(datestring)+pd.to_timedelta('1D')
#statustime=edate
#for realtime processing
#if os.environ.get('statustime'):
#    edate=pd.Timestamp(os.environ.get('statustime'))
#    sdate=edate-pd.to_timedelta('1D')
#else:
#    sdate=pd.Timestamp.now()-pd.to_timedelta('1D')
#    edate=pd.Timestamp.now()
statustime=edate
yesterday =  (statustime-pd.to_timedelta('1D')).strftime('%Y%m%d')

# if yesterday is processed do full day
if yesterday == datestring:
    statustime=datetime.datetime.strptime(datestring+'000000','%Y%m%d%H%M%S')+datetime.timedelta(days=1)
    yesterday =  (statustime-pd.to_timedelta('2D')).strftime('%Y%m%d')

plotinfo={}

#%% WINDoe
#read present day and yesterday
listdir = sorted(glob.glob(os.path.join(windoedir,'retrieval_output',siteid+'.'+combination+'.WINDoe.'+oversion+'.'+datestring+'*.nc')))
listdiryesterday = sorted(glob.glob(os.path.join(windoedir,'retrieval_output',combination+'.WINDoe.'+oversion+'.'+yesterday+'*.nc')))
if len(listdir)==0:
    print('no WINDoe output found for current day')
    sys.exit()
else:
    print('read '+listdir[0])
    INt=xr.load_dataset(listdir[0],decode_times=False,drop_variables=['Sa','cov','dfs'])
    try:
        INt = INt.drop_dims(['obs_dim','arb_dim1'])
    except:
        print('obs_dim and arb_dim1 cannot be dropped')
    dt=[datetime.datetime.fromtimestamp(t,datetime.UTC).replace(tzinfo=None) for t in INt['base_time'].values+INt['time_offset'].values]
    INt['time']=dt

    if len(listdiryesterday) > 0:
        print('read '+listdiryesterday[0])
        INy=xr.load_dataset(listdiryesterday[0],decode_times=False,drop_variables=['Sa','cov','dfs'])
        INy = INy.drop_dims(['obs_dim','arb_dim1'])
        dt=[datetime.datetime.fromtimestamp(t,datetime.UTC).replace(tzinfo=None) for t in INy['base_time'].values+INy['time_offset'].values]
        INy['time']=dt
        IN = xr.concat([INy,INt],dim='time')
    else:
        IN = INt
        
IN['wspd'],IN['wdir']=uv2ddff(IN.u_wind,IN.v_wind)
IN['dfsu']=xr.concat([IN['cdfs_U'].isel(height=0),IN['cdfs_U'].diff(dim='height')],dim='height')
IN['dfsv']=xr.concat([IN['cdfs_V'].isel(height=0),IN['cdfs_V'].diff(dim='height')],dim='height')

#filter wind for dfs
mask=(IN['dfsu']>0.) & (IN['dfsv']>0.0)
#filter for simga
mask = (IN['sigma_u']<=2.5) & (IN['sigma_v']<=2.5)
mask = (IN['sigma_u']<=3) & (IN['sigma_v']<=3)
IN['u_wind'] = IN['u_wind'].where(mask)
IN['v_wind'] = IN['v_wind'].where(mask)
IN['wspd'] = IN['wspd'].where(mask)


#plot time height raw,vertical
varsall=[['wspd'],['sigma_u'],['sigma_v'],['dfsu'],['dfsv']]
clabelstr=['Wind speed (m s$^{-1}$)','1-sigma u (m s$^{-1}$)','1-sigma v (m s$^{-1}$)','DFS u','DFS v']
clim=[[0,20],[0,3],[0,3],[0,1],[0,1]]
ylim=[0,5]
fout=os.path.join(fpathout,datestring+'_WINDoe_'+oversion+'_fig01_time_height')

set_1={'varsall': varsall,'clabelstr': clabelstr, \
        'ylim': ylim,'clim':clim,'fout':fout}
ylim=[0,2]
fout=os.path.join(fpathout,datestring+'_WINDoe_'+oversion+'_fig03_time_height')
set_2={'varsall': varsall,'clabelstr': clabelstr, \
        'ylim': ylim,'clim':clim,'fout':fout}

drange_=[1]
# drange_=[90]

for drange in drange_:
    for set_ in [set_1, set_2]:
        xlim=[mdates.date2num(statustime-datetime.timedelta(days=drange)),\
              mdates.date2num(statustime)]
        set_['xlim']=xlim
        data={}
        INplot=IN.sel(time=slice(mdates.num2date(xlim[0]).replace(tzinfo=None),mdates.num2date(xlim[1]).replace(tzinfo=None)))    
        data['windoe']=INplot
        plot_time_height(data,set_,drange)        


# plot obscount_flag
obsunique = np.unique(IN.obsunique_flag)
obsunique = obsunique[obsunique>0]

#loop through each observation
plotvar=xr.Dataset()
instr=[]
for i in range(len(obsunique)):
    if obsunique[i] == 1:
        instr_ = 'Raw lidar'
    elif obsunique[i] == 2:
        instr_ = 'Proc lidar u'
    elif obsunique[i] == 3:
        instr_ = 'Proc lidar v'
    elif obsunique[i] == 4:
        instr_ = 'Proc RWP high and low u'
    elif obsunique[i] == 5:
        instr_ = 'Proc RWP high and low v'
    elif obsunique[i] == 6:
        instr_ = 'In situ u'
    elif obsunique[i] == 7:
        instr_ = 'In situ v'
    mask = IN.obsunique_flag == obsunique[i]
    print('process '+instr_)
    plotvar[instr_.replace(' ','')]=IN.obscount_flag.where(mask).sum(dim='obsflag_dim')
    instr.append([instr_])
varsall = instr
ylabelstr = np.full(len(instr),'# observations')

ylim = np.matlib.repmat([],len(instr),1)
markers = np.full(len(instr),'')
fout=os.path.join(fpathout,datestring+'_WINDoe_'+oversion+'_fig02_obsflag')

set_1={'varsall': varsall,'ylabelstr': ylabelstr, \
        'ylims': ylim,'markers':markers,'fout':fout}
drange_=[1]

for drange in drange_:
    xlim=[mdates.date2num(statustime-datetime.timedelta(days=drange)),\
          mdates.date2num(statustime)]

    set_1['xlim']=xlim
    data={}
    INplot=plotvar.sel(time=slice(mdates.num2date(xlim[0]).replace(tzinfo=None),mdates.num2date(xlim[1]).replace(tzinfo=None)))    
    data=INplot

    plot_timeseries_obsflag(data,set_1,drange)
