#!/usr/local/bin/python2.6
# Aggregates a gridded emissions file by a factor of 3 squared
# This is designed speficially to move from a 12km grid to a 36km grid 
# 5/12/11 James Beidler <beidler.james@epa.gov>

from numpy import *
from scipy.io.netcdf import *
import sys, os

#if len(sys.argv) != 2:
#        print "You must provide an input file name"
#        print "agg.py <infile.py>"
#        sys.exit()

#inFileName = sys.argv[1]

# Dictionary of attributes for possible aggregate grids
gridDict = { '36US1': { 'name': '36US1_148X112', 'xorig': -2736000, 'yorig': -2088000, 'nrows': 112, 'ncols': 148 } }

# Hours in the file - generally 25, but put it in here just in case
hours = 25

class gridAtt(object):
	"""
	Provides the grid information in an object form
	"""

	def __init__(self, gridAbbrev):

		self.abbrev = gridAbbrev
		self.name = gridDict[gridAbbrev]['name']
		self.xcell = 36000  # Change if moving to a different size
		self.ycell = 36000  # ''
		self.xorig = gridDict[gridAbbrev]['xorig']
		self.yorig = gridDict[gridAbbrev]['yorig']
		self.nrows = gridDict[gridAbbrev]['nrows']
		self.ncols = gridDict[gridAbbrev]['ncols']

def openNCF(fileName, accessType = 'r'):
	"""
	Wrapper that tests to see if a NetCDF file is available for access.  If not it returns an error and exits.
	If it is available it returns an open file object.
	"""
	try: file = netcdf_file(fileName, accessType)
	except:
		print "ERROR: %s not available for access." %fileName
		sys.exit(1)
	else: return file

def checkEV(evName):
        """
        Checks if an environment variable is set.  If not, exits.  If it is, returns the variable.
        Takes the name of the environment variable.
        """
        try: var = os.environ[evName]
        except:
                print "ERROR: Environment variable '%s' is not defined." %evName
                sys.exit(1)
        else: return var

esdate = checkEV('ESDATE')  # Output date as YYYYMMDD
grid = checkEV('GRID')    # Grid
imPath = checkEV('IMD_ROOT') # Intermediate path for the case
spec = checkEV('EMF_SPC')   # Speciation
case = checkEV('CASE')    # Case
sector = checkEV('SECTOR') # sector

inFileName = '%s/%s/emis_mole_%s_%s_%s_%s_%s.ncf' %(imPath, sector, sector, esdate, grid, spec, case)

# Open the input file 
inFile = openNCF(inFileName, 'r')

# Set some initial variables based on the input file 
nlays = inFile.dimensions['LAY']
nvars = inFile.dimensions['VAR']
tstep = inFile.dimensions['TSTEP']
inNcols = inFile.dimensions['COL']
inNrows = inFile.dimensions['ROW']
speciesList = [ var.strip() for var in getattr(inFile, 'VAR-LIST').split(' ') if var.strip() != '' ]
inXorig = getattr(inFile, 'XORIG')
inYorig = getattr(inFile, 'YORIG')
inXcell = getattr(inFile, 'XCELL') # Our grids tend to be square, so only one cell size is needed.

# Loop over the grids to aggregate to
for gridAbbrev in ['36US1',]:
	print 'Aggregating %s to %s...' %(inFileName, gridAbbrev)

	# Set the output file name for the grid
	outFileName = '%s/%s/emis_mole_%s_%s_%s_%s_%s.ncf' %(imPath, sector, sector, esdate, gridAbbrev, spec, case)
	outFile = openNCF(outFileName, 'w')

	# Get the grid object for the given gridAbbrevs
	grid = gridAtt(gridAbbrev)

	# Set outfile dimensions 
	outDims = { 'VAR': int(nvars), 'TSTEP': tstep, 'DATE-TIME': 2, 'ROW': grid.nrows, 'COL': grid.ncols, 'LAY': int(nlays) }
	for dim, value in outDims.items(): outFile.createDimension(dim, value)

	# Set global attributes
	globalAtts = dir(inFile)
	# Ignore automatically created attributes, plus the grid specific ones and history because that doesn't play well for some reason.
	ignoredAtts = ('close', 'createDimension', 'createVariable', 'flush', 'sync', 'NROWS', 'NCOLS', 'XORIG', 'YORIG', 'GDNAM', 'HISTORY', 'XCELL', 'YCELL')
	attDict = {}
	for attName in [ globalAtts[attNum] for attNum in range(len(globalAtts)) ]:
		if attName in ignoredAtts: continue
		if attName != attName.upper(): continue  # Checks to see if the attribute is in caps.  All the ones that we create should be.
		attVal = getattr(inFile, attName)
		setattr( outFile, attName, attVal ) 
	setattr(outFile, 'NROWS', grid.nrows)
	setattr(outFile, 'NCOLS', grid.ncols)
	# IOAPI requires float 64 types for these attributes
	setattr(outFile, 'XORIG', float64(grid.xorig) )
	setattr(outFile, 'YORIG', float64(grid.yorig) )
	setattr(outFile, 'GDNAM', grid.name)
	setattr(outFile, 'XCELL', float64(grid.xcell) )
	setattr(outFile, 'YCELL', float64(grid.ycell) )
	setattr(outFile, 'HISTORY', ' ')

	# Calculate the raw difference for the distance between the origination points on the new and old grids
	xDist = (grid.xorig - inXorig)
	yDist = (grid.yorig - inYorig)

	# Calculate the starting points for copying on both grids
	# Keep in mind that the grid starts at the SW corner
	if xDist < 0: # If the old x origination is inside the new grid
		inCorig = 0 # Start at the first old grid cell 
		outCorig = abs( xDist / grid.xcell ) # Offset some new grid cells
	else: # If the old x origination is outside of equal to the orig on new grid
		inCorig = abs( xDist / inXcell )
		outCorig = 0
	if yDist < 0:
		inRorig = 0
		outRorig = abs( yDist / grid.ycell )
	else:
		inRorig = abs( yDist / inXcell )
		outRorig = 0

	# Calculate the end points to form boundary fields
	inXend = inXorig + ( inNcols * inXcell )
	inYend = inYorig + ( inNrows * inXcell )
	outXend = grid.xorig + ( grid.ncols * grid.xcell )
	outYend = grid.yorig + ( grid.nrows * grid.ycell )
	xDist = outXend - inXend
	yDist = outYend - inYend

	if xDist > 0: # If the endpoint of the old is inside the new
		
		inCend = inNcols
		outCend = grid.ncols - abs( xDist / grid.xcell ) # Calculate the last grid cell to aggregate

		if outCend > int(outCend): # Check to see if there is anything after the decimal ie. a partial column calculated.
			outCend = int(outCend) + 1  # If there is a partial column then add a column to be processed to hold the extra input data
	else:
		inCend = inNcols - abs( xDist / inXcell )
		outCend = grid.ncols
	if yDist > 0:
		inRend = inNrows
		outRend = grid.nrows - abs( yDist / grid.ycell )

		if outRend > int(outRend):
			outRend = int(outRend) + 1
	else:
		inRend = inNrows - abs( yDist / inXcell )
		outRend = grid.nrows


#	print "In start col: %s  End: %s    Out start: %s  End: %s" %(inCorig, inCend, outCorig, outCend) 
#	print "In start row: %s  End: %s    Out start: %s  End: %s" %(inRorig, inRend, outRorig, outRend)

	# Create column and row ranges for the old (in) grid and the new (out) grid
	inCrange = range(int(inCorig), int(inCend))
	inRrange = range(int(inRorig), int(inRend))
	outCrange = range(int(outCorig), int(outCend))
	outRrange = range(int(outRorig), int(outRend)) 

	# Loop through and subset each species variable
	for speciesName in speciesList:
		speciesIn = inFile.variables[speciesName]
		dataIn = speciesIn[:]
		dataOut = zeros([hours, nlays, grid.nrows, grid.ncols], 'f')
		speciesOut = outFile.createVariable(speciesName, 'f', ('TSTEP', 'LAY', 'ROW', 'COL'))
		speciesOut.long_name = speciesIn.long_name
		speciesOut.units = speciesIn.units
		speciesOut.var_desc = speciesIn.var_desc

		colCnt = 0  # Count the number of columns over from the start of the first column looped through on the new grid
		for col in outCrange:
			rowCnt = 0
			iCol = inCorig + (colCnt * 3) # Move over three old columns from the old grid origination point for each new column 

			for row in outRrange:
				iRow = inRorig + (rowCnt * 3)

				cellOut = zeros([hours, nlays], 'f')
				# Loop over three cells for each dimension on the old grid			
				for x in range(3):
					xBlk = iCol + x
					if xBlk not in inCrange: break

					for y in range(3):
						yBlk = iRow + y 
						if yBlk not in inRrange: break
						
#						print 'In (col,row): %s,%s   Out: %s,%s' %(xBlk, yBlk, col, row)

						cellOut = cellOut[:,:] + dataIn[:,:,yBlk,xBlk]

				dataOut[:,:,row,col] = cellOut

				rowCnt = rowCnt + 1

			colCnt = colCnt + 1

		speciesOut[:] = dataOut

	# Write the timestep flag 
	tflagOut = outFile.createVariable('TFLAG', 'i', ('TSTEP', 'VAR', 'DATE-TIME'))
	tflagOut.long_name = 'TFLAG' 
	tflagOut.units = '<YYYYDDD,HHMMSS>'
	tflagOut.var_desc = 'Timestep-valid flags:  (1) YYYYDD or (2) HHMMSS'
	tflagIn = inFile.variables['TFLAG'][:]
	tflagOut[:] = tflagIn

