#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Dec 21 15:39:33 2022

@author: badler
This script creates realtime lidar plots
This includes 

"""
import os
import sys
import glob
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import matplotlib
import string
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import gridspec
from cycler import cycler
from scipy.interpolate import interp1d
#from metpy.units import units
#import metpy.calc as mpcalc
import datetime
import numpy as np
import pandas as pd
from pandas.tseries.frequencies import to_offset
import xarray as xr
from netCDF4 import Dataset
import scipy.io as sio
import scipy
import shutil
import timeit
import tol_colors as tc
import warnings
warnings.simplefilter("ignore", RuntimeWarning)

# plot settings
params = {'legend.fontsize': 'xx-large',
          'figure.figsize': (15, 5),
          'axes.labelsize': 'x-large',
          'axes.titlesize':'xx-large',
          'xtick.labelsize':'x-large',
          'ytick.labelsize':'x-large',
          'xtick.minor.size':3.5,
          'xtick.top': True,
          'ytick.right': True}
plt.rcParams.update(params)
default_cycler=(cycler(color=plt.cm.nipy_spectral(np.linspace(0,1,15))))
plt.rc('axes', prop_cycle=default_cycler)
#some general defintions
tsdefcolor=['black','indianred','forestgreen','royalblue','mediumpurple',\
            'magenta','cyan','lime','orange','yellow','red']
tscycler=cycler(color=tsdefcolor)

dpi=50

#some functions
def datenum_to_datetime(datenum):
    """
    from: https://gist.github.com/victorkristof/b9d794fe1ed12e708b9d
    Convert Matlab datenum into Python datetime.
    :param datenum: Date in datenum format
    :return:        Datetime object corresponding to datenum.
    """
    days = datenum % 1
    return datetime.datetime.fromordinal(int(datenum)) \
           + datetime.timedelta(days=days) \
           - datetime.timedelta(days=366)
           
def datestr2num(*args):
    x=args[0]
    if len(args)==1:
        str_format='%Y%m%d'
    else:
        str_format=args[1]
    n=mdates.date2num(datetime.datetime.strptime(x,str_format))    
    return n   

def num2datestr(*args):
	x = args[0]
	if len(args)==1:
		str_format='%Y%m%d%H%M%S'
	else:
		str_format=args[1]
	d=mdates.num2date(x)
	s=d.strftime(str_format)

	return s
      
def get_ticks(start, end):
    from datetime import timedelta as td
    delta = end - start
    if delta <= td(days=1.5):
        loc = mdates.HourLocator(byhour=range(0,24,6))
        fmt = mdates.DateFormatter('%H:%M\n%m/%d')
    elif delta <= td(days=3):
        loc = mdates.HourLocator(byhour=range(0,24,12))
        fmt = mdates.DateFormatter('%m/%d %H')
    elif delta <= td(days=15):
        loc = mdates.DayLocator()
        fmt = mdates.DateFormatter('%m/%d')
    elif delta <= td(days=30):
        loc = mdates.DayLocator(interval=2)
        fmt = mdates.DateFormatter('%m/%d')
    elif delta <= td(days=50):
        loc = mdates.DayLocator(interval=5)
        fmt = mdates.DateFormatter('%m/%d')
    elif delta <= td(days=100):
        loc = mdates.DayLocator(interval=10)
        fmt = mdates.DateFormatter('%m/%d')
    else:
        loc = mdates.DayLocator(interval=20)
        fmt = mdates.DateFormatter('%m/%d')
   
    return loc,fmt 

def get_ticks_minor(start, end):
    from datetime import timedelta as td
    delta = end - start

    if delta <= td(minutes=10):
        loc = mdates.MinuteLocator()
    elif delta <= td(minutes=30):
        loc = mdates.MinuteLocator(byminute=range(0,60,5))
    elif delta <= td(hours=1):
        loc = mdates.MinuteLocator(byminute=range(0,60,5))
    elif delta <= td(hours=6):
        loc = mdates.HourLocator()
    elif delta <= td(days=1.5):
        loc = mdates.HourLocator(byhour=range(0,24,1))
    elif delta <= td(days=6):
        loc = mdates.HourLocator(byhour=range(0,24,6))
    elif delta <= td(weeks=10):
        loc = mdates.DayLocator()
    elif delta <= td(weeks=15):
        loc = mdates.DayLocator(interval=1)
        # loc = mdates.WeekdayLocator()
    elif delta <= td(weeks=104):
        loc = mdates.MonthLocator()
    elif delta <= td(weeks=208):
        loc = mdates.MonthLocator(interval=3)
    else:
        loc = mdates.MonthLocator(interval=6)
    return loc

def resample_xarray(data):
    ##%%
    #determine median temporal resolution
    sec=np.median(np.diff(data.time.values))/np.timedelta64(1,'s')
    dt=mdates.date2num(data.time)
    #drop duplicates
    #data=data.drop_duplicates(dim='time')
    _,idx=np.unique(dt,return_index=True)
    data=data.isel(time=idx)
    #if last value of data is earlier than statustime, extent to statustime (values are nan)
    if data.time[-1] < np.datetime64(statustime):
        timeint=np.concatenate((data.time,np.atleast_1d(np.datetime64(statustime))))
        data=data.interp(time=timeint)

    #put in chunks again
    #data=data.chunk({'time':int(len(data.time)/60)})
  
    #resample, so that mssing values have timestamp
    datars=data.resample(time=str(sec)+'S').nearest(tolerance=str(sec*2)+'S')
    ##%%
    return datars

def calc_variance(da):
    #compute variance profile for 30 min intervals
    #loop over heights and use dataframe to compute variance
    OUT_={}
    for i in range(len(da.range)):
        df=da.isel(range=i).to_dataframe()
        # c=df.rolling('30min').count()
        #compute 30 min variance
        varw=df.resample('30min').var()
        c=df.resample('30min').count()
        finratio=c.velocity/c.range
        #at least 75 % data avaialbe in time period
        varw[finratio<.75]=np.nan
        if i==0:
            temp_=np.full((len(varw.index),len(da.range)),np.nan)
            temp_[:,i]=varw.velocity.values
        else:
            temp_[:,i]=varw.velocity.values
    dao=xr.DataArray(data=temp_,dims=('time','range'),coords=dict(
        time=('time',varw.index+np.timedelta64(15*60,'s')),
        range=da.range.values),
        name='varw')
    return dao

def calc_skewness(da):
    #compute skewnss profile for 30 min intervals
    #loop over heights and use dataframe to compute variance
    OUT_={}
    for i in range(len(da.range)):
        df=da.isel(range=i).to_dataframe()
        def skewness(x):
            return scipy.stats.skew(x)
        skeww=df.resample('30min').apply(skewness)
        #compute 30 min variance
        c=df.resample('30min').count()
        finratio=c.velocity/c.range
        if i==0:
            temp_=np.full((len(skeww.index),len(da.range)),np.nan)
            temp_[:,i]=skeww.velocity.values
        else:
            temp_[:,i]=skeww.velocity.values
    dao=xr.DataArray(data=temp_,dims=('time','range'),coords=dict(
        time=('time',skeww.index+np.timedelta64(15*60,'s')),
        range=da.range.values),
        name='skeww')
    return dao



def plot_time_height(data,pspecs,drange):
    ##%%
    varsall=pspecs['varsall']
    clabelstr=pspecs['clabelstr']
    clim=pspecs['clim']
    ylim=pspecs['ylim']
    xlim=pspecs['xlim']
    fout=pspecs['fout']+'_'+str(drange)+'.png'
    fig,axs = plt.subplots(nrows=len(varsall),ncols=1,figsize=(13.5,4.5*len(varsall)),\
        facecolor='w',edgecolor='w',gridspec_kw={'hspace': 0.07,'wspace': 0.00},\
        tight_layout={'pad':0},sharex=True)
       
    axs=np.atleast_1d(axs)
    
    for i_varsall in range(0,len(varsall)):
        vars_=varsall[i_varsall]
        if type(axs) is np.ndarray:
            ax=axs[i_varsall]
        else:
            ax=axs
        cmin=clim[i_varsall][0]
        cmax=clim[i_varsall][1]
        for i_vars in range(0,len(vars_)):
            var=vars_[i_vars]
            if 'velocity (' in clabelstr[i_varsall]:
                new_cmap=plt.get_cmap(tc.tol_cmap('sunset'))
                cm = np.max([np.abs(cmin),cmax])
                cmin=-1*cm
                cmax=cm
            else:
                new_cmap=plt.get_cmap(tc.tol_cmap('rainbow_PuRd'))


            dplot=data['lidar']
            p=ax.pcolormesh(mdates.date2num(dplot.time.values),dplot['range'].values,dplot[var].values.T,vmin=cmin,vmax=cmax,cmap=new_cmap,rasterized=True,shading='auto')
            ax.grid('on')
        ax.set_ylim(ylim)
        ax.set_xlim(xlim)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        #cax=append_axes(ax)
        cb=plt.colorbar(p,cax=cax,extend='both')
        cb.set_label(label=clabelstr[i_varsall])
        ax.set_ylabel('Height (m AGL)')
        loc,fmt=get_ticks(mdates.num2date(xlim[0]),mdates.num2date(xlim[1]))
        ax.xaxis.set_major_locator(loc)
        ax.xaxis.set_major_formatter(fmt)
        locminor=get_ticks_minor(mdates.num2date(xlim[0]),mdates.num2date(xlim[1]))
        ax.xaxis.set_minor_locator(locminor)
    axs[-1].set_xlabel('Time (UTC)')
    axs[0].set_title(datestring+', '+site+' '+lidar)
    

    if len(varsall)==3:
        logo=fig.add_axes([0.1,.015,0.05,0.05],anchor='NW',zorder=1)
        im=plt.imread(os.path.join(campaigndir,'functions','noaa_web.png'))
        logo.imshow(im)
        logo.axis('off')
            
        plt.text(0.9, 0.03, 'Lidar '+lidar+\
                     '\nContact: Bianca.Adler@noaa.gov, Creation date: '+datetime.datetime.now(datetime.UTC).strftime('%d %b %Y %H:%M UTC'), \
             fontsize=14, transform=plt.gcf().transFigure,\
             ha='right')
    else:
        logo=fig.add_axes([0.1,.002,0.05,0.05],anchor='NW',zorder=1)
        im=plt.imread(os.path.join(campaigndir,'functions','noaa_web.png'))
        logo.imshow(im)
        logo.axis('off')
            
        plt.text(0.9, 0.01, 'Lidar '+lidar+\
                     '\nContact: Bianca.Adler@noaa.gov, Creation date: '+datetime.datetime.now(datetime.UTC).strftime('%d %b %Y %H:%M UTC'), \
             fontsize=14, transform=plt.gcf().transFigure,\
             ha='right')

    print('save as '+fout)
    plt.savefig(fout,format='png',dpi=dpi,bbox_inches='tight')



def plot_timeseries_nativeres(data,pspecs,drange):
    ##%%
    varsall=pspecs['varsall']
    ylabelstr=pspecs['ylabelstr']
    ylims=pspecs['ylims']
    markers=pspecs['markers']
    fout=pspecs['fout']+'_'+str(drange)+'.png'
    fig,axs = plt.subplots(nrows=len(varsall),ncols=1,figsize=(13.5,3*len(varsall)),\
        facecolor='w',edgecolor='w',gridspec_kw={'hspace': 0.07,'wspace': 0.00},\
        tight_layout={'pad':0},sharex=True)

    axs=np.atleast_1d(axs)

    for i_varsall in range(0,len(varsall)):
        vars_=varsall[i_varsall]
        if type(axs) is np.ndarray:
            ax=axs[i_varsall]
        else:
            ax=axs
        ax.set_prop_cycle(tscycler)
        for i_vars in range(0,len(vars_)):
            var=vars_[i_vars]
            dplot=data['lidar']
            plotdt=mdates.date2num(dplot['time'].values)
            plotvar=dplot[var].values
            labelstr=var
            p,=ax.plot(plotdt,plotvar,label=labelstr)
            if len(markers[i_varsall])>0:
                p.set_marker(markers[i_varsall])
                p.set_linestyle('None')

        ax.grid('on')
        try:
            ax.set_ylim(ylims[i_varsall])
        except:
            print('no ylim defined for subplot '+str(i_varsall))
        ax.set_ylabel(ylabelstr[i_varsall])

        lines,labels=ax.get_legend_handles_labels()
        if len(lines)>8:
            ax.legend(fontsize='medium',ncol=int(len(lines)/3.))
        else:
            ax.legend(fontsize='medium',ncol=max([int(len(lines)/2.),1]))

    ax.set_xlim(xlim[0],xlim[1])
    loc,fmt=get_ticks(mdates.num2date(ax.get_xlim()[0]),mdates.num2date(ax.get_xlim()[1]))
    ax.xaxis.set_major_locator(loc)
    ax.xaxis.set_major_formatter(fmt)
    if drange==1:
        locminor=mdates.HourLocator()
        ax.xaxis.set_minor_locator(locminor)
    ax.set_xlabel('Time (UTC)')
    axs[0].set_title(datestring+', '+site+' '+lidar)

    logo=fig.add_axes([0.1,0.04,0.05,0.05],anchor='NW',zorder=1)
    im=plt.imread(os.path.join(campaigndir,'functions','noaa_web.png'))
    logo.imshow(im)
    logo.axis('off')

    plt.text(0.9, 0.05, 'ASSIST housekeeping'+\
                 '\nContact: Bianca.Adler@noaa.gov, Creation date: '+datetime.datetime.now(datetime.UTC).strftime('%d %b %Y %H:%M UTC'), \
         fontsize=14, transform=plt.gcf().transFigure,\
         ha='right')
    print('save as '+fout)
    plt.savefig(fout,format='png',dpi=dpi,bbox_inches='tight')


#%%
#get some global variables
campaigndir=os.getenv('campaigndir')
lidar=os.getenv('lidar')
lidarraw=os.getenv('lidarraw')
datestring=os.getenv('pdate')
site=os.getenv('site')
siteid=os.getenv('siteid')

fpathin = os.path.join(campaigndir,siteid,lidar)

fpathout=os.path.join(campaigndir,siteid,lidar,'quicklooks_lidar')
if not os.path.isdir(fpathout):
    os.makedirs(fpathout)

#time to process, for realtime processing several times a day on site this is current time
#statustime=datetime.datetime.now()
#for processing at NOAA lab,  day at midnight
statustime=datetime.datetime.strptime(datestring+'000000','%Y%m%d%H%M%S')+datetime.timedelta(days=1)
#for testing, create time yourself
#statustime=datetime.datetime.strptime(datestring+'030000','%Y%m%d%H%M%S')
#statustime=datetime.datetime.strptime(datestring+'185900','%Y%m%d%H%M%S')
#statustime=datetime.datetime.strptime(datestring+'235900','%Y%m%d%H%M%S')

yesterday =  (statustime-pd.to_timedelta('1D')).strftime('%Y%m%d')

# if yesterday is processed do full day
if yesterday == datestring:
    statustime=datetime.datetime.strptime(datestring+'000000','%Y%m%d%H%M%S')+datetime.timedelta(days=1)

plotinfo={}

#%%
#Lidar data 
if 'OUT' in locals():
   del OUT
listdiry = sorted(glob.glob(os.path.join(fpathin,'netcdf',siteid+'.'+lidarraw+'**'+yesterday+'*.cdf')))
listdirt = sorted(glob.glob(os.path.join(fpathin,'netcdf',siteid+'.'+lidarraw+'**'+datestring+'*.cdf')))
listdir = listdiry + listdirt
if len(listdir) == 0:
    print('no netcdf lidardata found for '+datestring)
    sys.exit('No lidar data found') 
for l,i in zip(listdir,range(len(listdir))):
    IN_=xr.open_dataset(l)
    if i ==0:
        IN=IN_
    else:
        IN=xr.concat((IN,IN_),dim='time')

IN['range']=IN['range']*1000 # convert to m 

IN=IN.sortby('time')
#filter rv
mask =IN['intensity']>1.005
IN['velocity']=IN['velocity'].where(mask)
#log for bs
IN['backscatter'].values=np.log10(IN['backscatter']).values

mask = np.abs((IN.elevation-90))<0.1
IN90=IN.sel(time=mask) #keeps scans as nan
#IN90=IN.where(mask) #keeps scans as nan
mask = np.abs((IN.elevation-70))<0.1
IN70=IN.sel(time=mask)
mask = np.abs((IN.elevation-10))<0.1
IN10=IN.sel(time=mask)
mask = np.abs((IN.elevation-60))<0.1
IN60=IN.sel(time=mask)
#compute variance
try:
    davar=calc_variance(IN90.dropna(dim='time',how='all').velocity)
    daskew=calc_skewness(IN90.dropna(dim='time',how='all').velocity)
    INstat=xr.merge([davar,daskew])
    stat_=True
except:
    stat_=False

for IN in [IN90,IN70,IN10,IN60]:
    if len(IN.time) > 0:
        #plot time height raw,vertical
        varsall=[['velocity'],['intensity'],['backscatter'],['spectral_width']]
        clabelstr=['Vertical velocity (m s$^{-1}$)','Intensity (SNR+1)','Backscatter (dB)','Spectral width (m s$^{-1}$)']
        clim=[[-3,3],[1,2],[-6,-3],[2,6]]
        ylim=[0,4000]
        fout=os.path.join(fpathout,datestring+'_'+lidar+'_fig01_time_height')

        set_1={'varsall': varsall,'clabelstr': clabelstr, \
                'ylim': ylim,'clim':clim,'fout':fout}

        drange_=[1]
        # drange_=[90]

        for drange in drange_:
            xlim=[mdates.date2num(statustime-datetime.timedelta(days=drange)),\
                  mdates.date2num(statustime)]
            set_1['xlim']=xlim
            data={}
            INplot=IN.sel(time=slice(pd.to_datetime(xlim[0],origin='unix',unit='D'),pd.to_datetime(xlim[1],origin='unix',unit='D')))    
            data['lidar']=INplot

               
            
            plot_time_height(data,set_1,drange)        



if stat_ == True:
    #plot time height stats
    varsall=[['varw'],['skeww']]
    clabelstr=['Vertical velocity variance (m$^2$ s$^{-2}$)','Vertical velocity skewness']
    clim=[[0,3],[-1,2]]
    ylim=[0,4000]
    fout=os.path.join(fpathout,datestring+'_'+lidar+'_fig02_time_height')

    set_1={'varsall': varsall,'clabelstr': clabelstr, \
            'ylim': ylim,'clim':clim,'fout':fout}

    drange_=[1]
    # drange_=[90]

    for drange in drange_:
        xlim=[mdates.date2num(statustime-datetime.timedelta(days=drange)),\
              mdates.date2num(statustime)]
        set_1['xlim']=xlim
        data={}
        INplot=INstat.sel(time=slice(pd.to_datetime(xlim[0],origin='unix',unit='D'),pd.to_datetime(xlim[1],origin='unix',unit='D')))    
        data['lidar']=INplot

           
        
        plot_time_height(data,set_1,drange)        


#plot time series angles
varsall=[['elevation'],['azimuth'],['pitch'],['roll']]
ylabelstr=['Elevation angle (deg)','Azimuth angle (deg)','Pitch angle (deg)','Roll angle (deg)']
ylim=[[0,92],[0,360],[],[]]
markers=['.','.','.','.']
fout=os.path.join(fpathout,datestring+'_'+lidar+'_fig03_timeseries')

set_1={'varsall': varsall,'ylabelstr': ylabelstr, \
        'ylims': ylim,'markers':markers,'fout':fout}
drange_=[1]

for drange in drange_:
    xlim=[mdates.date2num(statustime-datetime.timedelta(days=drange)),\
          mdates.date2num(statustime)]

    set_1['xlim']=xlim
    data={}
    INplot=IN.sel(time=slice(pd.to_datetime(xlim[0],origin='unix',unit='D'),pd.to_datetime(xlim[1],origin='unix',unit='D')))    
    data['lidar']=INplot

    plot_timeseries_nativeres(data,set_1,drange)
