import unittest

from ConnectToMySQL import ConnectToMySQL
from NR08a.DataTypes import Activity, Allocation, DailyTemps, GrowthIndicators, Growth, GrowthScrappage, GrowthAltScrappage, Population, RetroFit, Regions, MonthlyAdjFactors, DailyAdjFactors
from NR08a.DataFile import DataFile

class Test(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        self.engine = ConnectToMySQL.create_engine('mysql.ini')
        self.con = self.engine.connect()
        self.con.execute('DROP DATABASE IF EXISTS test_nr_importer_unittest')
        self.con.execute('CREATE DATABASE IF NOT EXISTS test_nr_importer_unittest')
        self.con.close()
        self.engine = ConnectToMySQL.create_engine('mysql.ini', db='test_nr_importer_unittest')
        self.con = self.engine.connect()
        self.query = 'SELECT {} FROM {} LIMIT 1'

    @classmethod
    def tearDownClass(self):
        self.con.execute('DROP DATABASE IF EXISTS test_nr_importer_unittest')
        self.con.close()

    def test_Activity(self):
        DataFile('NR08a/tests/ACTIVITY.DAT').read(Activity).write(self.engine)
        self.assertEqual(self.con.execute(self.query.format(Activity._scc.name, Activity.table)).fetchone()[0], '2260001010')
        self.assertEqual(self.con.execute(self.query.format(Activity._description.name, Activity.table)).fetchone()[0], '2-Stroke Motorcycles: Off-Road')
        self.assertEqual(self.con.execute(self.query.format(Activity._region.name, Activity.table)).fetchone()[0], '')
        self.assertEqual(self.con.execute(self.query.format(Activity._minhp.name, Activity.table)).fetchone()[0], 0)
        self.assertEqual(self.con.execute(self.query.format(Activity._maxhp.name, Activity.table)).fetchone()[0], 9999)
        self.assertEqual(self.con.execute(self.query.format(Activity._lf.name, Activity.table)).fetchone()[0], 1.00)
        self.assertEqual(self.con.execute(self.query.format(Activity._units.name, Activity.table)).fetchone()[0], 'Hrs/Yr')
        self.assertEqual(self.con.execute(self.query.format(Activity._activity.name, Activity.table)).fetchone()[0], 1600)
        self.assertEqual(self.con.execute(self.query.format(Activity._ageadj.name, Activity.table)).fetchone()[0], 'DEFAULT')

    def test_Allocation(self):
        DataFile('NR08a/tests/AK_AIRTR.ALO').read(Allocation).write(self.engine)
        self.assertEqual(self.con.execute(self.query.format(Allocation._indicator.name, Allocation.table)).fetchone()[0], 'AIR')
        self.assertEqual(self.con.execute(self.query.format(Allocation._fips.name, Allocation.table)).fetchone()[0], '02000')
        self.assertEqual(self.con.execute(self.query.format(Allocation._subregion.name, Allocation.table)).fetchone()[0], '')
        self.assertEqual(self.con.execute(self.query.format(Allocation._year.name, Allocation.table)).fetchone()[0], 2002)
        self.assertEqual(self.con.execute(self.query.format(Allocation._allocation.name, Allocation.table)).fetchone()[0], 1305.879)
        self.assertEqual(self.con.execute(self.query.format(Allocation._description.name, Allocation.table)).fetchone()[0], 'AK')

    def test_DailyTemps(self):
        DataFile('NR08a/tests/DAYTMPRV.DAT').read(DailyTemps).write(self.engine)
        self.assertEqual(self.con.execute(self.query.format(DailyTemps._statefips.name, DailyTemps.table)).fetchone()[0], '0')
        self.assertEqual(self.con.execute(self.query.format(DailyTemps._state.name, DailyTemps.table)).fetchone()[0], 'US')
        self.assertEqual(self.con.execute(self.query.format(DailyTemps._parameter.name, DailyTemps.table)).fetchone()[0], 'TMAX')
        self.assertEqual(self.con.execute(self.query.format(DailyTemps._month.name, DailyTemps.table)).fetchone()[0], 1)
        self.assertEqual(self.con.execute(self.query.format(DailyTemps._day.name, DailyTemps.table)).fetchone()[0], 1)
        self.assertEqual(self.con.execute(self.query.format(DailyTemps._value.name, DailyTemps.table)).fetchone()[0], 46.6)

    def test_GrowthIndicators(self):
        DataFile('NR08a/tests/NATION.GRW').read(GrowthIndicators).write(self.engine)
        self.assertEqual(self.con.execute(self.query.format(GrowthIndicators._fips.name, GrowthIndicators.table)).fetchone()[0], '00000')
        self.assertEqual(self.con.execute(self.query.format(GrowthIndicators._indicator.name, GrowthIndicators.table)).fetchone()[0], '092')
        self.assertEqual(self.con.execute(self.query.format(GrowthIndicators._scc.name, GrowthIndicators.table)).fetchone()[0], '2260001000')
        self.assertEqual(self.con.execute(self.query.format(GrowthIndicators._minhp.name, GrowthIndicators.table)).fetchone()[0], 0)
        self.assertEqual(self.con.execute(self.query.format(GrowthIndicators._maxhp.name, GrowthIndicators.table)).fetchone()[0], 9999)
        self.assertEqual(self.con.execute(self.query.format(GrowthIndicators._techtype.name, GrowthIndicators.table)).fetchone()[0], 'ALL')
        self.assertEqual(self.con.execute(self.query.format(GrowthIndicators._description.name, GrowthIndicators.table)).fetchone()[0], '2-Stroke Recreational Vehicles')

    def test_Growth(self):
        DataFile('NR08a/tests/NATION.GRW').read(Growth).write(self.engine)
        self.assertEqual(self.con.execute(self.query.format(Growth._fips.name, Growth.table)).fetchone()[0], '00000')
        self.assertEqual(self.con.execute(self.query.format(Growth._subregion.name, Growth.table)).fetchone()[0], '')
        self.assertEqual(self.con.execute(self.query.format(Growth._year.name, Growth.table)).fetchone()[0], 1970)
        self.assertEqual(self.con.execute(self.query.format(Growth._indicator.name, Growth.table)).fetchone()[0], '095')
        self.assertEqual(self.con.execute(self.query.format(Growth._growth.name, Growth.table)).fetchone()[0], 1)

    def test_GrowthScrappage(self):
        DataFile('NR08a/tests/NATION.GRW').read(GrowthScrappage).write(self.engine)
        self.assertEqual(self.con.execute(self.query.format(GrowthScrappage._ulf.name, GrowthScrappage.table)).fetchone()[0], 0)
        self.assertEqual(self.con.execute(self.query.format(GrowthScrappage._scrapped.name, GrowthScrappage.table)).fetchone()[0], '0.00')

    def test_GrowthAltScrappage(self):
        DataFile('NR08a/tests/NATION.GRW').read(GrowthAltScrappage).write(self.engine)
        self.assertEqual(self.con.execute(self.query.format(GrowthAltScrappage._ulf.name, GrowthAltScrappage.table)).fetchone()[0], 0)
        self.assertEqual(self.con.execute(self.query.format(GrowthAltScrappage._equip.name, GrowthAltScrappage.table)).fetchone()[0], 'MOTORCYC')
        self.assertEqual(self.con.execute(self.query.format(GrowthAltScrappage._scrapped.name, GrowthAltScrappage.table)).fetchone()[0], '0')

    def test_Population(self):
        DataFile('NR08a/tests/AK.POP').read(Population).write(self.engine)
        self.assertEqual(self.con.execute(self.query.format(Population._fips.name, Population.table)).fetchone()[0], '02000')
        self.assertEqual(self.con.execute(self.query.format(Population._subregion.name, Population.table)).fetchone()[0], '')
        self.assertEqual(self.con.execute(self.query.format(Population._year.name, Population.table)).fetchone()[0], 1999)
        self.assertEqual(self.con.execute(self.query.format(Population._scc.name, Population.table)).fetchone()[0], '2260001020')
        self.assertEqual(self.con.execute(self.query.format(Population._description.name, Population.table)).fetchone()[0], '2-Str Snowmobiles')
        self.assertEqual(self.con.execute(self.query.format(Population._minhp.name, Population.table)).fetchone()[0], 1)
        self.assertEqual(self.con.execute(self.query.format(Population._maxhp.name, Population.table)).fetchone()[0], 3)
        self.assertEqual(self.con.execute(self.query.format(Population._avghp.name, Population.table)).fetchone()[0], '2.5')
        self.assertEqual(self.con.execute(self.query.format(Population._usefullife.name, Population.table)).fetchone()[0], 252)
        self.assertEqual(self.con.execute(self.query.format(Population._scrapdistequip.name, Population.table)).fetchone()[0], 'DEFAULT')
        self.assertEqual(self.con.execute(self.query.format(Population._pop.name, Population.table)).fetchone()[0], 241.2)

    def test_RetroFit(self):
        DataFile('NR08a/tests/retrotst.dat').read(RetroFit).write(self.engine)
        self.assertEqual(self.con.execute(self.query.format(RetroFit._yearstart.name, RetroFit.table)).fetchone()[0], '2008')
        self.assertEqual(self.con.execute(self.query.format(RetroFit._yearend.name, RetroFit.table)).fetchone()[0], '2009')
        self.assertEqual(self.con.execute(self.query.format(RetroFit._mystart.name, RetroFit.table)).fetchone()[0], '1996')
        self.assertEqual(self.con.execute(self.query.format(RetroFit._myend.name, RetroFit.table)).fetchone()[0], '1997')
        self.assertEqual(self.con.execute(self.query.format(RetroFit._scc.name, RetroFit.table)).fetchone()[0], '2270002000')
        self.assertEqual(self.con.execute(self.query.format(RetroFit._techtype.name, RetroFit.table)).fetchone()[0], 'ALL')
        self.assertEqual(self.con.execute(self.query.format(RetroFit._minhp.name, RetroFit.table)).fetchone()[0], 50)
        self.assertEqual(self.con.execute(self.query.format(RetroFit._maxhp.name, RetroFit.table)).fetchone()[0], 300)
        self.assertEqual(self.con.execute(self.query.format(RetroFit._fraction.name, RetroFit.table)).fetchone()[0], 0.05)
        self.assertEqual(self.con.execute(self.query.format(RetroFit._effectiveness.name, RetroFit.table)).fetchone()[0], 0.5)
        self.assertEqual(self.con.execute(self.query.format(RetroFit._pollutant.name, RetroFit.table)).fetchone()[0], 'PM')
        self.assertEqual(self.con.execute(self.query.format(RetroFit._identifier.name, RetroFit.table)).fetchone()[0], 1)
        self.assertEqual(self.con.execute(self.query.format(RetroFit._description.name, RetroFit.table)).fetchone()[0], 'funded by xxxx')

    def test_Regions(self):
        DataFile('NR08a/tests/SEASON.DAT').read(Regions).write(self.engine)
        self.assertEqual(self.con.execute(self.query.format(Regions._region.name, Regions.table)).fetchone()[0], 'US')
        self.assertEqual(self.con.execute(self.query.format(Regions._description.name, Regions.table)).fetchone()[0], 'National')
        self.assertEqual(self.con.execute(self.query.format(Regions._fips.name, Regions.table)).fetchone()[0], '00000')
        self.assertEqual(self.con.execute(self.query.format(Regions._name.name, Regions.table)).fetchone()[0], 'Nation')

    def test_MonthlyAdjFactors(self):
        DataFile('NR08a/tests/SEASON.DAT').read(MonthlyAdjFactors).write(self.engine)
        self.assertEqual(self.con.execute(self.query.format(MonthlyAdjFactors._subregion.name, MonthlyAdjFactors.table)).fetchone()[0], 'CW')
        self.assertEqual(self.con.execute(self.query.format(MonthlyAdjFactors._scc.name, MonthlyAdjFactors.table)).fetchone()[0], '2260000000')
        self.assertEqual(self.con.execute(self.query.format(MonthlyAdjFactors._description.name, MonthlyAdjFactors.table)).fetchone()[0], 'Average')
        self.assertEqual(self.con.execute(self.query.format(MonthlyAdjFactors._month.name, MonthlyAdjFactors.table)).fetchone()[0], 1)
        self.assertEqual(self.con.execute(self.query.format(MonthlyAdjFactors._adjfactor.name, MonthlyAdjFactors.table)).fetchone()[0], 0.081)

    def test_DailyAdjFactors(self):
        DataFile('NR08a/tests/SEASON.DAT').read(DailyAdjFactors).write(self.engine)                    
        self.assertEqual(self.con.execute(self.query.format(DailyAdjFactors._subregion.name, DailyAdjFactors.table)).fetchone()[0], '')
        self.assertEqual(self.con.execute(self.query.format(DailyAdjFactors._scc.name, DailyAdjFactors.table)).fetchone()[0], '2260001000')
        self.assertEqual(self.con.execute(self.query.format(DailyAdjFactors._description.name, DailyAdjFactors.table)).fetchone()[0], 'Recreational Equipment')
        self.assertEqual(self.con.execute(self.query.format(DailyAdjFactors._dayID.name, DailyAdjFactors.table)).fetchone()[0], 5)
        self.assertEqual(self.con.execute(self.query.format(DailyAdjFactors._adjfactor.name, DailyAdjFactors.table)).fetchone()[0], 0.1111111)



if __name__ == "__main__":
    #import sys;sys.argv = ['', 'Test.testName']
    unittest.main()