import unittest
import inspect
import os

from ConnectToMySQL import ConnectToMySQL
from NR08a.DataFile import DataFile
from NR08a.DataTypes import Activity, GrowthAltScrappage

class Test(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        with open('NR08a/tests/test.dat', 'w') as f:
            f.writelines('Headers\n')
            f.writelines('More Headers\n')
            f.writelines('/ACTIVITY/\n')
            f.writelines('2260001010 2-Stroke Motorcycles: Off-Road                             0 9999 1.00        Hrs/Yr       1600  DEFAULT\n')
            f.writelines('2260001020 2-Stroke Snowmobiles                                       0 9999 0.34        Hrs/Yr         57  DEFAULT\n')
            f.writelines('/END/\n')
            f.writelines('\n')
            f.writelines('/DATASECTION2/\n')
            f.writelines('4321 edcba 2004 hrs/yr 123.45678\n')
            f.writelines('8765 vwxyz 2005 hrs/yr 987.65432\n')
            f.writelines('56     jkl   06 hrs/yr   9.45\n')
            f.writelines('/END/\n')
            f.writelines('\n')            
        
        cls.engine = ConnectToMySQL.create_engine('mysql.ini')
        cls.engine.execute('DROP DATABASE IF EXISTS test_nr_importer_unittest')

    @classmethod
    def tearDown(cls):
        #os.remove('NR08a/tests/test.dat')
        cls.engine.execute('DROP DATABASE IF EXISTS test_nr_importer_unittest')

    def test_DataFileRead(self):
        gd = DataFile('NR08a/tests/test.dat').read(Activity)
        self.assertEqual(gd.data[0][0], '2260001010')
        self.assertEqual(gd.data[0][1], '2-Stroke Motorcycles: Off-Road')
        self.assertEqual(gd.data[0][2], '')
        self.assertEqual(gd.data[0][3], '0')
        self.assertEqual(gd.data[0][4], '9999')
        self.assertEqual(gd.data[0][5], '1.00')
        self.assertEqual(gd.data[0][6], 'Hrs/Yr')
        self.assertEqual(gd.data[0][7], '1600')
        self.assertEqual(gd.data[0][8], 'DEFAULT')
        
    def test_GenericDataWrite(self):
        gd = DataFile('NR08a/tests/test.dat').read(Activity)
        gd.write(self.engine, 'test_nr_importer_unittest')
        self.assertEqual(self.engine.execute('SELECT scc FROM test_nr_importer_unittest.activity LIMIT 1').fetchone()[0], '2260001010')
        self.assertEqual(self.engine.execute('SELECT description FROM test_nr_importer_unittest.activity LIMIT 1').fetchone()[0], '2-Stroke Motorcycles: Off-Road')
        self.assertEqual(self.engine.execute('SELECT region FROM test_nr_importer_unittest.activity LIMIT 1').fetchone()[0], '')
        self.assertEqual(self.engine.execute('SELECT minHP FROM test_nr_importer_unittest.activity LIMIT 1').fetchone()[0], 0)
        self.assertEqual(self.engine.execute('SELECT maxHP FROM test_nr_importer_unittest.activity LIMIT 1').fetchone()[0], 9999)
        self.assertEqual(self.engine.execute('SELECT lf FROM test_nr_importer_unittest.activity LIMIT 1').fetchone()[0], 1.00)
        self.assertEqual(self.engine.execute('SELECT units FROM test_nr_importer_unittest.activity LIMIT 1').fetchone()[0], 'Hrs/Yr')
        self.assertEqual(self.engine.execute('SELECT activity FROM test_nr_importer_unittest.activity LIMIT 1').fetchone()[0], 1600)
        self.assertEqual(self.engine.execute('SELECT ageadj FROM test_nr_importer_unittest.activity LIMIT 1').fetchone()[0], 'DEFAULT')
        
        # Test append function
        gd.write(self.engine, 'test_nr_importer_unittest')
        self.assertEqual(self.engine.execute('SELECT count(*) FROM test_nr_importer_unittest.activity').fetchone()[0], 2)
        gd.write(self.engine, 'test_nr_importer_unittest', append=True)
        self.assertEqual(self.engine.execute('SELECT count(*) FROM test_nr_importer_unittest.activity').fetchone()[0], 4)
        

    def test_GenericDataDBFuncs(self):
        gd = DataFile('NR08a/tests/test.dat').read(Activity)
        self.assertFalse(gd._db_exists(self.engine, 'test_nr_importer_unittest'))
        gd._create_db(self.engine, 'test_nr_importer_unittest')
        self.assertTrue(gd._db_exists(self.engine, 'test_nr_importer_unittest'))
        
        self.assertFalse(gd._table_exists(self.engine, 'activity', 'test_nr_importer_unittest'))
        gd._create_table(self.engine, 'activity', 'test_nr_importer_unittest')
        self.assertTrue(gd._table_exists(self.engine, 'activity', 'test_nr_importer_unittest'))


    def test_GrowthAltScrappage(self):
        gd = DataFile('NR08a/tests/test.dat').read(GrowthAltScrappage)
        gd.write(self.engine, 'test_nr_importer_unittest')


class Test_NoSchema(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        with open('NR08a/tests/test.dat', 'w') as f:
            f.writelines('Headers\n')
            f.writelines('More Headers\n')
            f.writelines('/ACTIVITY/\n')
            f.writelines('2260001010 2-Stroke Motorcycles: Off-Road                             0 9999 1.00        Hrs/Yr       1600  DEFAULT\n')
            f.writelines('2260001020 2-Stroke Snowmobiles                                       0 9999 0.34        Hrs/Yr         57  DEFAULT\n')
            f.writelines('/END/\n')
            f.writelines('\n')
            f.writelines('/DATASECTION2/\n')
            f.writelines('4321 edcba 2004 hrs/yr 123.45678\n')
            f.writelines('8765 vwxyz 2005 hrs/yr 987.65432\n')
            f.writelines('56     jkl   06 hrs/yr   9.45\n')
            f.writelines('/END/\n')
            f.writelines('\n')            
        
        cls.engine = ConnectToMySQL.create_engine('mysql.ini')
        cls.engine.execute('CREATE DATABASE IF NOT EXISTS test_nr_importer_unittest')
        cls.engine.execute('USE test_nr_importer_unittest')

    @classmethod
    def tearDownClass(cls):
        #os.remove('NR08a/tests/test.dat')
        cls.engine.execute('DROP DATABASE IF EXISTS test_nr_importer_unittest')

    def test_DataFileRead(self):
        gd = DataFile('NR08a/tests/test.dat').read(Activity)
        self.assertEqual(gd.data[0][0], '2260001010')
        self.assertEqual(gd.data[0][1], '2-Stroke Motorcycles: Off-Road')
        self.assertEqual(gd.data[0][2], '')
        self.assertEqual(gd.data[0][3], '0')
        self.assertEqual(gd.data[0][4], '9999')
        self.assertEqual(gd.data[0][5], '1.00')
        self.assertEqual(gd.data[0][6], 'Hrs/Yr')
        self.assertEqual(gd.data[0][7], '1600')
        self.assertEqual(gd.data[0][8], 'DEFAULT')
        
    def test_GenericDataWrite(self):
        gd = DataFile('NR08a/tests/test.dat').read(Activity)
        gd.write(self.engine)
        self.assertEqual(self.engine.execute('SELECT scc FROM activity LIMIT 1').fetchone()[0], '2260001010')
        self.assertEqual(self.engine.execute('SELECT description FROM activity LIMIT 1').fetchone()[0], '2-Stroke Motorcycles: Off-Road')
        self.assertEqual(self.engine.execute('SELECT region FROM activity LIMIT 1').fetchone()[0], '')
        self.assertEqual(self.engine.execute('SELECT minHP FROM activity LIMIT 1').fetchone()[0], 0)
        self.assertEqual(self.engine.execute('SELECT maxHP FROM activity LIMIT 1').fetchone()[0], 9999)
        self.assertEqual(self.engine.execute('SELECT lf FROM activity LIMIT 1').fetchone()[0], 1.00)
        self.assertEqual(self.engine.execute('SELECT units FROM activity LIMIT 1').fetchone()[0], 'Hrs/Yr')
        self.assertEqual(self.engine.execute('SELECT activity FROM activity LIMIT 1').fetchone()[0], 1600)
        self.assertEqual(self.engine.execute('SELECT ageadj FROM activity LIMIT 1').fetchone()[0], 'DEFAULT')
         
        # Test append function
        gd.write(self.engine)
        self.assertEqual(self.engine.execute('SELECT count(*) FROM activity').fetchone()[0], 2)
        gd.write(self.engine, append=True)
        self.assertEqual(self.engine.execute('SELECT count(*) FROM activity').fetchone()[0], 4)

    def test_GenericDataDBFuncs(self):
        gd = DataFile('NR08a/tests/test.dat').read(Activity)
        
        self.assertFalse(gd._table_exists(self.engine, 'activity'))
        gd._create_table(self.engine, 'activity')
        self.assertTrue(gd._table_exists(self.engine, 'activity'))
     

if __name__ == "__main__":
    #import sys;sys.argv = ['', 'Test.testName']
    unittest.main()