예제 #1
0
    def testColumnNames(self):
        """
        Test the method that returns the names of columns in a table
        """
        dbobj = DBObject(driver=self.driver, database=self.db_name)
        names = dbobj.get_column_names('doubleTable')
        self.assertEqual(len(names), 3)
        self.assertIn('id', names)
        self.assertIn('sqrt', names)
        self.assertIn('log', names)

        names = dbobj.get_column_names('intTable')
        self.assertEqual(len(names), 3)
        self.assertIn('id', names)
        self.assertIn('twice', names)
        self.assertIn('thrice', names)

        names = dbobj.get_column_names()
        keys = ['doubleTable', 'intTable', 'junkTable']
        for kk in names:
            self.assertIn(kk, keys)

        self.assertEqual(len(names['doubleTable']), 3)
        self.assertEqual(len(names['intTable']), 3)
        self.assertIn('id', names['doubleTable'])
        self.assertIn('sqrt', names['doubleTable'])
        self.assertIn('log', names['doubleTable'])
        self.assertIn('id', names['intTable'])
        self.assertIn('twice', names['intTable'])
        self.assertIn('thrice', names['intTable'])
예제 #2
0
 def testTableNames(self):
     """
     Test the method that returns the names of tables in a database
     """
     dbobj = DBObject(driver=self.driver, database=self.db_name)
     names = dbobj.get_table_names()
     self.assertEqual(len(names), 3)
     self.assertIn('doubleTable', names)
     self.assertIn('intTable', names)
예제 #3
0
    def testPassingConnection(self):
        """
        Repeat the test from testJoin, but with a DBObject whose connection was passed
        directly from another DBObject, to make sure that passing a connection works
        """
        dbobj_base = DBObject(driver=self.driver, database=self.db_name)
        dbobj = DBObject(connection=dbobj_base.connection)
        query = 'SELECT doubleTable.id, intTable.id, doubleTable.log, intTable.thrice '
        query += 'FROM doubleTable, intTable WHERE doubleTable.id = intTable.id'
        results = dbobj.get_chunk_iterator(query, chunk_size=10)

        dtype = [
            ('id', int),
            ('id_1', int),
            ('log', float),
            ('thrice', int)]

        i = 0
        for chunk in results:
            if i < 90:
                self.assertEqual(len(chunk), 10)
            for row in chunk:
                self.assertEqual(2*(i+1), row[0])
                self.assertEqual(row[0], row[1])
                self.assertAlmostEqual(np.log(row[0]), row[2], 6)
                self.assertEqual(3*row[0], row[3])
                self.assertEqual(dtype, row.dtype)
                i += 1
        self.assertEqual(i, 99)
        # make sure that we found all the matches whe should have

        results = dbobj.execute_arbitrary(query)
        self.assertEqual(dtype, results.dtype)
        i = 0
        for row in results:
            self.assertEqual(2*(i+1), row[0])
            self.assertEqual(row[0], row[1])
            self.assertAlmostEqual(np.log(row[0]), row[2], 6)
            self.assertEqual(3*row[0], row[3])
            i += 1
        self.assertEqual(i, 99)
예제 #4
0
    def testMinMax(self):
        """
        Test queries on SQL functions by using the MIN and MAX functions
        """
        dbobj = DBObject(driver=self.driver, database=self.db_name)
        query = 'SELECT MAX(thrice), MIN(thrice) FROM intTable'
        results = dbobj.execute_arbitrary(query)
        self.assertEqual(results[0][0], 594)
        self.assertEqual(results[0][1], 0)

        dtype = [('MAXthrice', int), ('MINthrice', int)]
        self.assertEqual(results.dtype, dtype)
예제 #5
0
    def testReadOnlyFilter(self):
        """
        Test that the filters we placed on queries made with execute_aribtrary()
        work
        """
        dbobj = DBObject(driver=self.driver, database=self.db_name)
        controlQuery = 'SELECT doubleTable.id, intTable.id, doubleTable.log, intTable.thrice '
        controlQuery += 'FROM doubleTable, intTable WHERE doubleTable.id = intTable.id'
        dbobj.execute_arbitrary(controlQuery)

        # make sure that execute_arbitrary only accepts strings
        query = ['a', 'list']
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)

        # check that our filter catches different capitalization permutations of the
        # verboten commands
        query = 'DROP TABLE junkTable'
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower())
        query = 'DELETE FROM junkTable WHERE id=4'
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower())
        query = 'UPDATE junkTable SET sqrt=0.0, log=0.0 WHERE id=4'
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower())
        query = 'INSERT INTO junkTable VALUES (9999, 1.0, 1.0)'
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower())

        query = 'Drop Table junkTable'
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
        query = 'Delete FROM junkTable WHERE id=4'
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
        query = 'Update junkTable SET sqrt=0.0, log=0.0 WHERE id=4'
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
        query = 'Insert INTO junkTable VALUES (9999, 1.0, 1.0)'
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)

        query = 'dRoP TaBlE junkTable'
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
        query = 'dElEtE FROM junkTable WHERE id=4'
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
        query = 'uPdAtE junkTable SET sqrt=0.0, log=0.0 WHERE id=4'
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
        query = 'iNsErT INTO junkTable VALUES (9999, 1.0, 1.0)'
        self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
예제 #6
0
    def testValidationErrors(self):
        """ Test that appropriate errors and warnings are thrown when connecting
        """

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            DBObject('sqlite:///' + self.db_name)
            assert len(w) == 1

        # missing database
        self.assertRaises(AttributeError, DBObject, driver=self.driver)
        # missing driver
        self.assertRaises(AttributeError, DBObject, database=self.db_name)
        # missing host
        self.assertRaises(AttributeError, DBObject, driver='mssql+pymssql')
        # missing port
        self.assertRaises(AttributeError, DBObject, driver='mssql+pymssql', host='localhost')
def get_table_mins(table_tag, dmag_dict,_out_dir):
    db = DBObject(database='LSSTCATSIM', host='fatboy.phys.washington.edu',
                  port=1433, driver='mssql+pymssql')

    query = 'SELECT '
    query += 'htmid, simobjid, varParamStr '
    query += 'FROM stars_obafgk_part_%s' % table_tag

    dtype = np.dtype([('htmid', int), ('simobjid', int),
                      ('varParamStr', str, 200)])

    data_iter = get_arbitrary_chunk_iterator(query, dtype=dtype, chunk_size=10000)
    with open(os.path.join(_out_dir,'dmag_%s.txt' % table_tag), 'w') as out_file:
        for chunk in data_iter:
            for star in chunk:
                param_dict = json.loads(star['varParamStr'])
                lc_id = param_dict['p']['lc']
                out_file.write('%d %d %d\n' % (star['htmid'], star['simobjid'], dmag_dict[lc_id]))
예제 #8
0
    def testSingleTableQuery(self):
        """
        Test a query on a single table (using chunk iterator)
        """
        dbobj = DBObject(driver=self.driver, database=self.db_name)
        query = 'SELECT id, sqrt FROM doubleTable'
        results = dbobj.get_chunk_iterator(query)

        dtype = [('id', int),
                 ('sqrt', float)]

        i = 1
        for chunk in results:
            for row in chunk:
                self.assertEqual(row[0], i)
                self.assertAlmostEqual(row[1], np.sqrt(i))
                self.assertEqual(dtype, row.dtype)
                i += 1

        self.assertEqual(i, 201)
예제 #9
0
    def testDtype(self):
        """
        Test that passing dtype to a query works

        (also test q query on a single table using .execute_arbitrary() directly
        """
        dbobj = DBObject(driver=self.driver, database=self.db_name)
        query = 'SELECT id, log FROM doubleTable'
        dtype = [('id', int), ('log', float)]
        results = dbobj.execute_arbitrary(query, dtype = dtype)

        self.assertEqual(results.dtype, dtype)
        for xx in results:
            self.assertAlmostEqual(np.log(xx[0]), xx[1], 6)

        self.assertEqual(len(results), 200)

        results = dbobj.get_chunk_iterator(query, chunk_size=10, dtype=dtype)
        next(results)
        for chunk in results:
            self.assertEqual(chunk.dtype, dtype)
예제 #10
0
    def testJoin(self):
        """
        Test a join
        """
        dbobj = DBObject(driver=self.driver, database=self.db_name)
        query = 'SELECT doubleTable.id, intTable.id, doubleTable.log, intTable.thrice '
        query += 'FROM doubleTable, intTable WHERE doubleTable.id = intTable.id'
        results = dbobj.get_chunk_iterator(query, chunk_size=10)

        dtype = [
            ('id', int),
            ('id_1', int),
            ('log', float),
            ('thrice', int)]

        i = 0
        for chunk in results:
            if i < 90:
                self.assertEqual(len(chunk), 10)
            for row in chunk:
                self.assertEqual(2*(i+1), row[0])
                self.assertEqual(row[0], row[1])
                self.assertAlmostEqual(np.log(row[0]), row[2], 6)
                self.assertEqual(3*row[0], row[3])
                self.assertEqual(dtype, row.dtype)
                i += 1
        self.assertEqual(i, 99)
        # make sure that we found all the matches whe should have

        results = dbobj.execute_arbitrary(query)
        self.assertEqual(dtype, results.dtype)
        i = 0
        for row in results:
            self.assertEqual(2*(i+1), row[0])
            self.assertEqual(row[0], row[1])
            self.assertAlmostEqual(np.log(row[0]), row[2], 6)
            self.assertEqual(3*row[0], row[3])
            i += 1
        self.assertEqual(i, 99)
예제 #11
0
from lsst.sims.catalogs.db import DBObject
db = DBObject(database='LSST',
              host='localhost',
              port=51432,
              driver='mssql+pymssql')

import numpy as np

rng = np.random.RandomState(7153323)
dtype = np.dtype([('id', int), ('htmid', int)])
query = 'SELECT id, htmid FROM galaxy'

data_iter = db.get_arbitrary_chunk_iterator(query,
                                            dtype=dtype,
                                            chunk_size=10000)

with open('galaxy_sne_flag.txt', 'w') as out_file:
    for chunk in data_iter:
        flag_vals = rng.randint(0, 10, size=len(chunk))
        for hh, ii, ff in zip(chunk['htmid'], chunk['id'], flag_vals):
            out_file.write('%d;%d;%d\n' % (hh, ii, ff))
from lsst.sims.catalogs.db import DBObject

if args.use_tunnel:
    from lsst.sims.catUtils.baseCatalogModel import BaseCatalogConfig
    config = BaseCatalogConfig
    host = config.host
    port = config.port
    database = config.database
    driver = config.driver
else:
    host = 'fatboy.phys.washington.edu'
    port = 1433
    database = 'LSSTCATSIM'
    driver = 'mssql+pymssql'

db = DBObject(database=database, host=host, port=port, driver=driver)

if args.limit is None:
    query = 'SELECT '
else:
    query = 'SELECT TOP %d ' % args.limit

query += 'htmid, simobjid, gal_l, gal_b, parallax, sdssr, sdssi, sdssz '
query += 'FROM %s ' % args.table

from mdwarf_utils import activity_type_from_color_z
from mdwarf_utils import xyz_from_lon_lat_px
import os
import numpy as np
import time
예제 #13
0
    def testDetectDtype(self):
        """
        Test that DBObject.execute_arbitrary can correctly detect the dtypes
        of the rows it is returning
        """
        db_name = os.path.join(self.scratch_dir, 'testDBObject_dtype_DB.db')
        if os.path.exists(db_name):
            os.unlink(db_name)

        conn = sqlite3.connect(db_name)
        c = conn.cursor()
        try:
            c.execute('''CREATE TABLE testTable (id int, val real, sentence int)''')
            conn.commit()
        except:
            raise RuntimeError("Error creating database.")

        for ii in range(10):
            cmd = '''INSERT INTO testTable VALUES (%d, %.5f, %s)''' % (ii, 5.234*ii, "'this, has; punctuation'")
            c.execute(cmd)

        conn.commit()
        conn.close()

        db = DBObject(database=db_name, driver='sqlite')
        query = 'SELECT id, val, sentence FROM testTable WHERE id%2 = 0'
        results = db.execute_arbitrary(query)

        np.testing.assert_array_equal(results['id'], np.arange(0,9,2,dtype=int))
        np.testing.assert_array_almost_equal(results['val'], 5.234*np.arange(0,9,2), decimal=5)
        for sentence in results['sentence']:
            self.assertEqual(sentence, 'this, has; punctuation')

        self.assertEqual(str(results.dtype['id']), 'int64')
        self.assertEqual(str(results.dtype['val']), 'float64')
        if sys.version_info.major == 2:
            self.assertEqual(str(results.dtype['sentence']), '|S22')
        else:
            self.assertEqual(str(results.dtype['sentence']), '<U22')
        self.assertEqual(len(results.dtype), 3)

        # now test that it works when getting a ChunkIterator
        chunk_iter = db.get_arbitrary_chunk_iterator(query, chunk_size=3)
        ct = 0
        for chunk in chunk_iter:

            self.assertEqual(str(chunk.dtype['id']), 'int64')
            self.assertEqual(str(chunk.dtype['val']), 'float64')
            if sys.version_info.major == 2:
                self.assertEqual(str(results.dtype['sentence']), '|S22')
            else:
                self.assertEqual(str(results.dtype['sentence']), '<U22')
            self.assertEqual(len(chunk.dtype), 3)

            for line in chunk:
                ct += 1
                self.assertEqual(line['sentence'], 'this, has; punctuation')
                self.assertAlmostEqual(line['val'], line['id']*5.234, 5)
                self.assertEqual(line['id']%2, 0)

        self.assertEqual(ct, 5)

        # test that doing a different query does not spoil dtype detection
        query = 'SELECT id, sentence FROM testTable WHERE id%2 = 0'
        results = db.execute_arbitrary(query)
        self.assertGreater(len(results), 0)
        self.assertEqual(len(results.dtype.names), 2)
        self.assertEqual(str(results.dtype['id']), 'int64')
        if sys.version_info.major == 2:
            self.assertEqual(str(results.dtype['sentence']), '|S22')
        else:
            self.assertEqual(str(results.dtype['sentence']), '<U22')

        query = 'SELECT id, val, sentence FROM testTable WHERE id%2 = 0'
        chunk_iter = db.get_arbitrary_chunk_iterator(query, chunk_size=3)
        ct = 0
        for chunk in chunk_iter:

            self.assertEqual(str(chunk.dtype['id']), 'int64')
            self.assertEqual(str(chunk.dtype['val']), 'float64')
            if sys.version_info.major == 2:
                self.assertEqual(str(results.dtype['sentence']), '|S22')
            else:
                self.assertEqual(str(results.dtype['sentence']), '<U22')
            self.assertEqual(len(chunk.dtype), 3)

            for line in chunk:
                ct += 1
                self.assertEqual(line['sentence'], 'this, has; punctuation')
                self.assertAlmostEqual(line['val'], line['id']*5.234, 5)
                self.assertEqual(line['id']%2, 0)

        self.assertEqual(ct, 5)

        if os.path.exists(db_name):
            os.unlink(db_name)
def get_pointing_htmid(pointing_dir,
                       opsim_db_name,
                       ra_colname='descDitheredRA',
                       dec_colname='descDitheredDec',
                       rottel_colname='descDitheredRotTelPos'):
    """
    For a list of OpSim pointings, find dicts mapping those pointings to:
    - The trixels filling the pointings
    - The MJDs of the pointings
    - The telescope filters of the pointings

    Parameters
    ----------
    pointing_dir contains a series of files that are two columns: obshistid, mjd.
    The files must each have 'visits' in their name.  These specify the pointings
    for which we are assembling data.  See:
        https://github.com/LSSTDESC/DC2_Repo/tree/master/data/Run1.1
    for an example.

    opsim_db_name is the path to the OpSim database from which to take those pointings

    ra_colname is the column used for RA of the pointing (default:
    descDitheredRA)

    dec_colname is the column used for the Dec of the pointing (default:
    descDitheredDec)

    rottel_colname is the column used for the rotTelPos of the pointing
    (default: desckDitheredRotTelPos')

    Returns
    -------
    htmid_bound_dict -- a dict keyed on ObsHistID.  Values are the list of trixels filling
    the OpSim pointing, as returned by lsst.sims.utils.HalfSpace.findAllTrixels

    mjd_dict -- a dict keyed on ObsHistID.  Values are the MJD(TAI) of the OpSim pointings.

    filter_dict -- a dict keyed on ObsHistID.  Values are the 'ugrizy' filter of the
    OpSim pointings.

    obsmd_dict -- a dict keyed on ObsHistID.  The values are ObservationMetaData with the
    RA, Dec, MJD, and rotSkyPos of the pointings (for use in focal plane geometry calculations)
    """

    radius = 2.0  # field of view of a pointing in degrees

    if not os.path.isfile(opsim_db_name):
        raise RuntimeError("%s is not a valid file name" % opsim_db_name)

    if not os.path.isdir(pointing_dir):
        raise RuntimeError("%s is not a valid dir name" % pointing_dir)

    dtype = np.dtype([('obshistid', int), ('mjd', float)])
    obs_data = None
    for file_name in os.listdir(pointing_dir):
        if 'visits' in file_name:
            full_name = os.path.join(pointing_dir, file_name)
            data = np.genfromtxt(full_name, dtype=dtype)
            if obs_data is None:
                obs_data = data['obshistid']
            else:
                obs_data = np.concatenate((obs_data, data['obshistid']),
                                          axis=0)

    obs_data = np.sort(obs_data)

    db = DBObject(opsim_db_name, driver='sqlite')
    dtype = np.dtype([('obshistid', int), ('mjd', float), ('band', str, 1),
                      ('ra', float), ('dec', float), ('rotTelPos', float)])

    htmid_bound_dict = {}
    mjd_dict = {}
    filter_dict = {}
    obsmd_dict = {}

    d_obs = len(obs_data) // 5
    for i_start in range(0, len(obs_data), d_obs):
        i_end = i_start + d_obs
        if len(obs_data) - i_start < d_obs:
            i_end = len(obs_data)

        subset = obs_data[i_start:i_end]

        query = 'SELECT obsHistId, expMJD, filter,'
        query += ' %s, %s, %s FROM Summary' % (ra_colname, dec_colname,
                                               rottel_colname)
        query += ' WHERE obsHistID BETWEEN %d and %e' % (subset.min(),
                                                         subset.max())
        query += ' GROUP BY obsHistID'

        results = db.execute_arbitrary(query, dtype=dtype)

        for ii in range(len(results)):
            obshistid = results['obshistid'][ii]
            if obshistid not in obs_data:
                continue

            hs = halfSpaceFromRaDec(np.degrees(results['ra'][ii]),
                                    np.degrees(results['dec'][ii]), radius)

            trixel_bounds = hs.findAllTrixels(_truth_trixel_level)
            htmid_bound_dict[obshistid] = trixel_bounds
            mjd_dict[obshistid] = results['mjd'][ii]
            filter_dict[obshistid] = results['band'][ii]
            obs_md = ObservationMetaData(
                pointingRA=np.degrees(results['ra'][ii]),
                pointingDec=np.degrees(results['dec'][ii]),
                mjd=results['mjd'][ii])

            rotsky_rad = _getRotSkyPos(results['ra'][ii], results['dec'][ii],
                                       obs_md, results['rotTelPos'][ii])
            obsmd_dict[obshistid] = ObservationMetaData(
                pointingRA=np.degrees(results['ra'][ii]),
                pointingDec=np.degrees(results['dec'][ii]),
                mjd=results['mjd'][ii],
                rotSkyPos=np.degrees(rotsky_rad))
    assert len(obs_data) == len(htmid_bound_dict)

    return htmid_bound_dict, mjd_dict, filter_dict, obsmd_dict
예제 #15
0
    def _postprocess_results(self, master_chunk):
        """
        query the database specified by agn_params_db to
        find the AGN varParamStr associated with each AGN
        """

        if self.agn_objid is None:
            gid_name = 'galaxy_id'
            varpar_name = 'varParamStr'
            magnorm_name = 'magNorm'
        else:
            gid_name = self.agn_objid + '_' + 'galaxy_id'
            varpar_name = self.agn_objid + '_' + 'varParamStr'
            magnorm_name = self.agn_objid + '_' + 'magNorm'

        if self.agn_params_db is None:
            return(master_chunk)

        if not os.path.exists(self.agn_params_db):
            raise RuntimeError('\n%s\n\ndoes not exist' % self.agn_params_db)

        if not hasattr(self, '_agn_dbo'):
            self._agn_dbo = DBObject(database=self.agn_params_db,
                                     driver='sqlite')

            self._agn_dtype = np.dtype([('galaxy_id', int),
                                        ('magNorm', float),
                                        ('varParamStr', str, 500)])

        gid_arr = master_chunk[gid_name].astype(float)

        gid_min = np.nanmin(gid_arr)
        gid_max = np.nanmax(gid_arr)

        query = 'SELECT galaxy_id, magNorm, varParamStr '
        query += 'FROM agn_params '
        query += 'WHERE galaxy_id BETWEEN %d AND %d ' % (gid_min, gid_max)
        query += 'ORDER BY galaxy_id'

        agn_data_iter = self._agn_dbo.get_arbitrary_chunk_iterator(query,
                                                        dtype=self._agn_dtype,
                                                        chunk_size=1000000)


        m_sorted_dex = np.argsort(gid_arr)
        m_sorted_id = gid_arr[m_sorted_dex]
        for agn_chunk in agn_data_iter:

            # find the indices of the elements in master_chunk
            # that correspond to elements in agn_chunk
            m_elements = np.in1d(m_sorted_id, agn_chunk['galaxy_id'])
            m_dex = m_sorted_dex[m_elements]

            # find the indices of the elements in agn_chunk
            # that correspond to elements in master_chunk
            a_dex = np.in1d(agn_chunk['galaxy_id'], m_sorted_id)

            # make sure we have matched elements correctly
            np.testing.assert_array_equal(agn_chunk['galaxy_id'][a_dex],
                                          master_chunk[gid_name][m_dex])

            if varpar_name in master_chunk.dtype.names:
                master_chunk[varpar_name][m_dex] = agn_chunk['varParamStr'][a_dex]

            if magnorm_name in master_chunk.dtype.names:
                master_chunk[magnorm_name][m_dex] = agn_chunk['magNorm'][a_dex]

        return self._final_pass(master_chunk)
예제 #16
0
    def __init__(self, database=None, driver='sqlite', host=None, port=None):
        """
        Constructor for the class

        Parameters
        ----------
        database : string
            absolute path to the output of the OpSim database
        driver : string, optional, defaults to 'sqlite'
            driver/dialect for the SQL database
        host : hostName, optional, defaults to None,
            hostName, None is good for a local database
        port : hostName, optional, defaults to None,
            port, None is good for a local database

        Returns
        ------
        Instance of the ObserverMetaDataGenerator class

        ..notes : For testing purposes a small OpSim database is available at
        `os.path.join(getPackageDir('sims_data'), 'OpSimData/opsimblitz1_1133_sqlite.db')`
        """
        self._opsim_version = None
        self.driver = driver
        self.host = host
        self.port = port
        self.database = database
        self._seeing_column = 'FWHMeff'

        if self.database is None:
            return

        if not os.path.isfile(self.database):
            raise RuntimeError('%s is not a file' % self.database)

        self.opsimdb = DBObject(driver=self.driver,
                                database=self.database,
                                host=self.host,
                                port=self.port)

        # 27 January 2016
        # Detect whether the OpSim db you are connecting to uses 'finSeeing'
        # as its seeing column (deprecated), or FWHMeff, which is the modern
        # standard

        list_of_tables = self.opsimdb.get_table_names()
        if 'Summary' in list_of_tables:
            self._opsim_version = 3
        else:
            self._opsim_version = 4

        self._summary_columns = self.opsimdb.get_column_names(self.table_name)
        self._set_seeing_column(self._summary_columns)

        # Set up self.dtype containg the dtype of the recarray we expect back from the SQL query.
        # Also setup baseQuery which is just the SELECT clause of the SQL query
        #
        # self.active_columns will be a list containing the subset of OpSim database columns
        # (specified in self.user_interface_to_opsim) that actually exist in this opsim database
        dtypeList = []
        self.baseQuery = 'SELECT'
        self.active_columns = []

        self._queried_columns = []  # This will be a list of all of the
        # OpSim columns queried
        # Note: here we will refer to the
        # columns by their names in OpSim

        for column in self.user_interface_to_opsim:
            rec = self.user_interface_to_opsim[column]
            if rec[0] in self._summary_columns:
                self.active_columns.append(column)
                dtypeList.append((rec[0], rec[2]))
                if self.baseQuery != 'SELECT':
                    self.baseQuery += ','
                self.baseQuery += ' ' + rec[0]
                self._queried_columns.append(rec[0])

        # Now loop over self._summary_columns, adding any columns
        # to the query that have not already been included therein.
        # Since we do not have explicit information about the
        # data types of these columns, we will assume they are floats.
        for column in self._summary_columns:
            if column not in self._queried_columns:
                self.baseQuery += ', ' + column
                dtypeList.append((column, float))

        self.dtype = np.dtype(dtypeList)
예제 #17
0
    def test_alert_data_generation(self):

        dmag_cutoff = 0.005
        mag_name_to_int = {'u': 0, 'g': 1, 'r': 2, 'i': 3, 'z' : 4, 'y': 5}

        _max_var_param_str = self.max_str_len

        class StarAlertTestDBObj(StellarAlertDBObjMixin, CatalogDBObject):
            objid = 'star_alert'
            tableid = 'stars'
            idColKey = 'simobjid'
            raColName = 'ra'
            decColName = 'dec'
            objectTypeId = 0
            columns = [('raJ2000', 'ra*0.01745329252'),
                       ('decJ2000', 'dec*0.01745329252'),
                       ('parallax', 'px*0.01745329252/3600.0'),
                       ('properMotionRa', 'pmra*0.01745329252/3600.0'),
                       ('properMotionDec', 'pmdec*0.01745329252/3600.0'),
                       ('radialVelocity', 'vrad'),
                       ('variabilityParameters', 'varParamStr', str, _max_var_param_str)]

        class TestAlertsVarCatMixin(object):

            @register_method('alert_test')
            def applyAlertTest(self, valid_dexes, params, expmjd, variability_cache=None):
                if len(params) == 0:
                    return np.array([[], [], [], [], [], []])

                if isinstance(expmjd, numbers.Number):
                    dmags_out = np.zeros((6, self.num_variable_obj(params)))
                else:
                    dmags_out = np.zeros((6, self.num_variable_obj(params), len(expmjd)))

                for i_star in range(self.num_variable_obj(params)):
                    if params['amp'][i_star] is not None:
                        dmags = params['amp'][i_star]*np.cos(params['per'][i_star]*expmjd)
                        for i_filter in range(6):
                            dmags_out[i_filter][i_star] = dmags

                return dmags_out

        class TestAlertsVarCat(TestAlertsVarCatMixin, AlertStellarVariabilityCatalog):
            pass

        class TestAlertsTruthCat(TestAlertsVarCatMixin, CameraCoords, AstrometryStars,
                                 Variability, InstanceCatalog):
            column_outputs = ['uniqueId', 'chipName', 'dmagAlert', 'magAlert']

            camera = obs_lsst_phosim.PhosimMapper().camera

            @compound('delta_umag', 'delta_gmag', 'delta_rmag',
                      'delta_imag', 'delta_zmag', 'delta_ymag')
            def get_TruthVariability(self):
                return self.applyVariability(self.column_by_name('varParamStr'))

            @cached
            def get_dmagAlert(self):
                return self.column_by_name('delta_%smag' % self.obs_metadata.bandpass)

            @cached
            def get_magAlert(self):
                return self.column_by_name('%smag' % self.obs_metadata.bandpass) + \
                       self.column_by_name('dmagAlert')

        star_db = StarAlertTestDBObj(database=self.star_db_name, driver='sqlite')

        # assemble the true light curves for each object; we need to figure out
        # if their np.max(dMag) ever goes over dmag_cutoff; then we will know if
        # we are supposed to simulate them
        true_lc_dict = {}
        true_lc_obshistid_dict = {}
        is_visible_dict = {}
        obs_dict = {}
        max_obshistid = -1
        n_total_observations = 0
        for obs in self.obs_list:
            obs_dict[obs.OpsimMetaData['obsHistID']] = obs
            obshistid = obs.OpsimMetaData['obsHistID']
            if obshistid > max_obshistid:
                max_obshistid = obshistid
            cat = TestAlertsTruthCat(star_db, obs_metadata=obs)

            for line in cat.iter_catalog():
                if line[1] is None:
                    continue

                n_total_observations += 1
                if line[0] not in true_lc_dict:
                    true_lc_dict[line[0]] = {}
                    true_lc_obshistid_dict[line[0]] = []

                true_lc_dict[line[0]][obshistid] = line[2]
                true_lc_obshistid_dict[line[0]].append(obshistid)

                if line[0] not in is_visible_dict:
                    is_visible_dict[line[0]] = False

                if line[3] <= self.obs_mag_cutoff[mag_name_to_int[obs.bandpass]]:
                    is_visible_dict[line[0]] = True

        obshistid_bits = int(np.ceil(np.log(max_obshistid)/np.log(2)))

        skipped_due_to_mag = 0

        objects_to_simulate = []
        obshistid_unqid_set = set()
        for obj_id in true_lc_dict:

            dmag_max = -1.0
            for obshistid in true_lc_dict[obj_id]:
                if np.abs(true_lc_dict[obj_id][obshistid]) > dmag_max:
                    dmag_max = np.abs(true_lc_dict[obj_id][obshistid])

            if dmag_max >= dmag_cutoff:
                if not is_visible_dict[obj_id]:
                    skipped_due_to_mag += 1
                    continue

                objects_to_simulate.append(obj_id)
                for obshistid in true_lc_obshistid_dict[obj_id]:
                    obshistid_unqid_set.add((obj_id << obshistid_bits) + obshistid)

        self.assertGreater(len(objects_to_simulate), 10)
        self.assertGreater(skipped_due_to_mag, 0)

        log_file_name = tempfile.mktemp(dir=self.output_dir, suffix='log.txt')
        alert_gen = AlertDataGenerator(testing=True)

        alert_gen.subdivide_obs(self.obs_list, htmid_level=6)

        for htmid in alert_gen.htmid_list:
            alert_gen.alert_data_from_htmid(htmid, star_db,
                                            photometry_class=TestAlertsVarCat,
                                            output_prefix='alert_test',
                                            output_dir=self.output_dir,
                                            dmag_cutoff=dmag_cutoff,
                                            log_file_name=log_file_name)

        dummy_sed = Sed()

        bp_dict = BandpassDict.loadTotalBandpassesFromFiles()

        phot_params = PhotometricParameters()

        # First, verify that the contents of the sqlite files are all correct

        n_tot_simulated = 0

        alert_query = 'SELECT alert.uniqueId, alert.obshistId, meta.TAI, '
        alert_query += 'meta.band, quiescent.flux, alert.dflux, '
        alert_query += 'quiescent.snr, alert.snr, '
        alert_query += 'alert.ra, alert.dec, alert.chipNum, '
        alert_query += 'alert.xPix, alert.yPix, ast.pmRA, ast.pmDec, '
        alert_query += 'ast.parallax '
        alert_query += 'FROM alert_data AS alert '
        alert_query += 'INNER JOIN metadata AS meta ON meta.obshistId=alert.obshistId '
        alert_query += 'INNER JOIN quiescent_flux AS quiescent '
        alert_query += 'ON quiescent.uniqueId=alert.uniqueId '
        alert_query += 'AND quiescent.band=meta.band '
        alert_query += 'INNER JOIN baseline_astrometry AS ast '
        alert_query += 'ON ast.uniqueId=alert.uniqueId'

        alert_dtype = np.dtype([('uniqueId', int), ('obshistId', int),
                                ('TAI', float), ('band', int),
                                ('q_flux', float), ('dflux', float),
                                ('q_snr', float), ('tot_snr', float),
                                ('ra', float), ('dec', float),
                                ('chipNum', int), ('xPix', float), ('yPix', float),
                                ('pmRA', float), ('pmDec', float), ('parallax', float)])

        sqlite_file_list = os.listdir(self.output_dir)

        n_tot_simulated = 0
        obshistid_unqid_simulated_set = set()
        for file_name in sqlite_file_list:
            if not file_name.endswith('db'):
                continue
            full_name = os.path.join(self.output_dir, file_name)
            self.assertTrue(os.path.exists(full_name))
            alert_db = DBObject(full_name, driver='sqlite')
            alert_data = alert_db.execute_arbitrary(alert_query, dtype=alert_dtype)
            if len(alert_data) == 0:
                continue

            mjd_list = ModifiedJulianDate.get_list(TAI=alert_data['TAI'])
            for i_obj in range(len(alert_data)):
                n_tot_simulated += 1
                obshistid_unqid_simulated_set.add((alert_data['uniqueId'][i_obj] << obshistid_bits) +
                                                  alert_data['obshistId'][i_obj])

                unq = alert_data['uniqueId'][i_obj]
                obj_dex = (unq//1024)-1
                self.assertAlmostEqual(self.pmra_truth[obj_dex], 0.001*alert_data['pmRA'][i_obj], 4)
                self.assertAlmostEqual(self.pmdec_truth[obj_dex], 0.001*alert_data['pmDec'][i_obj], 4)
                self.assertAlmostEqual(self.px_truth[obj_dex], 0.001*alert_data['parallax'][i_obj], 4)

                ra_truth, dec_truth = applyProperMotion(self.ra_truth[obj_dex], self.dec_truth[obj_dex],
                                                        self.pmra_truth[obj_dex], self.pmdec_truth[obj_dex],
                                                        self.px_truth[obj_dex], self.vrad_truth[obj_dex],
                                                        mjd=mjd_list[i_obj])
                distance = angularSeparation(ra_truth, dec_truth,
                                             alert_data['ra'][i_obj], alert_data['dec'][i_obj])

                distance_arcsec = 3600.0*distance
                msg = '\ntruth: %e %e\nalert: %e %e\n' % (ra_truth, dec_truth,
                                                          alert_data['ra'][i_obj],
                                                          alert_data['dec'][i_obj])

                self.assertLess(distance_arcsec, 0.0005, msg=msg)

                obs = obs_dict[alert_data['obshistId'][i_obj]]


                chipname = chipNameFromRaDec(self.ra_truth[obj_dex], self.dec_truth[obj_dex],
                                             pm_ra=self.pmra_truth[obj_dex],
                                             pm_dec=self.pmdec_truth[obj_dex],
                                             parallax=self.px_truth[obj_dex],
                                             v_rad=self.vrad_truth[obj_dex],
                                             obs_metadata=obs,
                                             camera=self.camera)

                chipnum = int(chipname.replace('R', '').replace('S', '').
                              replace(' ', '').replace(';', '').replace(',', '').
                              replace(':', ''))

                self.assertEqual(chipnum, alert_data['chipNum'][i_obj])

                xpix, ypix = pixelCoordsFromRaDec(self.ra_truth[obj_dex], self.dec_truth[obj_dex],
                                                  pm_ra=self.pmra_truth[obj_dex],
                                                  pm_dec=self.pmdec_truth[obj_dex],
                                                  parallax=self.px_truth[obj_dex],
                                                  v_rad=self.vrad_truth[obj_dex],
                                                  obs_metadata=obs,
                                                  camera=self.camera)

                self.assertAlmostEqual(alert_data['xPix'][i_obj], xpix, 4)
                self.assertAlmostEqual(alert_data['yPix'][i_obj], ypix, 4)

                dmag_sim = -2.5*np.log10(1.0+alert_data['dflux'][i_obj]/alert_data['q_flux'][i_obj])
                self.assertAlmostEqual(true_lc_dict[alert_data['uniqueId'][i_obj]][alert_data['obshistId'][i_obj]],
                                       dmag_sim, 3)

                mag_name = ('u', 'g', 'r', 'i', 'z', 'y')[alert_data['band'][i_obj]]
                m5 = obs.m5[mag_name]

                q_mag = dummy_sed.magFromFlux(alert_data['q_flux'][i_obj])
                self.assertAlmostEqual(self.mag0_truth_dict[alert_data['band'][i_obj]][obj_dex],
                                       q_mag, 4)

                snr, gamma = calcSNR_m5(self.mag0_truth_dict[alert_data['band'][i_obj]][obj_dex],
                                        bp_dict[mag_name],
                                        self.obs_mag_cutoff[alert_data['band'][i_obj]],
                                        phot_params)

                self.assertAlmostEqual(snr/alert_data['q_snr'][i_obj], 1.0, 4)

                tot_mag = self.mag0_truth_dict[alert_data['band'][i_obj]][obj_dex] + \
                          true_lc_dict[alert_data['uniqueId'][i_obj]][alert_data['obshistId'][i_obj]]

                snr, gamma = calcSNR_m5(tot_mag, bp_dict[mag_name],
                                        m5, phot_params)
                self.assertAlmostEqual(snr/alert_data['tot_snr'][i_obj], 1.0, 4)

        for val in obshistid_unqid_set:
            self.assertIn(val, obshistid_unqid_simulated_set)
        self.assertEqual(len(obshistid_unqid_set), len(obshistid_unqid_simulated_set))

        astrometry_query = 'SELECT uniqueId, ra, dec, TAI '
        astrometry_query += 'FROM baseline_astrometry'
        astrometry_dtype = np.dtype([('uniqueId', int),
                                     ('ra', float),
                                     ('dec', float),
                                     ('TAI', float)])

        tai_list = []
        for obs in self.obs_list:
            tai_list.append(obs.mjd.TAI)
        tai_list = np.array(tai_list)

        n_tot_ast_simulated = 0
        for file_name in sqlite_file_list:
            if not file_name.endswith('db'):
                continue
            full_name = os.path.join(self.output_dir, file_name)
            self.assertTrue(os.path.exists(full_name))
            alert_db = DBObject(full_name, driver='sqlite')
            astrometry_data = alert_db.execute_arbitrary(astrometry_query, dtype=astrometry_dtype)

            if len(astrometry_data) == 0:
                continue

            mjd_list = ModifiedJulianDate.get_list(TAI=astrometry_data['TAI'])
            for i_obj in range(len(astrometry_data)):
                n_tot_ast_simulated += 1
                obj_dex = (astrometry_data['uniqueId'][i_obj]//1024) - 1
                ra_truth, dec_truth = applyProperMotion(self.ra_truth[obj_dex], self.dec_truth[obj_dex],
                                                        self.pmra_truth[obj_dex], self.pmdec_truth[obj_dex],
                                                        self.px_truth[obj_dex], self.vrad_truth[obj_dex],
                                                        mjd=mjd_list[i_obj])

                distance = angularSeparation(ra_truth, dec_truth,
                                             astrometry_data['ra'][i_obj],
                                             astrometry_data['dec'][i_obj])

                self.assertLess(3600.0*distance, 0.0005)

        del alert_gen
        gc.collect()
        self.assertGreater(n_tot_simulated, 10)
        self.assertGreater(len(obshistid_unqid_simulated_set), 10)
        self.assertLess(len(obshistid_unqid_simulated_set), n_total_observations)
        self.assertGreater(n_tot_ast_simulated, 0)
예제 #18
0
    def write_alerts(self, obshistid, data_dir, prefix_list,
                     htmid_list, out_dir, out_prefix,
                     dmag_cutoff, lock=None, log_file_name=None):
        """
        Write the alerts for an obsHistId to a properly formatted avro file.

        Parameters
        ----------
        obshistid is the integer uniquely identifying the OpSim pointing
        being simulated

        data_dir is the directory containing the sqlite files created by
        the AlertDataGenerator

        prefix_list is a list of prefixes for those sqlite files.

        htmid_list is the list of htmids identifying the trixels that overlap
        this obshistid's field of view. For each htmid in htmid_list and each
        prefix in prefix_list, this method will process the files
            data_dir/prefix_htmid_sqlite.db
        searching for alerts that correspond to this obshistid

        out_dir is the directory to which the avro files should be written

        out_prefix is the prefix of the avro file names

        dmag_cutoff is the minimum delta magnitude needed to trigger an alert

        lock is an optional multiprocessing.Lock() for use when running many
        instances of this method. It prevents multiple processes from writing to
        the logfile or stdout at once.

        log_file_name is the name of an optional text file to which progress is
        written.
        """

        out_name = os.path.join(out_dir, '%s_%d.avro' % (out_prefix, obshistid))
        if os.path.exists(out_name):
            os.unlink(out_name)

        with DataFileWriter(open(out_name, "wb"),
                            DatumWriter(), self._alert_schema) as data_writer:

            diasource_query = 'SELECT alert.uniqueId, alert.xPix, alert.yPix, '
            diasource_query += 'alert.chipNum, alert.dflux, alert.snr, alert.ra, alert.dec, '
            diasource_query += 'meta.band, meta.TAI, quiescent.flux, quiescent.snr '
            diasource_query += 'FROM alert_data as alert '
            diasource_query += 'INNER JOIN metadata AS meta ON alert.obshistId=meta.obshistId '
            diasource_query += 'INNER JOIN quiescent_flux AS quiescent ON quiescent.uniqueId=alert.uniqueID '
            diasource_query += 'AND quiescent.band=meta.band '
            diasource_query += 'WHERE alert.obshistId=%d ' % obshistid
            diasource_query += 'ORDER BY alert.uniqueId'

            diasource_dtype = np.dtype([('uniqueId', int), ('xPix', float), ('yPix', float),
                                        ('chipNum', int), ('dflux', float), ('tot_snr', float),
                                        ('ra', float), ('dec', float), ('band', int), ('TAI', float),
                                        ('quiescent_flux', float), ('quiescent_snr', float)])

            diaobject_query = 'SELECT uniqueId, ra, dec, TAI, pmRA, pmDec, parallax '
            diaobject_query += 'FROM baseline_astrometry'

            diaobject_dtype = np.dtype([('uniqueId', int), ('ra', float), ('dec', float),
                                        ('TAI', float), ('pmRA', float), ('pmDec', float),
                                        ('parallax', float)])

            t_start = time.time()
            alert_ct = 0
            for htmid in htmid_list:
                for prefix in prefix_list:
                    db_name = os.path.join(data_dir, '%s_%d_sqlite.db' % (prefix, htmid))
                    if not os.path.exists(db_name):
                        warnings.warn('%s does not exist' % db_name)
                        continue

                    db_obj = DBObject(db_name, driver='sqlite')

                    diaobject_data = db_obj.execute_arbitrary(diaobject_query,
                                                              dtype=diaobject_dtype)

                    diaobject_dict = self._create_objects(diaobject_data)

                    diasource_data = db_obj.execute_arbitrary(diasource_query,
                                                              dtype=diasource_dtype)

                    dmag = 2.5*np.log10(1.0+diasource_data['dflux']/diasource_data['quiescent_flux'])
                    valid_alerts = np.where(np.abs(dmag) >= dmag_cutoff)
                    diasource_data = diasource_data[valid_alerts]
                    avro_diasource_list = self._create_sources(obshistid, diasource_data)

                    for i_source in range(len(avro_diasource_list)):
                        alert_ct += 1
                        unq = diasource_data[i_source]['uniqueId']
                        diaobject = diaobject_dict[unq]
                        diasource = avro_diasource_list[i_source]

                        avro_alert = {}
                        avro_alert['alertId'] = np.long((obshistid << 20) + alert_ct)
                        avro_alert['l1dbId'] = np.long(unq)
                        avro_alert['diaSource'] = diasource
                        avro_alert['diaObject'] = diaobject

                        data_writer.append(avro_alert)

        if lock is not None:
            lock.acquire()

        elapsed = (time.time()-t_start)/3600.0

        msg = 'finished obshistid %d; %d alerts in %.2e hrs' % (obshistid, alert_ct, elapsed)

        print(msg)

        if log_file_name is not None:
            with open(log_file_name, 'a') as out_file:
                out_file.write(msg)
                out_file.write('\n')

        if lock is not None:
            lock.release()
    def __init__(self, database=None, driver='sqlite', host=None, port=None):
        """
        Constructor for the class

        Parameters
        ----------
        database : string
            absolute path to the output of the OpSim database
        driver : string, optional, defaults to 'sqlite'
            driver/dialect for the SQL database
        host : hostName, optional, defaults to None,
            hostName, None is good for a local database
        port : hostName, optional, defaults to None,
            port, None is good for a local database

        Returns
        ------
        Instance of the ObserverMetaDataGenerator class

        ..notes : For testing purposes a small OpSim database is available at
        `os.path.join(getPackageDir('sims_data'), 'OpSimData/opsimblitz1_1133_sqlite.db')`
        """
        self.driver = driver
        self.host = host
        self.port = port
        self.database = database
        self._seeing_column = 'FWHMeff'

        # a dict keyed on the user interface names of the OpSimdata columns
        # (i.e. the args to getObservationMetaData).  Returns a tuple that is the
        # (name of data column in OpSim, transformation to go from user interface to OpSim units,
        # dtype in OpSim)
        #
        # Note: this dict will contain entries for every column (except propID) in the OpSim
        # summary table, not just those the ObservationMetaDataGenerator is designed to query
        # on.  The idea is that ObservationMetaData generated by this class will carry around
        # records of the values of all of the associated OpSim Summary columns so that users
        # can pass those values on to PhoSim/other tools and thier own discretion.
        self._user_interface_to_opsim = {'obsHistID': ('obsHistID', None, np.int64),
                                         'expDate': ('expDate', None, int),
                                         'fieldRA': ('fieldRA', np.radians, float),
                                         'fieldDec': ('fieldDec', np.radians, float),
                                         'moonRA': ('moonRA', np.radians, float),
                                         'moonDec': ('moonDec', np.radians, float),
                                         'rotSkyPos': ('rotSkyPos', np.radians, float),
                                         'telescopeFilter':
                                             ('filter', lambda x: '\'{}\''.format(x), (str, 1)),
                                         'rawSeeing': ('rawSeeing', None, float),
                                         'sunAlt': ('sunAlt', np.radians, float),
                                         'moonAlt': ('moonAlt', np.radians, float),
                                         'dist2Moon': ('dist2Moon', np.radians, float),
                                         'moonPhase': ('moonPhase', None, float),
                                         'expMJD': ('expMJD', None, float),
                                         'altitude': ('altitude', np.radians, float),
                                         'azimuth': ('azimuth', np.radians, float),
                                         'visitExpTime': ('visitExpTime', None, float),
                                         'airmass': ('airmass', None, float),
                                         'm5': ('fiveSigmaDepth', None, float),
                                         'skyBrightness': ('filtSkyBrightness', None, float),
                                         'sessionID': ('sessionID', None, int),
                                         'fieldID': ('fieldID', None, int),
                                         'night': ('night', None, int),
                                         'visitTime': ('visitTime', None, float),
                                         'finRank': ('finRank', None, float),
                                         'FWHMgeom': ('FWHMgeom', None, float),
                                         # do not include FWHMeff; that is detected by
                                         # self._set_seeing_column()
                                         'transparency': ('transparency', None, float),
                                         'vSkyBright': ('vSkyBright', None, float),
                                         'rotTelPos': ('rotTelPos', None, float),
                                         'lst': ('lst', None, float),
                                         'solarElong': ('solarElong', None, float),
                                         'moonAz': ('moonAz', None, float),
                                         'sunAz': ('sunAz', None, float),
                                         'phaseAngle': ('phaseAngle', None, float),
                                         'rScatter': ('rScatter', None, float),
                                         'mieScatter': ('mieScatter', None, float),
                                         'moonBright': ('moonBright', None, float),
                                         'darkBright': ('darkBright', None, float),
                                         'wind': ('wind', None, float),
                                         'humidity': ('humidity', None, float),
                                         'slewDist': ('slewDist', None, float),
                                         'slewTime': ('slewTime', None, float),
                                         'ditheredRA': ('ditheredRA', None, float),
                                         'ditheredDec': ('ditheredDec', None, float)}

        if self.database is None:
            return

        if not os.path.exists(self.database):
            raise RuntimeError('%s does not exist' % self.database)

        self.opsimdb = DBObject(driver=self.driver, database=self.database,
                                host=self.host, port=self.port)

        # 27 January 2016
        # Detect whether the OpSim db you are connecting to uses 'finSeeing'
        # as its seeing column (deprecated), or FWHMeff, which is the modern
        # standard
        self._summary_columns = self.opsimdb.get_column_names('Summary')
        self._set_seeing_column(self._summary_columns)

        # Set up self.dtype containg the dtype of the recarray we expect back from the SQL query.
        # Also setup baseQuery which is just the SELECT clause of the SQL query
        #
        # self.active_columns will be a list containing the subset of OpSim database columns
        # (specified in self._user_interface_to_opsim) that actually exist in this opsim database
        dtypeList = []
        self.baseQuery = 'SELECT'
        self.active_columns = []

        self._queried_columns = []  # This will be a list of all of the
                                    # OpSim columns queried
                                    # Note: here we will refer to the
                                    # columns by their names in OpSim

        for column in self._user_interface_to_opsim:
            rec = self._user_interface_to_opsim[column]
            if rec[0] in self._summary_columns:
                self.active_columns.append(column)
                dtypeList.append((rec[0], rec[2]))
                if self.baseQuery != 'SELECT':
                    self.baseQuery += ','
                self.baseQuery += ' ' + rec[0]
                self._queried_columns.append(rec[0])

        # Now loop over self._summary_columns, adding any columns
        # to the query that have not already been included therein.
        # Since we do not have explicit information about the
        # data types of these columns, we will assume they are floats.
        for column in self._summary_columns:
            if column not in self._queried_columns:
                self.baseQuery += ', ' + column
                dtypeList.append((column, float))

        self.dtype = np.dtype(dtypeList)
예제 #20
0
def write_sprinkled_lc(out_file_name, total_obs_md,
                       pointing_dir, opsim_db_name,
                       ra_colname='descDitheredRA',
                       dec_colname='descDitheredDec',
                       rottel_colname = 'descDitheredRotTelPos',
                       sql_file_name=None,
                       bp_dict=None):

    """
    Create database of light curves

    Note: this is still under development.  It has not yet been
    used for a production-level truth catalog

    Parameters
    ----------
    out_file_name is the name of the sqlite file to be written

    total_obs_md is an ObservationMetaData covering the whole
    survey area

    pointing_dir contains a series of files that are two columns: obshistid, mjd.
    The files must each have 'visits' in their name.  These specify the pointings
    for which we are assembling data.  See:
        https://github.com/LSSTDESC/DC2_Repo/tree/master/data/Run1.1
    for an example.

    opsim_db_name is the name of the OpSim database to be queried for pointings

    ra_colname is the column used for RA of the pointing (default:
    descDitheredRA)

    dec_colname is the column used for the Dec of the pointing (default:
    descDitheredDec)

    rottel_colname is the column used for the rotTelPos of the pointing
    (default: desckDitheredRotTelPos')

    sql_file_name is the name of the parameter database produced by
    write_sprinkled_param_db to be used

    bp_dict is a BandpassDict of the telescope filters to be used

    Returns
    -------
    None

    Writes out a database to out_file_name.  The tables of this database and
    their columns are:

    light_curves:
        - uniqueId -- an int unique to all objects
        - obshistid -- an int unique to all pointings
        - mag -- the magnitude observed for this object at that pointing

    obs_metadata:
        - obshistid -- an int unique to all pointings
        - mjd -- the date of the pointing
        - filter -- an int corresponding to the telescope filter (0==u, 1==g..)

    variables_and_transients:
        - uniqueId -- an int unique to all objects
        - galaxy_id -- an int indicating the host galaxy
        - ra -- in degrees
        - dec -- in degrees
        - agn -- ==1 if object is an AGN
        - sn -- ==1 if object is a supernova
    """

    t0_master = time.time()

    if not os.path.isfile(sql_file_name):
        raise RuntimeError('%s does not exist' % sql_file_name)


    sn_simulator = SneSimulator(bp_dict)
    sed_dir = os.environ['SIMS_SED_LIBRARY_DIR']

    create_sprinkled_sql_file(out_file_name)

    t_start = time.time()

    # get data about the pointings being simulated
    (htmid_dict,
     mjd_dict,
     filter_dict,
     obsmd_dict) = get_pointing_htmid(pointing_dir, opsim_db_name,
                                      ra_colname=ra_colname,
                                      dec_colname=dec_colname)

    t_htmid_dict = time.time()-t_start

    bp_to_int = {'u':0, 'g':1, 'r':2, 'i':3, 'z':4, 'y':5}

    # put the data about the pointings in the obs_metadata table
    with sqlite3.connect(out_file_name) as conn:
        cursor = conn.cursor()
        values = ((int(obs), mjd_dict[obs], bp_to_int[filter_dict[obs]])
                  for obs in mjd_dict)
        cursor.executemany('''INSERT INTO obs_metadata VALUES (?,?,?)''', values)
        cursor.execute('''CREATE INDEX obs_filter
                       ON obs_metadata (obshistid, filter)''')
        conn.commit()

    print('\ngot htmid_dict -- %d in %e seconds' % (len(htmid_dict), t_htmid_dict))

    db = DBObject(sql_file_name, driver='sqlite')

    # get a list of htmid corresponding to trixels in which
    # variables and transients can be found
    query = 'SELECT DISTINCT htmid FROM zpoint WHERE is_agn=1 OR is_sn=1'
    dtype = np.dtype([('htmid', int)])

    results = db.execute_arbitrary(query, dtype=dtype)

    object_htmid = results['htmid']

    agn_dtype = np.dtype([('uniqueId', int), ('galaxy_id', int),
                          ('ra', float), ('dec', float),
                          ('redshift', float), ('sed', str, 500),
                          ('magnorm', float), ('varParamStr', str, 500),
                          ('is_sprinkled', int)])

    agn_base_query = 'SELECT uniqueId, galaxy_id, '
    agn_base_query += 'raJ2000, decJ2000, '
    agn_base_query += 'redshift, sedFilepath, '
    agn_base_query += 'magNorm, varParamStr, is_sprinkled '
    agn_base_query += 'FROM zpoint WHERE is_agn=1 '

    sn_dtype = np.dtype([('uniqueId', int), ('galaxy_id', int),
                         ('ra', float), ('dec', float),
                         ('redshift', float), ('sn_truth_params', str, 500),
                         ('is_sprinkled', int)])

    sn_base_query = 'SELECT uniqueId, galaxy_id, '
    sn_base_query += 'raJ2000, decJ2000, '
    sn_base_query += 'redshift, sn_truth_params, is_sprinkled '
    sn_base_query += 'FROM zpoint WHERE is_sn=1 '

    filter_to_int = {'u':0, 'g':1, 'r':2, 'i':3, 'z':4, 'y':5}

    n_floats = 0
    with sqlite3.connect(out_file_name) as conn:
        cursor = conn.cursor()
        t_before_htmid = time.time()

        # loop over trixels containing variables and transients, simulating
        # the light curves of those objects
        for htmid_dex, htmid in enumerate(object_htmid):
            if htmid_dex>0:
                htmid_duration = (time.time()-t_before_htmid)/3600.0
                htmid_prediction = len(object_htmid)*htmid_duration/htmid_dex
                print('%d htmid out of %d in %e hours; predict %e hours remaining' %
                (htmid_dex, len(object_htmid), htmid_duration,htmid_prediction-htmid_duration))
            mjd_arr = []
            obs_arr = []
            filter_arr = []

            # Find only those pointings which overlap the current trixel
            for obshistid in htmid_dict:
                is_contained = False
                for bounds in htmid_dict[obshistid]:
                    if htmid<=bounds[1] and htmid>=bounds[0]:
                        is_contained = True
                        break
                if is_contained:
                    mjd_arr.append(mjd_dict[obshistid])
                    obs_arr.append(obshistid)
                    filter_arr.append(filter_to_int[filter_dict[obshistid]])
            if len(mjd_arr) == 0:
                continue
            mjd_arr = np.array(mjd_arr)
            obs_arr = np.array(obs_arr)
            filter_arr = np.array(filter_arr)
            sorted_dex = np.argsort(mjd_arr)
            mjd_arr = mjd_arr[sorted_dex]
            obs_arr = obs_arr[sorted_dex]
            filter_arr = filter_arr[sorted_dex]

            agn_query = agn_base_query + 'AND htmid=%d' % htmid

            agn_iter = db.get_arbitrary_chunk_iterator(agn_query,
                                                       dtype=agn_dtype,
                                                       chunk_size=10000)

            # put static data about the AGN (position, etc.) into the
            # variables_and_transients table
            for i_chunk, agn_results in enumerate(agn_iter):
                values = ((int(agn_results['uniqueId'][i_obj]),
                           int(agn_results['galaxy_id'][i_obj]),
                           np.degrees(agn_results['ra'][i_obj]),
                           np.degrees(agn_results['dec'][i_obj]),
                           int(agn_results['is_sprinkled'][i_obj]),
                           1,0)
                          for i_obj in range(len(agn_results)))

                cursor.executemany('''INSERT INTO variables_and_transients VALUES
                                      (?,?,?,?,?,?,?)''', values)

                agn_simulator = AgnSimulator(agn_results['redshift'])

                quiescent_mag = np.zeros((len(agn_results), 6), dtype=float)
                for i_obj, (sed_name, zz, mm) in enumerate(zip(agn_results['sed'],
                                                               agn_results['redshift'],
                                                               agn_results['magnorm'])):
                    spec = Sed()
                    spec.readSED_flambda(os.path.join(sed_dir, sed_name))
                    fnorm = getImsimFluxNorm(spec, mm)
                    spec.multiplyFluxNorm(fnorm)
                    spec.redshiftSED(zz, dimming=True)
                    mag_list = bp_dict.magListForSed(spec)
                    quiescent_mag[i_obj] = mag_list

                # simulate AGN variability
                dmag = agn_simulator.applyVariability(agn_results['varParamStr'],
                                                      expmjd=mjd_arr)

                # loop over pointings that overlap the current trixel, writing
                # out simulated photometry for each AGN observed in that pointing
                for i_time, obshistid in enumerate(obs_arr):

                    # only include objects that were actually on a detector
                    are_on_chip = _actually_on_chip(np.degrees(agn_results['ra']),
                                                    np.degrees(agn_results['dec']),
                                                    obsmd_dict[obshistid])

                    valid_agn = np.where(are_on_chip)

                    if len(valid_agn[0])==0:
                        continue

                    values = ((int(agn_results['uniqueId'][i_obj]),
                               int(obs_arr[i_time]),
                               quiescent_mag[i_obj][filter_arr[i_time]]+
                               dmag[filter_arr[i_time]][i_obj][i_time])
                              for i_obj in valid_agn[0])
                    cursor.executemany('''INSERT INTO light_curves VALUES
                                       (?,?,?)''', values)

                conn.commit()
                n_floats += len(dmag.flatten())

            sn_query = sn_base_query + 'AND htmid=%d' % htmid

            sn_iter = db.get_arbitrary_chunk_iterator(sn_query,
                                                      dtype=sn_dtype,
                                                      chunk_size=10000)

            for sn_results in sn_iter:
                t0_sne = time.time()

                # write static information about SNe to the
                # variables_and_transients table
                values = ((int(sn_results['uniqueId'][i_obj]),
                           int(sn_results['galaxy_id'][i_obj]),
                           np.degrees(sn_results['ra'][i_obj]),
                           np.degrees(sn_results['dec'][i_obj]),
                           int(sn_results['is_sprinkled'][i_obj]),
                           0,1)
                          for i_obj in range(len(sn_results)))

                cursor.executemany('''INSERT INTO variables_and_transients VALUES
                                      (?,?,?,?,?,?,?)''', values)
                conn.commit()

                sn_mags = sn_simulator.calculate_sn_magnitudes(sn_results['sn_truth_params'],
                                                               mjd_arr, filter_arr)
                print('    did %d sne in %e seconds' % (len(sn_results), time.time()-t0_sne))

                # loop over pointings that overlap the current trixel, writing
                # out simulated photometry for each SNe observed in that pointing
                for i_time, obshistid in enumerate(obs_arr):

                    # only include objects that fell on a detector
                    are_on_chip = _actually_on_chip(np.degrees(sn_results['ra']),
                                                    np.degrees(sn_results['dec']),
                                                    obsmd_dict[obshistid])

                    valid_obj = np.where(np.logical_and(np.isfinite(sn_mags[:,i_time]),
                                                        are_on_chip))

                    if len(valid_obj[0]) == 0:
                        continue

                    values = ((int(sn_results['uniqueId'][i_obj]),
                               int(obs_arr[i_time]),
                               sn_mags[i_obj][i_time])
                              for i_obj in valid_obj[0])

                    cursor.executemany('''INSERT INTO light_curves VALUES (?,?,?)''', values)
                    conn.commit()
                    n_floats += len(valid_obj[0])

        cursor.execute('CREATE INDEX unq_obs ON light_curves (uniqueId, obshistid)')
        conn.commit()

    print('n_floats %d' % n_floats)
    print('in %e seconds' % (time.time()-t0_master))
예제 #21
0
    parser.add_argument('--partition', type=str, default=None,
                        help='htmid tag of stars_partition_* table to run')

    parser.add_argument('--outdir', type=str, default=None,
                        help='dir to write output to')

    parser.add_argument('--n_procs', type=int, default=20)
    args = parser.parse_args()

    assert args.partition is not None
    assert args.outdir is not None
    assert os.path.isdir(args.outdir)

    try:
        db = DBObject(database='LSST',
                      port=1433,
                      host='epyc.astro.washington.edu',
                      driver='mssql+pymssql')
    except:
        db = DBObject(database='LSST',
                      port=51432,
                      host='localhost',
                      driver='mssql+pymssql')


    table_name = 'stars_partition_%s' % args.partition
    out_name = os.path.join(args.outdir,'isvar_lookup_%s.txt' % args.partition)
    if os.path.isfile(out_name):
        os.unlink(out_name)
        #raise RuntimeError("\n%s\nexists\n" % out_name)

    query = "SELECT "
    bp_dict = BandpassDict.loadTotalBandpassesFromFiles()
    bp = bp_dict['i']

    z_grid = np.arange(0.0, 16.0, 0.01)
    k_grid = np.zeros(len(z_grid), dtype=float)

    for i_z, zz in enumerate(z_grid):
        ss = Sed(flambda=base_sed.flambda, wavelen=base_sed.wavelen)
        ss.redshiftSED(zz, dimming=True)
        k = k_correction(ss, bp, zz)
        k_grid[i_z] = k

    cosmo = CosmologyObject()

    db = DBObject(database='LSSTCATSIM',
                  host='fatboy.phys.washington.edu',
                  port=1433,
                  driver='mssql+pymssql')

    query = 'SELECT magnorm_agn, redshift, varParamStr FROM '
    query += 'galaxy WHERE varParamStr IS NOT NULL '
    query += 'AND dec BETWEEN -2.5 AND 2.5 '
    query += 'AND (ra<2.5 OR ra>357.5)'

    dtype = np.dtype([('magnorm', float), ('redshift', float),
                      ('varParamStr', str, 400)])

    data_iter = db.get_arbitrary_chunk_iterator(query,
                                                dtype=dtype,
                                                chunk_size=10000)

    with open('data/dc1_agn_params.txt', 'w') as out_file:
예제 #23
0
    def test_catalog_db_object_cacheing(self):
        """
        Test that opening multiple CatalogDBObjects that connect to the same
        database only results in one connection being opened and used.  We
        will test this by instantiating two CatalogDBObjects and a DBObject
        that connect to the same database.  We will then test that the two
        CatalogDBObjects' connections are identical, but that the DBObject has
        its own connection.
        """

        self.assertEqual(len(CatalogDBObject._connection_cache), 0)

        class DbClass1(CatalogDBObject):
            database = self.db_name
            port = None
            host = None
            driver = 'sqlite'
            tableid = 'test'
            idColKey = 'id'
            objid = 'test_db_class_1'

            columns = [('identification', 'id')]

        class DbClass2(CatalogDBObject):
            database = self.db_name
            port = None
            host = None
            driver = 'sqlite'
            tableid = 'test'
            idColKey = 'id'
            objid = 'test_db_class_2'

            columns = [('other', 'i1')]

        db1 = DbClass1()
        db2 = DbClass2()
        self.assertEqual(id(db1.connection), id(db2.connection))
        self.assertEqual(len(CatalogDBObject._connection_cache), 1)

        db3 = DBObject(database=self.db_name,
                       driver='sqlite',
                       host=None,
                       port=None)
        self.assertNotEqual(id(db1.connection), id(db3.connection))

        self.assertEqual(len(CatalogDBObject._connection_cache), 1)

        # check that if we had passed db1.connection to a DBObject,
        # the connections would be identical
        db4 = DBObject(connection=db1.connection)
        self.assertEqual(id(db4.connection), id(db1.connection))

        self.assertEqual(len(CatalogDBObject._connection_cache), 1)

        # verify that db1 and db2 are both useable
        results = db1.query_columns(
            colnames=['id', 'i1', 'i2', 'identification'])
        results = next(results)
        self.assertEqual(len(results), 5)
        np.testing.assert_array_equal(results['id'], list(range(5)))
        np.testing.assert_array_equal(results['id'], results['identification'])
        np.testing.assert_array_equal(results['id']**2, results['i1'])
        np.testing.assert_array_equal(results['id'] * (-1), results['i2'])

        results = db2.query_columns(colnames=['id', 'i1', 'i2', 'other'])
        results = next(results)
        self.assertEqual(len(results), 5)
        np.testing.assert_array_equal(results['id'], list(range(5)))
        np.testing.assert_array_equal(results['id']**2, results['i1'])
        np.testing.assert_array_equal(results['i1'], results['other'])
        np.testing.assert_array_equal(results['id'] * (-1), results['i2'])