class ObservationMetaDataGeneratorTest(unittest.TestCase):

    longMessage = True

    @classmethod
    def tearDownClass(cls):
        sims_clean_up()

    def setUp(self):

        dbPath = os.path.join(getPackageDir('sims_data'),
                              'OpSimData/opsimblitz1_1133_sqlite.db')

        self.gen = ObservationMetaDataGenerator(database=dbPath,
                                                driver='sqlite')

    def tearDown(self):
        del self.gen

    def testExceptions(self):
        """
        Make sure that RuntimeErrors get raised when they should
        """
        gen = self.gen
        self.assertRaises(RuntimeError, gen.getObservationMetaData)
        self.assertRaises(RuntimeError, gen.getObservationMetaData, fieldRA=(1.0, 2.0, 3.0))

    def testQueryOnRanges(self):
        """
        Test that ObservationMetaData objects returned by queries of the form
        min < value < max
        are, in fact, within that range.

        Test when querying on both a single and two columns.
        """
        gen = self.gen

        # An list containing the bounds of our queries.
        # The order of the tuples must correspond to the order of
        # self.columnMapping in ObservationMetaDataGenerator.
        # This was generated with a separate script which printed
        # the median and maximum values of all of the quantities
        # in our test opsim database
        bounds = [('obsHistID', (5973, 7000)),
                  ('fieldRA', (np.degrees(1.370916), np.degrees(1.40))),
                  ('rawSeeing', (0.728562, 0.9)),
                  ('seeing', (0.7, 0.9)),
                  ('dist2Moon', (np.degrees(1.570307), np.degrees(1.9))),
                  ('expMJD', (49367.129396, 49370.0)),
                  ('m5', (22.815249, 23.0)),
                  ('skyBrightness', (19.017605, 19.5))]

        # test querying on a single column
        for line in bounds:
            tag = line[0]

            args = {tag: line[1]}

            results = gen.getObservationMetaData(**args)
            msg = "failed querying on %s" % tag
            self.assertGreater(len(results), 0, msg=msg)

            for obs in results:
                val = get_val_from_obs(tag, obs)
                self.assertGreaterEqual(val, line[1][0], msg=msg)
                self.assertLessEqual(val, line[1][1], msg=msg)

        # test querying on two columns at once
        for ix in range(len(bounds)):
            tag1 = bounds[ix][0]

            for jx in range(ix+1, len(bounds)):
                tag2 = bounds[jx][0]

                args = {}
                args[tag1] = bounds[ix][1]
                args[tag2] = bounds[jx][1]
                results = gen.getObservationMetaData(**args)
                msg = "failed querying %s and %s" % (tag1, tag2)
                self.assertGreater(len(results), 0, msg=msg)
                for obs in results:
                    v1 = get_val_from_obs(tag1, obs)
                    v2 = get_val_from_obs(tag2, obs)
                    self.assertGreaterEqual(v1, bounds[ix][1][0], msg=msg)
                    self.assertLessEqual(v1, bounds[ix][1][1], msg=msg)
                    self.assertGreaterEqual(v2, bounds[jx][1][0], msg=msg)
                    self.assertLessEqual(v2, bounds[jx][1][1], msg=msg)

    def testOpSimQueryOnRanges(self):
        """
        Test that getOpimRecords() returns correct results
        """
        bounds = [('obsHistID', (5973, 7000)),
                  ('fieldRA', (np.degrees(1.370916), np.degrees(1.40))),
                  ('rawSeeing', (0.728562, 0.9)),
                  ('seeing', (0.7, 0.9)),
                  ('dist2Moon', (np.degrees(1.570307), np.degrees(1.9))),
                  ('expMJD', (49367.129396, 49370.0)),
                  ('m5', (22.815249, 23.0)),
                  ('skyBrightness', (19.017605, 19.5))]

        for line in bounds:
            tag = line[0]
            args = {tag: line[1]}
            results = self.gen.getOpSimRecords(**args)
            msg = 'failed querying %s ' % tag
            self.assertGreater(len(results), 0)
            for rec in results:
                val = get_val_from_rec(tag, rec)
                self.assertGreaterEqual(val, line[1][0], msg=msg)
                self.assertLessEqual(val, line[1][1], msg=msg)

        for ix in range(len(bounds)):
            tag1 = bounds[ix][0]
            for jx in range(ix+1, len(bounds)):
                tag2 = bounds[jx][0]
                args = {tag1: bounds[ix][1], tag2: bounds[jx][1]}
                results = self.gen.getOpSimRecords(**args)
                msg = 'failed while querying %s and %s' % (tag1, tag2)
                self.assertGreater(len(results), 0)
                for rec in results:
                    v1 = get_val_from_rec(tag1, rec)
                    v2 = get_val_from_rec(tag2, rec)
                    self.assertGreaterEqual(v1, bounds[ix][1][0], msg=msg)
                    self.assertLessEqual(v1, bounds[ix][1][1], msg=msg)
                    self.assertGreaterEqual(v2, bounds[jx][1][0], msg=msg)
                    self.assertLessEqual(v2, bounds[jx][1][1], msg=msg)

    def testQueryExactValues(self):
        """
        Test that ObservationMetaData returned by a query demanding an exact value do,
        in fact, adhere to that requirement.
        """
        gen = self.gen

        bounds = [('obsHistID', 5973),
                  ('expDate', 1220779),
                  ('fieldRA', np.degrees(1.370916)),
                  ('fieldDec', np.degrees(-0.456238)),
                  ('moonRA', np.degrees(2.914132)),
                  ('moonDec', np.degrees(0.06305)),
                  ('rotSkyPos', np.degrees(3.116656)),
                  ('telescopeFilter', 'i'),
                  ('rawSeeing', 0.728562),
                  ('seeing', 0.88911899999999999),
                  ('sunAlt', np.degrees(-0.522905)),
                  ('moonAlt', np.degrees(0.099096)),
                  ('dist2Moon', np.degrees(1.570307)),
                  ('moonPhase', 52.2325),
                  ('expMJD', 49367.129396),
                  ('visitExpTime', 30.0),
                  ('m5', 22.815249),
                  ('skyBrightness', 19.017605)]

        for ii in range(len(bounds)):
            tag = bounds[ii][0]
            args = {}
            args[tag] = bounds[ii][1]
            results = gen.getObservationMetaData(**args)
            msg = 'failed querying %s' % tag
            self.assertGreater(len(results), 0, msg=msg)
            for obs in results:
                self.assertEqual(get_val_from_obs(tag, obs), bounds[ii][1], msg=msg)

    def testOpSimQueryExact(self):
        """
        Test that querying OpSim records for exact values works
        """

        bounds = [('obsHistID', 5973),
                  ('expDate', 1220779),
                  ('fieldRA', np.degrees(1.370916)),
                  ('fieldDec', np.degrees(-0.456238)),
                  ('moonRA', np.degrees(2.914132)),
                  ('moonDec', np.degrees(0.06305)),
                  ('rotSkyPos', np.degrees(3.116656)),
                  ('telescopeFilter', 'i'),
                  ('rawSeeing', 0.728562),
                  ('seeing', 0.88911899999999999),
                  ('sunAlt', np.degrees(-0.522905)),
                  ('moonAlt', np.degrees(0.099096)),
                  ('dist2Moon', np.degrees(1.570307)),
                  ('moonPhase', 52.2325),
                  ('expMJD', 49367.129396),
                  ('visitExpTime', 30.0),
                  ('m5', 22.815249),
                  ('skyBrightness', 19.017605)]

        for line in bounds:
            tag = line[0]
            args = {tag: line[1]}
            results = self.gen.getOpSimRecords(**args)
            msg = 'failed while querying %s' % tag
            self.assertGreater(len(results), 0, msg=msg)
            for rec in results:
                self.assertEqual(get_val_from_rec(tag, rec), line[1], msg=msg)

    def testPassInOtherQuery(self):
        """
        Test that you can pass OpSim pointings generated from another source
        into an ObservationMetaDataGenerator and still get ObservationMetaData
        out
        """

        pointing_list = self.gen.getOpSimRecords(fieldRA=np.degrees(1.370916))
        self.assertGreater(len(pointing_list), 1)
        local_gen = ObservationMetaDataGenerator()
        obs_list = local_gen.ObservationMetaDataFromPointingArray(pointing_list)
        self.assertEqual(len(obs_list), len(pointing_list))

        for pp in pointing_list:
            obs = local_gen.ObservationMetaDataFromPointing(pp)
            self.assertIsInstance(obs, ObservationMetaData)

    def testQueryLimit(self):
        """
        Test that, when we specify a limit on the number of ObservationMetaData we want returned,
        that limit is respected
        """
        gen = self.gen
        results = gen.getObservationMetaData(fieldRA=(np.degrees(1.370916), np.degrees(1.5348635)),
                                             limit=20)
        self.assertEqual(len(results), 20)

    def testQueryOnFilter(self):
        """
        Test that queries on the filter work.
        """
        gen = self.gen
        results = gen.getObservationMetaData(fieldRA=np.degrees(1.370916), telescopeFilter='i')
        ct = 0
        for obs_metadata in results:
            self.assertAlmostEqual(obs_metadata._pointingRA, 1.370916)
            self.assertEqual(obs_metadata.bandpass, 'i')
            ct += 1

        # Make sure that more than zero ObservationMetaData were returned
        self.assertGreater(ct, 0)

    def testObsMetaDataBounds(self):
        """
        Make sure that the bound specifications (i.e. a circle or a box on the
        sky) are correctly passed through to the resulting ObservationMetaData
        """

        gen = self.gen

        # Test a cirlce with a specified radius
        results = gen.getObservationMetaData(fieldRA=np.degrees(1.370916),
                                             telescopeFilter='i',
                                             boundLength=0.9)
        ct = 0
        for obs_metadata in results:
            self.assertTrue(isinstance(obs_metadata.bounds, CircleBounds),
                            msg='obs_metadata.bounds is not an intance of '
                            'CircleBounds')

            # include some wiggle room, in case ObservationMetaData needs to
            # adjust the boundLength to accommodate the transformation between
            # ICRS and observed coordinates
            self.assertGreaterEqual(obs_metadata.bounds.radiusdeg, 0.9)
            self.assertLessEqual(obs_metadata.bounds.radiusdeg, 0.95)

            self.assertAlmostEqual(obs_metadata.bounds.RA,
                                   np.radians(obs_metadata.pointingRA), 5)
            self.assertAlmostEqual(obs_metadata.bounds.DEC,
                                   np.radians(obs_metadata.pointingDec), 5)
            ct += 1

        # Make sure that some ObservationMetaData were tested
        self.assertGreater(ct, 0)

        boundLengthList = [1.2, (1.2, 0.6)]
        for boundLength in boundLengthList:
            results = gen.getObservationMetaData(fieldRA=np.degrees(1.370916),
                                                 telescopeFilter='i',
                                                 boundType='box',
                                                 boundLength=boundLength)

            if hasattr(boundLength, '__len__'):
                dra = boundLength[0]
                ddec = boundLength[1]
            else:
                dra = boundLength
                ddec = boundLength

            ct = 0
            for obs_metadata in results:
                RAdeg = obs_metadata.pointingRA
                DECdeg = obs_metadata.pointingDec
                self.assertTrue(isinstance(obs_metadata.bounds, BoxBounds),
                                msg='obs_metadata.bounds is not an instance of '
                                'BoxBounds')

                self.assertAlmostEqual(obs_metadata.bounds.RAminDeg, RAdeg-dra, 10)

                self.assertAlmostEqual(obs_metadata.bounds.RAmaxDeg, RAdeg+dra, 10)

                self.assertAlmostEqual(obs_metadata.bounds.DECminDeg, DECdeg-ddec, 10)

                self.assertAlmostEqual(obs_metadata.bounds.DECmaxDeg, DECdeg+ddec, 10)

                self.assertAlmostEqual(obs_metadata.bounds.RA, np.radians(obs_metadata.pointingRA), 5)
                self.assertAlmostEqual(obs_metadata.bounds.DEC, np.radians(obs_metadata.pointingDec), 5)

                ct += 1

            # Make sure that some ObservationMetaData were tested
            self.assertGreater(ct, 0)

    def testQueryOnNight(self):
        """
        Check that the ObservationMetaDataGenerator can query on the 'night'
        column in the OpSim summary table
        """

        # the test database goes from night=0 to night=28
        # corresponding to 49353.032079 <= mjd <= 49381.38533
        night0 = 49353.032079

        results = self.gen.getObservationMetaData(night=(11, 13))

        self.assertGreater(len(results), 1800)
        # there should be about 700 observations a night;
        # make sure we get at least 600

        for obs in results:
            self.assertGreaterEqual(obs.mjd.TAI, night0+11.0)
            self.assertLessEqual(obs.mjd.TAI, night0+13.5)
            # the 0.5 is there because the last observation on night 13 could be
            # 13 days and 8 hours after the first observation on night 0
            self.assertGreaterEqual(obs._OpsimMetaData['night'], 11)
            self.assertLessEqual(obs._OpsimMetaData['night'], 13)

        # query for an exact night
        results = self.gen.getObservationMetaData(night=15)

        self.assertGreater(len(results), 600)
        # there should be about 700 observations a night;
        # make sure we get at least 600

        for obs in results:
            self.assertEqual(obs._OpsimMetaData['night'], 15)
            self.assertGreaterEqual(obs.mjd.TAI, night0+14.9)
            self.assertLessEqual(obs.mjd.TAI, night0+15.9)

    def testCreationOfPhoSimCatalog(self):
        """
        Make sure that we can create PhoSim input catalogs using the returned
        ObservationMetaData. This test will just make sure that all of the
        expected header entries are there.
        """

        dbName = tempfile.mktemp(dir=ROOT, prefix='obsMetaDataGeneratorTest-', suffix='.db')
        makePhoSimTestDB(filename=dbName)
        bulgeDB = testGalaxyBulgeDBObj(driver='sqlite', database=dbName)
        gen = self.gen
        results = gen.getObservationMetaData(fieldRA=np.degrees(1.370916),
                                             telescopeFilter='i')
        testCat = PhoSimCatalogSersic2D(bulgeDB, obs_metadata=results[0])
        testCat.phoSimHeaderMap = {}
        with lsst.utils.tests.getTempFilePath('.txt') as catName:
            testCat.write_catalog(catName)

        if os.path.exists(dbName):
            os.unlink(dbName)

    def testCreationOfPhoSimCatalog_2(self):
        """
        Make sure that we can create PhoSim input catalogs using the returned
        ObservationMetaData.

        Use the actual DefaultPhoSimHeader map; make sure that opsim_version
        does not make it into the header.
        """

        dbName = tempfile.mktemp(dir=ROOT, prefix='obsMetaDataGeneratorTest-', suffix='.db')
        makePhoSimTestDB(filename=dbName)
        bulgeDB = testGalaxyBulgeDBObj(driver='sqlite', database=dbName)
        gen = self.gen
        results = gen.getObservationMetaData(fieldRA=np.degrees(1.370916),
                                             telescopeFilter='i')
        testCat = PhoSimCatalogSersic2D(bulgeDB, obs_metadata=results[0])
        testCat.phoSimHeaderMap = DefaultPhoSimHeaderMap
        with lsst.utils.tests.getTempFilePath('.txt') as catName:
            testCat.write_catalog(catName)
            ct_lines = 0
            with open(catName, 'r') as in_file:
                for line in in_file:
                    ct_lines += 1
                    self.assertNotIn('opsim_version', line)
                self.assertGreater(ct_lines, 10)  # check that some lines did get written

        if os.path.exists(dbName):
            os.unlink(dbName)

    def testCreationOfPhoSimCatalog_3(self):
        """
        Make sure that we can create PhoSim input catalogs using the returned
        ObservationMetaData.

        Test that an error is actually raised if we try to build a PhoSim catalog
        with a v3 header map using a v4 ObservationMetaData
        """

        dbName = tempfile.mktemp(dir=ROOT, prefix='obsMetaDataGeneratorTest-', suffix='.db')
        makePhoSimTestDB(filename=dbName)
        bulgeDB = testGalaxyBulgeDBObj(driver='sqlite', database=dbName)
        opsim_db = os.path.join(getPackageDir('sims_data'), 'OpSimData',
                                'astro-lsst-01_2014.db')
        assert os.path.isfile(opsim_db)
        gen = ObservationMetaDataGenerator(opsim_db, driver='sqlite')
        results = gen.getObservationMetaData(fieldRA=(70.0, 85.0),
                                             telescopeFilter='i')
        self.assertGreater(len(results), 0)
        testCat = PhoSimCatalogSersic2D(bulgeDB, obs_metadata=results[0])
        testCat.phoSimHeaderMap = DefaultPhoSimHeaderMap
        with lsst.utils.tests.getTempFilePath('.txt') as catName:
            with self.assertRaises(RuntimeError):
                testCat.write_catalog(catName)

        if os.path.exists(dbName):
            os.unlink(dbName)