#!/usr/bin/python
# The "new1" program for approach 2.
# Creates temperature adjusted onroad netCDF files from unadjusted onroad netCDF files.
# Requires the PYTHON Numeric or NumPy package available at: http://numpy.scipy.org/
# James Beidler <beidler.james@epa.gov> 9/29/08 
# Update 5/27/09 to include functional tagging
# Update 7/9/09 by C. Allen - changed definition of dayPath so that this script works for different years

from Numeric import *
from Scientific.IO.NetCDF import *
import sys, csv, time, os 
from MA import *

def checkEV(evName):
	"""
	Checks if an environment variable is set.  If not, exits.  If it is, returns the variable.
	Takes the name of the environment variable.
	"""
	try: var = os.environ[evName]
	except:
		print "ERROR: Environment variable '%s' is not defined." %evName
		sys.exit(1)
	else: return var

baseYear = checkEV('BASE_YEAR') # Modeling year, used in definition of dayPath
dayPath = '/orchid/share/em_v4/inputs/ge_dat/smk_dates/' + baseYear  # Path to smk_merge_dates files for day mapping purposes
ignoredSpecies = ('PMFINE_72','TFLAG','PMC_72')   # List of species to ignore ie. - not pass through
speciesList = ('POC_72', 'PEC_72', 'NAPHTH_72')   # List of species to convert
# End User Config


def readMADJUST(inFile):
	"""
	Reads in MADJUST file and puts it into a dictionary.
	"""
	factorDict = {}
	for line in csv.reader(openFile(factorFile).readlines()):
		if (line[0][0] != '#') and (line[1] == runType): factorDict[int(line[2].strip('\"'))] = float(line[3].strip('\"'))
	return factorDict

def readDAYTABLE(dayPath, Year, Mon):
	"""
	Reads in the mrggrid date file for day conversion.
	"""
	dayFileName = 'smk_merge_dates_' + Year + Mon + '.txt'
	dayFileName = os.path.join(dayPath, dayFileName)
	dayTable = []
	for line in csv.reader(openFile(dayFileName).readlines()[1:]):
		dayTable.append(line)
	return dayTable
		
def openFile(fileName, accessType = 'r'):
	"""
	Wrapper that tests to see if a file is available for access.  If not it returns an error and exits.
	If it is available it returns an open file object.
	"""
	try: file = open(fileName, accessType)
	except:
		print "ERROR: %s not available for access." %fileName
		sys.exit(1)
	else: return file

def openNCF(fileName, accessType = 'r'):
	"""
	Wrapper that tests to see if a NetCDF file is available for access.  If not it returns an error and exits.
	If it is available it returns an open file object.
	"""
	try: file = NetCDFFile(fileName, accessType)
	except:
		print "ERROR: %s not available for access." %fileName
		sys.exit(1)
	else: return file

class onroadAdj(object):
	"""
	Onroad Adjustment class.
	"""

	def __init__(self, Year = "", Mon = "", Day = "", grid = "", sector = "", spec = "", case = "", imPath = "", \
		metPath = "", runType = ""):
		"""
		"""
		self.Year = Year
		self.Mon = Mon
		self.Day = Day
		self.grid = grid
		self.sector = sector
		self.spec = spec
		self.case = case
		self.imPath = imPath
		self.metPath = metPath
		self.runType = runType

	def threeDigit(self, x):
		"""
		Returns a three character number string from an integer.
		"""
		if len(str(x)) == 1: return '00' + str(x)
		elif len(str(x)) == 2: return '0' + str(x)
		else: return str(x)

	def conv2jul(self, year, month, day):
		"""
		Returns Julian date from year, month, and day.
		"""
		t = time.mktime((int(year), int(month), int(day), 0, 0, 0, 0, 0, 0))
		return int(str(time.gmtime(t)[0]) + self.threeDigit(time.gmtime(t)[7]))
				
	def nameInFile(self, inDate):
		"""
		Set the infile name based on the SMOKE conventions.
		"""
		inFileName = 'emis_mole_' + self.sector + '_%s_' %inDate + self.grid + '_' + self.spec + '_' + self.case + '.ncf'
		inPath = os.path.join(self.imPath, self.sector)
		return os.path.join(inPath, inFileName) 

	def nameOutFile(self, outDate):
		"""
		Set the outfile name based on the SMOKE conventions.
		"""
		outFileName = 'emis_mole_' + self.sector + '_adj_%s_' %outDate + self.grid + '_' + self.spec + '_' + self.case + '.ncf'
		outPath = os.path.join(self.imPath, self.sector + '_adj')
		return os.path.join(outPath, outFileName)
	 
	def outFileSettings(self, inFile, outFile):
		"""
		Defines dimensional and global attribute settings for the outfile.
		"""
		for dimName in inFile.dimensions.keys():
			if dimName != 'VAR': outFile.createDimension(dimName, inFile.dimensions[dimName])
		globalAtt = dir(inFile)
		for globalAttNum in range(len(globalAtt)):
			globalAttVal = getattr(inFile, globalAtt[globalAttNum])
			# List of attributes which produces objects that are defined automatically or manually
			ignoredAtt = ('close', 'createDimension', 'createVariable', 'flush', 'sync', 'SDATE', 'VAR-LIST', 'NVARS')
			if globalAtt[globalAttNum] not in ignoredAtt: setattr(outFile, globalAtt[globalAttNum], globalAttVal)
		setattr(outFile, 'SDATE', self.conv2jul(self.Year, self.Mon, self.Day))

	def convertSpeciesName(self, speciesName):
		"""
		Checks for species tags and returns a two part list of species name and tag number with the '_72' stripped when applicable.
		The tag number is blank if tagging is not detected.
		"""
		splitSpecies = speciesName.split('_')
		speciesBase = splitSpecies[0]
		if len(splitSpecies) > 1: tag = '_%s' %splitSpecies[len(splitSpecies) - 1]
		else: tag = ''
		if len(splitSpecies) == 2 and tag == '_72' and speciesBase + tag in speciesList: tag = ''
		if len(splitSpecies) == 2 and tag == '_72' and speciesBase + tag in ignoredSpecies: tag = ''
		if speciesBase == 'NAPHTH': speciesBase = 'NAPHTHALENE'   # Handle NAPHTH
		return (speciesBase, tag)

	def parseVarList(self, variableNames):	
		"""
		Define the species tuple for the output variable list global attribute in the out file.
		"""
		speciesTuple = []
		for species in variableNames:
			if species == 'TFLAG': continue
			species = ''.join(self.convertSpeciesName(species))
			if len(species) > 16: species = species[:16]
			while len(species) < 16: species = species + ' '
			speciesTuple = speciesTuple + [species]
		return speciesTuple

	def createTagList(self, speciesTuple):
		"""
		Creates a list of all the tags and returns as a list.
		"""
		tagList = ['',]
		for species in speciesTuple:
			if len(species.split('_')) == 2: 
				tag = species.split('_')[1].strip()
				if tag not in tagList: tagList.append(tag)
		return tagList
	
	def kelToFar(self, temp):
		"""
		Converts a Kelvin temperature to Fahrenheit
		"""
		return (((float(temp) - 273.15) * 9) / 5) + 32

	def calcFactorTable(self, metTemp, factorDict, hours, rows, cols):
		"""
		Calculate the factor from the temperature using interpolation.
		"""
		factorTable = ones([hours, 1, rows, cols], 'f')
		for hour in range(hours):
			for row in range(rows):
				for col in range(cols):
					temp2 = self.kelToFar(metTemp[hour][0][row][col])
					# Determine if the factor needs to be calculated
					if (temp2 < 72) and (temp2 > -20):
						# Calculate whole number temperature boundaries
						if (round(temp2) - temp2) > 0:
							temp2High = round(temp2)
							temp2Low = temp2High - 1
						elif (round(temp2) - temp2) < 0:
							temp2Low = round(temp2)
							temp2High = temp2Low + 1
						else:
							temp2Low = temp2
						 	temp2High = temp2

						# Fetch the factor for the temperature
						factorTable[hour,0,row,col] = factorDict[int(temp2Low)] + (int(temp2Low)-temp2)*(factorDict[int(temp2Low)]-factorDict[int(temp2High)])
					else: factorTable[hour,0,row,col] = 1
					if temp2 <= -20: factorTable[hour,0,row,col] = factorDict[-20]
		return factorTable

	def processSpecies(self, inFile, outFile, speciesName, factorTable):
		"""
		Process the input species and either adjust them or pass them through.
		"""
		# Fetch a variable from the in file
		species = inFile.variables[speciesName]
		dataIn = species.getValue()

		outSpeciesName = ''.join(self.convertSpeciesName(speciesName))

		speciesOut = outFile.createVariable(outSpeciesName, 'f', ('TSTEP','LAY','ROW','COL'))
		setattr(speciesOut, 'long_name', outSpeciesName)
		setattr(speciesOut, 'units', getattr(species, 'units'))
		setattr(speciesOut, 'var_desc', 'Model species ' + outSpeciesName)

		if speciesName.split('_')[0] + '_72' in speciesList:   # If the species is in the species list, then calculate the adjusted value.
			# Calculate the adjusted species
			dataOut = dataIn * factorTable
			speciesOut[:] = dataOut
		else:  # If the species is not in the species list, then pass it through.
			speciesOut[:] = dataIn
		outFile.sync()

	def createTFLAG(self, inFile, outFile):
		"""
		Create a new TFLAG and adjust to size for number of variables in the out file.
		"""
		TFLAG = inFile.variables['TFLAG']
		TFLAGIn = TFLAG.getValue()
		speciesOut = outFile.createVariable('TFLAG', 'i', ('TSTEP', 'VAR', 'DATE-TIME'))
		dataOut = zeros([TFLAG.shape[0],len(speciesTuple),TFLAG.shape[2]], 'i')
		setattr(speciesOut, 'long_name', getattr(TFLAG, 'long_name'))
		setattr(speciesOut, 'units', getattr(TFLAG, 'units'))
		setattr(speciesOut, 'var_desc', getattr(TFLAG, 'var_desc'))
		for tStep in range(TFLAG.shape[0]):
			for tVar in range(len(speciesTuple)):
				if tStep == (TFLAG.shape[0] - 1):  # Step to next day for the last hour
					# Step to next year if last hour of last day of year
					if (self.Mon == '12') and (self.Day == '31'): dataOut[tStep,tVar,0] = self.conv2jul(str(int(self.Year) + 1), '01', '01') 
					else: dataOut[tStep,tVar,0] = self.conv2jul(self.Year, self.Mon, self.Day) + 1
					dataOut[tStep,tVar,1] = 0	
				else:
					dataOut[tStep,tVar,0] = self.conv2jul(self.Year, self.Mon, self.Day)  # Set date for TFLAG timestep
					dataOut[tStep,tVar,1] = int(str(tStep) + '0000')    # Set time for TFLAG timestep
		speciesOut[:] = dataOut
		outFile.sync()

	def calcPMFINE(self, outFile, tag = ''):
		"""
		Calculate PMFINE and write to the out file.
		"""
		if tag != '': tag = '_%s' %tag
		POC = outFile.variables['POC' + tag]
		OTHER = outFile.variables['OTHER' + tag]

		# Get the values of POC and OTHER 
		POCIn = POC.getValue()
		OTHERIn = OTHER.getValue()

		# Create a species variable, format, and fill with zeroes
		speciesOut = outFile.createVariable('PMFINE' + tag, 'f', ('TSTEP','LAY','ROW','COL'))
		setattr(speciesOut, 'long_name', 'PMFINE' + tag) 
		setattr(speciesOut, 'units', getattr(POC, 'units'))
		setattr(speciesOut, 'var_desc', 'Model species PMFINE' + tag)
		
		# Do the calculation and write to the outfile
		POCIn.savespace(1)   # Hold array type
		POCIn = POCIn * 0.2  # Multiply the array with held type by the factor
		speciesOut[:] = OTHERIn + POCIn
		outFile.sync()

	def calcPMC(self, PMC, outFile, tag = ''):
		"""
		Calculate PMC and write to the outfile
		"""
		print 'PMC_%s' %tag
		if tag != '': tag = '_%s' %tag
		POC = outFile.variables['POC' + tag]
		PEC = outFile.variables['PEC' + tag]
		PSO4 = outFile.variables['PSO4' + tag]
		PNO3 = outFile.variables['PNO3' + tag]
		PMFINE = outFile.variables['PMFINE' + tag]

		# Get the values of PEC, PSO4, PNO3, PMFINE
		POCIn = POC.getValue()
		PECIn = PEC.getValue()
		PSO4In = PSO4.getValue()
		PNO3In = PNO3.getValue()
		PMFINEIn = PMFINE.getValue()

		# Do the calculation and write to the outfile
		dataOut = PECIn + POCIn + PSO4In + PNO3In + PMFINEIn
		dataOut.savespace(1)   # Hold array type
		dataOut = dataOut * 0.086  # Multiply the array with held type by the factor
		PMC = PMC + dataOut
		return PMC 

	def writePMC(self, PMC, outFile):
		"""
		Write PMC to the outfile.
		"""
		POC = outFile.variables['POC']
		# Create a species variable, format, and fill with zeroes
		speciesOut = outFile.createVariable('PMC', 'f', ('TSTEP','LAY','ROW','COL'))
		setattr(speciesOut, 'long_name', 'PMC')
		setattr(speciesOut, 'units', getattr(POC, 'units'))
		setattr(speciesOut, 'var_desc', 'Model species PMC')
		speciesOut[:] = PMC
		outFile.sync() 

### Main script

factorFile = checkEV('ADJ_FACS') # Path to MADJUST factor file
esdate = checkEV('ESDATE')  # Output date as YYYYMMDD
grid = checkEV('GRID')    # Grid
imPath = checkEV('IMD_ROOT')  # Intermediate path for the case
sector = checkEV('SECTOR')   # Sector
metPath = os.path.join(checkEV('MET_ROOT'), checkEV('GRID')) + '/mcip_out'   # Path to meteorology for the grid
runType = checkEV('RUN_TYPE')   # RUNEXH or STARTEXH
spec = checkEV('EMF_SPC')   # Speciation
case = checkEV('CASE')    # Case

Mon = esdate[4:6]
Day = esdate[6:8]
Year = esdate[:4]

factorDict = readMADJUST(factorFile)
dayTable = readDAYTABLE(dayPath, Year, Mon)

# Create new class object
adjust = onroadAdj(Year, Mon, Day, grid, sector, spec, case, imPath, metPath, runType)

# Set in date and out date and open the in and out files.
inDate = dayTable[int(Day) - 1][6].strip()
outDate = Year + Mon + Day
	
inFileName = adjust.nameInFile(inDate) 
outFileName = adjust.nameOutFile(outDate)

print "In File: " + inFileName
inFile = openNCF(inFileName, 'r')
print "Out File: " + outFileName
outFile = openNCF(outFileName, 'w')

metFileName = 'METCRO2D_' + Year[2:] + Mon + Day
metFileName = os.path.join(metPath, metFileName)
metFile = openNCF(metFileName, 'r')
print "Met File: " + metFileName

# Create the NCF settings for the outfile based on the infile
adjust.outFileSettings(inFile, outFile)

# Get the list of variables from the infile
variableNames = inFile.variables.keys()

# Define outfile attributes related to the number of variables
speciesTuple = adjust.parseVarList(variableNames)
setattr(outFile, 'VAR-LIST', ''.join(speciesTuple))
setattr(outFile, 'NVARS', len(speciesTuple))
outFile.createDimension('VAR', len(speciesTuple))

### Run main loop
factorTable = adjust.calcFactorTable(metFile.variables['TEMP2'].getValue(), factorDict, 25, inFile.dimensions['ROW'], inFile.dimensions['COL'])

for speciesName in variableNames:
	if speciesName.split('_')[0] in ignoredSpecies or speciesName.split('_')[0] + '_72' in ignoredSpecies: continue  # Skip ignored species
	adjust.processSpecies(inFile, outFile, speciesName, factorTable)

adjust.createTFLAG(inFile, outFile)
inFile.close()

# Set PMC array
POC = outFile.variables['POC']
PMC = zeros([POC.shape[0],POC.shape[1],POC.shape[2],POC.shape[3]], 'f')  # Create an array of the proper variable shape
for tag in adjust.createTagList(speciesTuple):
	adjust.calcPMFINE(outFile, tag)
	PMC = adjust.calcPMC(PMC, outFile, tag)

adjust.writePMC(PMC, outFile)
outFile.close()
