# coding: utf-8
import gc
from netCDF4 import Dataset
import numpy as np


def area(lon1de, lat1de, Re = 6370000.):
    """
    Arguments
    ---------
    lon1de : array
        1d longitude edges in degrees
    lat1de : array
        1d latitude edges in degrees
    Re : scalar
        spherical earth radius in meters, default = 6370000. meters
    
    Returns
    -------
    area : array
        2D area for grid cells in squared units of Re 
        (typically m)
    """
    from numpy import pi, abs, sin, radians
    lat1 = radians(lat1de[1:])
    lat2 = radians(lat1de[:-1])
    lon1 = lon1de[1:]
    lon2 = lon1de[:-1]
    dlat = abs(sin(lat1) - sin(lat2))
    partialcap = 2 * pi * Re**2 * dlat
    dlon = abs(lon1 - lon2)
    out = partialcap[:, None] * dlon[None, :] / 360
    return out


def h2d(v, inst=True):
    h = v.sum((1, 2, 3))
    if inst:
        d = (v[1:] + v[:-1]).sum() / 2.
    else:
        d = v[:].sum()
    return d


hroot = '/work/ROMO/global/emissions/EPA/FSL2016/sectors'
hroot = '/work/EMIS/users/ktalgo/WO-144.3_hemi/Task06/output_diurnal/'
sroot = '/work/EMIS/users/ktalgo/WO-144.3_hemi/Task06/input_monthly_diurnal/'
# Merged root
mroot = 'anthro_merged/WO-144.3_hemi/Task06/output_diurnal/anthro_merged/'
sectors = sorted(open('sectors/sectors.txt', 'r').read().split())
# sectors = ['rail', 'rwc']
verbose = 1
spcs = ['NO', 'NO2', 'NOX', 'CO', 'SO2', 'PNH4']
inmolwts = dict(
    NO=46.,
    NOX=46.,
    NO2=46.,
    CO=28.,
    SO2=64.
)
inmolwth = inmolwts.copy()

outmolwt = dict(
    NO=14.,
    NO2=14.,
    NOX=14.
)

outspcs = dict(
    NO='N',
    NO2='N',
    NOX='N'
)

def printmrg(month, dt, outfile):
    hwdf = Dataset(
        # mroot + '2016fh_16j_0.1x0.1_merged_2016{0:02d}_{1}.ncf'.format(month, dt),
        mroot + '2016fh_16j_0.1x0.1_anthro_merged_2016{0:02d}_{1}.ncf'.format(month, dt),
        mode='rs'
    )
    lon = hwdf.variables['lon'][:].astype('d')
    lat = hwdf.variables['lat'][:].astype('d')
    a = area(lon, lat)
    for spc in spcs:
        if spc not in hwdf.variables:
            continue
        molwtis = inmolwts.get(spc, 1)
        molwtih = inmolwth.get(spc, 1)
        molwto = outmolwt.get(spc, molwtih)
        outs = outspcs.get(spc, spc)
        # kgNO2/m2/s - kgNO2
        hemco_Mg =  h2d(hwdf.variables[spc][:, :, 1:, 1:] * a * 3600 / 1e3, inst=False)
        hemco_MgS = hemco_Mg * molwto/molwtih
        hemco_tons = hemco_Mg * 1.102
        print(
            '2016%02d,%s,mHEMCO,%10s,%7.1f,%7s,%7.1f,%6s' %
            (month, dt, 'merged', hemco_tons, 'ton' + spc, hemco_MgS, 'Mg' + outs),
            file=outfile, flush=True
        )

def printspc(month, dt, sector, outfile):
    hwdf = Dataset(
      '{0}/{1}/2016fh_16j_0.1x0.1_{1}_2016{2:02d}_{3}.ncf'.format(hroot, sector, month, dt),
       mode='rs'
    )
    swdf = Dataset(
      '{0}/2016fh_16j_{1}_12US1_month_{2:02d}_{3}.ncf'.format(sroot, sector, month, dt),
      mode='rs'
    )
    lon = hwdf.variables['lon'][:].astype('d')
    lat = hwdf.variables['lat'][:].astype('d')
    a = area(lon, lat)
    for spc in spcs:
      if spc not in hwdf.variables:
          continue
      var = hwdf.variables[spc]
      if len(var.dimensions) > 4:
        print('Skipping', month, dt, sector, spc)
        print('Too many dimensions')
        continue
      molwtis = inmolwts.get(spc, 1)
      molwtih = inmolwth.get(spc, 1)
      molwto = outmolwt.get(spc, molwtih)
      outs = outspcs.get(spc, spc)

      smoke_tons = h2d(swdf.variables[spc][:], inst=True) # tonsNO2/day
      smoke_MgS = smoke_tons / 1.102 * molwto/molwtis
      # kgNO2/m2/s - kgNO2
      hemco_Mg =  h2d(hwdf.variables[spc][:, :, 1:, 1:] * a * 3600 / 1e3, inst=False)
      hemco_MgS = hemco_Mg * molwto/molwtih
      hemco_tons = hemco_Mg * 1.102
      #print(spc, molwto, molwtis, molwtih)
      print(
        '2016%02d,%s,SMOKE,%10s,%7.1f,%7s,%7.1f,%6s' %
        (month, dt, sector, smoke_tons, 'ton' + spc, smoke_MgS, 'Mg' + outs),
        file=outfile, flush=True
      )
      print(
        '2016%02d,%s,HEMCO,%10s,%7.1f,%7s,%7.1f,%6s' %
        (month, dt, sector, hemco_tons, 'ton' + spc, hemco_MgS, 'Mg' + outs),
        file=outfile, flush=True
      )
    hwdf.close()
    swdf.close()
    gc.collect()


outfile = open('emis.csv', 'w')
print('MONTH,DAYTYPE,SOURC,SECTOR,VALUE01,UNIT01,VALUE02,UNIT02', file=outfile, flush=True)
for month in [1]:
  for dt in ['weekday', 'weekend']:
    printmrg(month, dt, outfile)
  for sector in sectors:
    for dt in ['weekday', 'weekend']:
      if verbose:
        print(month, sector, dt)
      printspc(month, dt, sector, outfile)
