def testColumnNames(self): """ Test the method that returns the names of columns in a table """ dbobj = DBObject(driver=self.driver, database=self.db_name) names = dbobj.get_column_names('doubleTable') self.assertEqual(len(names), 3) self.assertIn('id', names) self.assertIn('sqrt', names) self.assertIn('log', names) names = dbobj.get_column_names('intTable') self.assertEqual(len(names), 3) self.assertIn('id', names) self.assertIn('twice', names) self.assertIn('thrice', names) names = dbobj.get_column_names() keys = ['doubleTable', 'intTable', 'junkTable'] for kk in names: self.assertIn(kk, keys) self.assertEqual(len(names['doubleTable']), 3) self.assertEqual(len(names['intTable']), 3) self.assertIn('id', names['doubleTable']) self.assertIn('sqrt', names['doubleTable']) self.assertIn('log', names['doubleTable']) self.assertIn('id', names['intTable']) self.assertIn('twice', names['intTable']) self.assertIn('thrice', names['intTable'])
def testTableNames(self): """ Test the method that returns the names of tables in a database """ dbobj = DBObject(driver=self.driver, database=self.db_name) names = dbobj.get_table_names() self.assertEqual(len(names), 3) self.assertIn('doubleTable', names) self.assertIn('intTable', names)
def testPassingConnection(self): """ Repeat the test from testJoin, but with a DBObject whose connection was passed directly from another DBObject, to make sure that passing a connection works """ dbobj_base = DBObject(driver=self.driver, database=self.db_name) dbobj = DBObject(connection=dbobj_base.connection) query = 'SELECT doubleTable.id, intTable.id, doubleTable.log, intTable.thrice ' query += 'FROM doubleTable, intTable WHERE doubleTable.id = intTable.id' results = dbobj.get_chunk_iterator(query, chunk_size=10) dtype = [ ('id', int), ('id_1', int), ('log', float), ('thrice', int)] i = 0 for chunk in results: if i < 90: self.assertEqual(len(chunk), 10) for row in chunk: self.assertEqual(2*(i+1), row[0]) self.assertEqual(row[0], row[1]) self.assertAlmostEqual(np.log(row[0]), row[2], 6) self.assertEqual(3*row[0], row[3]) self.assertEqual(dtype, row.dtype) i += 1 self.assertEqual(i, 99) # make sure that we found all the matches whe should have results = dbobj.execute_arbitrary(query) self.assertEqual(dtype, results.dtype) i = 0 for row in results: self.assertEqual(2*(i+1), row[0]) self.assertEqual(row[0], row[1]) self.assertAlmostEqual(np.log(row[0]), row[2], 6) self.assertEqual(3*row[0], row[3]) i += 1 self.assertEqual(i, 99)
def testMinMax(self): """ Test queries on SQL functions by using the MIN and MAX functions """ dbobj = DBObject(driver=self.driver, database=self.db_name) query = 'SELECT MAX(thrice), MIN(thrice) FROM intTable' results = dbobj.execute_arbitrary(query) self.assertEqual(results[0][0], 594) self.assertEqual(results[0][1], 0) dtype = [('MAXthrice', int), ('MINthrice', int)] self.assertEqual(results.dtype, dtype)
def testReadOnlyFilter(self): """ Test that the filters we placed on queries made with execute_aribtrary() work """ dbobj = DBObject(driver=self.driver, database=self.db_name) controlQuery = 'SELECT doubleTable.id, intTable.id, doubleTable.log, intTable.thrice ' controlQuery += 'FROM doubleTable, intTable WHERE doubleTable.id = intTable.id' dbobj.execute_arbitrary(controlQuery) # make sure that execute_arbitrary only accepts strings query = ['a', 'list'] self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) # check that our filter catches different capitalization permutations of the # verboten commands query = 'DROP TABLE junkTable' self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower()) query = 'DELETE FROM junkTable WHERE id=4' self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower()) query = 'UPDATE junkTable SET sqrt=0.0, log=0.0 WHERE id=4' self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower()) query = 'INSERT INTO junkTable VALUES (9999, 1.0, 1.0)' self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query.lower()) query = 'Drop Table junkTable' self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) query = 'Delete FROM junkTable WHERE id=4' self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) query = 'Update junkTable SET sqrt=0.0, log=0.0 WHERE id=4' self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) query = 'Insert INTO junkTable VALUES (9999, 1.0, 1.0)' self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) query = 'dRoP TaBlE junkTable' self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) query = 'dElEtE FROM junkTable WHERE id=4' self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) query = 'uPdAtE junkTable SET sqrt=0.0, log=0.0 WHERE id=4' self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query) query = 'iNsErT INTO junkTable VALUES (9999, 1.0, 1.0)' self.assertRaises(RuntimeError, dbobj.execute_arbitrary, query)
def testValidationErrors(self): """ Test that appropriate errors and warnings are thrown when connecting """ with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") DBObject('sqlite:///' + self.db_name) assert len(w) == 1 # missing database self.assertRaises(AttributeError, DBObject, driver=self.driver) # missing driver self.assertRaises(AttributeError, DBObject, database=self.db_name) # missing host self.assertRaises(AttributeError, DBObject, driver='mssql+pymssql') # missing port self.assertRaises(AttributeError, DBObject, driver='mssql+pymssql', host='localhost')
def get_table_mins(table_tag, dmag_dict,_out_dir): db = DBObject(database='LSSTCATSIM', host='fatboy.phys.washington.edu', port=1433, driver='mssql+pymssql') query = 'SELECT ' query += 'htmid, simobjid, varParamStr ' query += 'FROM stars_obafgk_part_%s' % table_tag dtype = np.dtype([('htmid', int), ('simobjid', int), ('varParamStr', str, 200)]) data_iter = get_arbitrary_chunk_iterator(query, dtype=dtype, chunk_size=10000) with open(os.path.join(_out_dir,'dmag_%s.txt' % table_tag), 'w') as out_file: for chunk in data_iter: for star in chunk: param_dict = json.loads(star['varParamStr']) lc_id = param_dict['p']['lc'] out_file.write('%d %d %d\n' % (star['htmid'], star['simobjid'], dmag_dict[lc_id]))
def testSingleTableQuery(self): """ Test a query on a single table (using chunk iterator) """ dbobj = DBObject(driver=self.driver, database=self.db_name) query = 'SELECT id, sqrt FROM doubleTable' results = dbobj.get_chunk_iterator(query) dtype = [('id', int), ('sqrt', float)] i = 1 for chunk in results: for row in chunk: self.assertEqual(row[0], i) self.assertAlmostEqual(row[1], np.sqrt(i)) self.assertEqual(dtype, row.dtype) i += 1 self.assertEqual(i, 201)
def testDtype(self): """ Test that passing dtype to a query works (also test q query on a single table using .execute_arbitrary() directly """ dbobj = DBObject(driver=self.driver, database=self.db_name) query = 'SELECT id, log FROM doubleTable' dtype = [('id', int), ('log', float)] results = dbobj.execute_arbitrary(query, dtype = dtype) self.assertEqual(results.dtype, dtype) for xx in results: self.assertAlmostEqual(np.log(xx[0]), xx[1], 6) self.assertEqual(len(results), 200) results = dbobj.get_chunk_iterator(query, chunk_size=10, dtype=dtype) next(results) for chunk in results: self.assertEqual(chunk.dtype, dtype)
def testJoin(self): """ Test a join """ dbobj = DBObject(driver=self.driver, database=self.db_name) query = 'SELECT doubleTable.id, intTable.id, doubleTable.log, intTable.thrice ' query += 'FROM doubleTable, intTable WHERE doubleTable.id = intTable.id' results = dbobj.get_chunk_iterator(query, chunk_size=10) dtype = [ ('id', int), ('id_1', int), ('log', float), ('thrice', int)] i = 0 for chunk in results: if i < 90: self.assertEqual(len(chunk), 10) for row in chunk: self.assertEqual(2*(i+1), row[0]) self.assertEqual(row[0], row[1]) self.assertAlmostEqual(np.log(row[0]), row[2], 6) self.assertEqual(3*row[0], row[3]) self.assertEqual(dtype, row.dtype) i += 1 self.assertEqual(i, 99) # make sure that we found all the matches whe should have results = dbobj.execute_arbitrary(query) self.assertEqual(dtype, results.dtype) i = 0 for row in results: self.assertEqual(2*(i+1), row[0]) self.assertEqual(row[0], row[1]) self.assertAlmostEqual(np.log(row[0]), row[2], 6) self.assertEqual(3*row[0], row[3]) i += 1 self.assertEqual(i, 99)
from lsst.sims.catalogs.db import DBObject db = DBObject(database='LSST', host='localhost', port=51432, driver='mssql+pymssql') import numpy as np rng = np.random.RandomState(7153323) dtype = np.dtype([('id', int), ('htmid', int)]) query = 'SELECT id, htmid FROM galaxy' data_iter = db.get_arbitrary_chunk_iterator(query, dtype=dtype, chunk_size=10000) with open('galaxy_sne_flag.txt', 'w') as out_file: for chunk in data_iter: flag_vals = rng.randint(0, 10, size=len(chunk)) for hh, ii, ff in zip(chunk['htmid'], chunk['id'], flag_vals): out_file.write('%d;%d;%d\n' % (hh, ii, ff))
from lsst.sims.catalogs.db import DBObject if args.use_tunnel: from lsst.sims.catUtils.baseCatalogModel import BaseCatalogConfig config = BaseCatalogConfig host = config.host port = config.port database = config.database driver = config.driver else: host = 'fatboy.phys.washington.edu' port = 1433 database = 'LSSTCATSIM' driver = 'mssql+pymssql' db = DBObject(database=database, host=host, port=port, driver=driver) if args.limit is None: query = 'SELECT ' else: query = 'SELECT TOP %d ' % args.limit query += 'htmid, simobjid, gal_l, gal_b, parallax, sdssr, sdssi, sdssz ' query += 'FROM %s ' % args.table from mdwarf_utils import activity_type_from_color_z from mdwarf_utils import xyz_from_lon_lat_px import os import numpy as np import time
def testDetectDtype(self): """ Test that DBObject.execute_arbitrary can correctly detect the dtypes of the rows it is returning """ db_name = os.path.join(self.scratch_dir, 'testDBObject_dtype_DB.db') if os.path.exists(db_name): os.unlink(db_name) conn = sqlite3.connect(db_name) c = conn.cursor() try: c.execute('''CREATE TABLE testTable (id int, val real, sentence int)''') conn.commit() except: raise RuntimeError("Error creating database.") for ii in range(10): cmd = '''INSERT INTO testTable VALUES (%d, %.5f, %s)''' % (ii, 5.234*ii, "'this, has; punctuation'") c.execute(cmd) conn.commit() conn.close() db = DBObject(database=db_name, driver='sqlite') query = 'SELECT id, val, sentence FROM testTable WHERE id%2 = 0' results = db.execute_arbitrary(query) np.testing.assert_array_equal(results['id'], np.arange(0,9,2,dtype=int)) np.testing.assert_array_almost_equal(results['val'], 5.234*np.arange(0,9,2), decimal=5) for sentence in results['sentence']: self.assertEqual(sentence, 'this, has; punctuation') self.assertEqual(str(results.dtype['id']), 'int64') self.assertEqual(str(results.dtype['val']), 'float64') if sys.version_info.major == 2: self.assertEqual(str(results.dtype['sentence']), '|S22') else: self.assertEqual(str(results.dtype['sentence']), '<U22') self.assertEqual(len(results.dtype), 3) # now test that it works when getting a ChunkIterator chunk_iter = db.get_arbitrary_chunk_iterator(query, chunk_size=3) ct = 0 for chunk in chunk_iter: self.assertEqual(str(chunk.dtype['id']), 'int64') self.assertEqual(str(chunk.dtype['val']), 'float64') if sys.version_info.major == 2: self.assertEqual(str(results.dtype['sentence']), '|S22') else: self.assertEqual(str(results.dtype['sentence']), '<U22') self.assertEqual(len(chunk.dtype), 3) for line in chunk: ct += 1 self.assertEqual(line['sentence'], 'this, has; punctuation') self.assertAlmostEqual(line['val'], line['id']*5.234, 5) self.assertEqual(line['id']%2, 0) self.assertEqual(ct, 5) # test that doing a different query does not spoil dtype detection query = 'SELECT id, sentence FROM testTable WHERE id%2 = 0' results = db.execute_arbitrary(query) self.assertGreater(len(results), 0) self.assertEqual(len(results.dtype.names), 2) self.assertEqual(str(results.dtype['id']), 'int64') if sys.version_info.major == 2: self.assertEqual(str(results.dtype['sentence']), '|S22') else: self.assertEqual(str(results.dtype['sentence']), '<U22') query = 'SELECT id, val, sentence FROM testTable WHERE id%2 = 0' chunk_iter = db.get_arbitrary_chunk_iterator(query, chunk_size=3) ct = 0 for chunk in chunk_iter: self.assertEqual(str(chunk.dtype['id']), 'int64') self.assertEqual(str(chunk.dtype['val']), 'float64') if sys.version_info.major == 2: self.assertEqual(str(results.dtype['sentence']), '|S22') else: self.assertEqual(str(results.dtype['sentence']), '<U22') self.assertEqual(len(chunk.dtype), 3) for line in chunk: ct += 1 self.assertEqual(line['sentence'], 'this, has; punctuation') self.assertAlmostEqual(line['val'], line['id']*5.234, 5) self.assertEqual(line['id']%2, 0) self.assertEqual(ct, 5) if os.path.exists(db_name): os.unlink(db_name)
def get_pointing_htmid(pointing_dir, opsim_db_name, ra_colname='descDitheredRA', dec_colname='descDitheredDec', rottel_colname='descDitheredRotTelPos'): """ For a list of OpSim pointings, find dicts mapping those pointings to: - The trixels filling the pointings - The MJDs of the pointings - The telescope filters of the pointings Parameters ---------- pointing_dir contains a series of files that are two columns: obshistid, mjd. The files must each have 'visits' in their name. These specify the pointings for which we are assembling data. See: https://github.com/LSSTDESC/DC2_Repo/tree/master/data/Run1.1 for an example. opsim_db_name is the path to the OpSim database from which to take those pointings ra_colname is the column used for RA of the pointing (default: descDitheredRA) dec_colname is the column used for the Dec of the pointing (default: descDitheredDec) rottel_colname is the column used for the rotTelPos of the pointing (default: desckDitheredRotTelPos') Returns ------- htmid_bound_dict -- a dict keyed on ObsHistID. Values are the list of trixels filling the OpSim pointing, as returned by lsst.sims.utils.HalfSpace.findAllTrixels mjd_dict -- a dict keyed on ObsHistID. Values are the MJD(TAI) of the OpSim pointings. filter_dict -- a dict keyed on ObsHistID. Values are the 'ugrizy' filter of the OpSim pointings. obsmd_dict -- a dict keyed on ObsHistID. The values are ObservationMetaData with the RA, Dec, MJD, and rotSkyPos of the pointings (for use in focal plane geometry calculations) """ radius = 2.0 # field of view of a pointing in degrees if not os.path.isfile(opsim_db_name): raise RuntimeError("%s is not a valid file name" % opsim_db_name) if not os.path.isdir(pointing_dir): raise RuntimeError("%s is not a valid dir name" % pointing_dir) dtype = np.dtype([('obshistid', int), ('mjd', float)]) obs_data = None for file_name in os.listdir(pointing_dir): if 'visits' in file_name: full_name = os.path.join(pointing_dir, file_name) data = np.genfromtxt(full_name, dtype=dtype) if obs_data is None: obs_data = data['obshistid'] else: obs_data = np.concatenate((obs_data, data['obshistid']), axis=0) obs_data = np.sort(obs_data) db = DBObject(opsim_db_name, driver='sqlite') dtype = np.dtype([('obshistid', int), ('mjd', float), ('band', str, 1), ('ra', float), ('dec', float), ('rotTelPos', float)]) htmid_bound_dict = {} mjd_dict = {} filter_dict = {} obsmd_dict = {} d_obs = len(obs_data) // 5 for i_start in range(0, len(obs_data), d_obs): i_end = i_start + d_obs if len(obs_data) - i_start < d_obs: i_end = len(obs_data) subset = obs_data[i_start:i_end] query = 'SELECT obsHistId, expMJD, filter,' query += ' %s, %s, %s FROM Summary' % (ra_colname, dec_colname, rottel_colname) query += ' WHERE obsHistID BETWEEN %d and %e' % (subset.min(), subset.max()) query += ' GROUP BY obsHistID' results = db.execute_arbitrary(query, dtype=dtype) for ii in range(len(results)): obshistid = results['obshistid'][ii] if obshistid not in obs_data: continue hs = halfSpaceFromRaDec(np.degrees(results['ra'][ii]), np.degrees(results['dec'][ii]), radius) trixel_bounds = hs.findAllTrixels(_truth_trixel_level) htmid_bound_dict[obshistid] = trixel_bounds mjd_dict[obshistid] = results['mjd'][ii] filter_dict[obshistid] = results['band'][ii] obs_md = ObservationMetaData( pointingRA=np.degrees(results['ra'][ii]), pointingDec=np.degrees(results['dec'][ii]), mjd=results['mjd'][ii]) rotsky_rad = _getRotSkyPos(results['ra'][ii], results['dec'][ii], obs_md, results['rotTelPos'][ii]) obsmd_dict[obshistid] = ObservationMetaData( pointingRA=np.degrees(results['ra'][ii]), pointingDec=np.degrees(results['dec'][ii]), mjd=results['mjd'][ii], rotSkyPos=np.degrees(rotsky_rad)) assert len(obs_data) == len(htmid_bound_dict) return htmid_bound_dict, mjd_dict, filter_dict, obsmd_dict
def _postprocess_results(self, master_chunk): """ query the database specified by agn_params_db to find the AGN varParamStr associated with each AGN """ if self.agn_objid is None: gid_name = 'galaxy_id' varpar_name = 'varParamStr' magnorm_name = 'magNorm' else: gid_name = self.agn_objid + '_' + 'galaxy_id' varpar_name = self.agn_objid + '_' + 'varParamStr' magnorm_name = self.agn_objid + '_' + 'magNorm' if self.agn_params_db is None: return(master_chunk) if not os.path.exists(self.agn_params_db): raise RuntimeError('\n%s\n\ndoes not exist' % self.agn_params_db) if not hasattr(self, '_agn_dbo'): self._agn_dbo = DBObject(database=self.agn_params_db, driver='sqlite') self._agn_dtype = np.dtype([('galaxy_id', int), ('magNorm', float), ('varParamStr', str, 500)]) gid_arr = master_chunk[gid_name].astype(float) gid_min = np.nanmin(gid_arr) gid_max = np.nanmax(gid_arr) query = 'SELECT galaxy_id, magNorm, varParamStr ' query += 'FROM agn_params ' query += 'WHERE galaxy_id BETWEEN %d AND %d ' % (gid_min, gid_max) query += 'ORDER BY galaxy_id' agn_data_iter = self._agn_dbo.get_arbitrary_chunk_iterator(query, dtype=self._agn_dtype, chunk_size=1000000) m_sorted_dex = np.argsort(gid_arr) m_sorted_id = gid_arr[m_sorted_dex] for agn_chunk in agn_data_iter: # find the indices of the elements in master_chunk # that correspond to elements in agn_chunk m_elements = np.in1d(m_sorted_id, agn_chunk['galaxy_id']) m_dex = m_sorted_dex[m_elements] # find the indices of the elements in agn_chunk # that correspond to elements in master_chunk a_dex = np.in1d(agn_chunk['galaxy_id'], m_sorted_id) # make sure we have matched elements correctly np.testing.assert_array_equal(agn_chunk['galaxy_id'][a_dex], master_chunk[gid_name][m_dex]) if varpar_name in master_chunk.dtype.names: master_chunk[varpar_name][m_dex] = agn_chunk['varParamStr'][a_dex] if magnorm_name in master_chunk.dtype.names: master_chunk[magnorm_name][m_dex] = agn_chunk['magNorm'][a_dex] return self._final_pass(master_chunk)
def __init__(self, database=None, driver='sqlite', host=None, port=None): """ Constructor for the class Parameters ---------- database : string absolute path to the output of the OpSim database driver : string, optional, defaults to 'sqlite' driver/dialect for the SQL database host : hostName, optional, defaults to None, hostName, None is good for a local database port : hostName, optional, defaults to None, port, None is good for a local database Returns ------ Instance of the ObserverMetaDataGenerator class ..notes : For testing purposes a small OpSim database is available at `os.path.join(getPackageDir('sims_data'), 'OpSimData/opsimblitz1_1133_sqlite.db')` """ self._opsim_version = None self.driver = driver self.host = host self.port = port self.database = database self._seeing_column = 'FWHMeff' if self.database is None: return if not os.path.isfile(self.database): raise RuntimeError('%s is not a file' % self.database) self.opsimdb = DBObject(driver=self.driver, database=self.database, host=self.host, port=self.port) # 27 January 2016 # Detect whether the OpSim db you are connecting to uses 'finSeeing' # as its seeing column (deprecated), or FWHMeff, which is the modern # standard list_of_tables = self.opsimdb.get_table_names() if 'Summary' in list_of_tables: self._opsim_version = 3 else: self._opsim_version = 4 self._summary_columns = self.opsimdb.get_column_names(self.table_name) self._set_seeing_column(self._summary_columns) # Set up self.dtype containg the dtype of the recarray we expect back from the SQL query. # Also setup baseQuery which is just the SELECT clause of the SQL query # # self.active_columns will be a list containing the subset of OpSim database columns # (specified in self.user_interface_to_opsim) that actually exist in this opsim database dtypeList = [] self.baseQuery = 'SELECT' self.active_columns = [] self._queried_columns = [] # This will be a list of all of the # OpSim columns queried # Note: here we will refer to the # columns by their names in OpSim for column in self.user_interface_to_opsim: rec = self.user_interface_to_opsim[column] if rec[0] in self._summary_columns: self.active_columns.append(column) dtypeList.append((rec[0], rec[2])) if self.baseQuery != 'SELECT': self.baseQuery += ',' self.baseQuery += ' ' + rec[0] self._queried_columns.append(rec[0]) # Now loop over self._summary_columns, adding any columns # to the query that have not already been included therein. # Since we do not have explicit information about the # data types of these columns, we will assume they are floats. for column in self._summary_columns: if column not in self._queried_columns: self.baseQuery += ', ' + column dtypeList.append((column, float)) self.dtype = np.dtype(dtypeList)
def test_alert_data_generation(self): dmag_cutoff = 0.005 mag_name_to_int = {'u': 0, 'g': 1, 'r': 2, 'i': 3, 'z' : 4, 'y': 5} _max_var_param_str = self.max_str_len class StarAlertTestDBObj(StellarAlertDBObjMixin, CatalogDBObject): objid = 'star_alert' tableid = 'stars' idColKey = 'simobjid' raColName = 'ra' decColName = 'dec' objectTypeId = 0 columns = [('raJ2000', 'ra*0.01745329252'), ('decJ2000', 'dec*0.01745329252'), ('parallax', 'px*0.01745329252/3600.0'), ('properMotionRa', 'pmra*0.01745329252/3600.0'), ('properMotionDec', 'pmdec*0.01745329252/3600.0'), ('radialVelocity', 'vrad'), ('variabilityParameters', 'varParamStr', str, _max_var_param_str)] class TestAlertsVarCatMixin(object): @register_method('alert_test') def applyAlertTest(self, valid_dexes, params, expmjd, variability_cache=None): if len(params) == 0: return np.array([[], [], [], [], [], []]) if isinstance(expmjd, numbers.Number): dmags_out = np.zeros((6, self.num_variable_obj(params))) else: dmags_out = np.zeros((6, self.num_variable_obj(params), len(expmjd))) for i_star in range(self.num_variable_obj(params)): if params['amp'][i_star] is not None: dmags = params['amp'][i_star]*np.cos(params['per'][i_star]*expmjd) for i_filter in range(6): dmags_out[i_filter][i_star] = dmags return dmags_out class TestAlertsVarCat(TestAlertsVarCatMixin, AlertStellarVariabilityCatalog): pass class TestAlertsTruthCat(TestAlertsVarCatMixin, CameraCoords, AstrometryStars, Variability, InstanceCatalog): column_outputs = ['uniqueId', 'chipName', 'dmagAlert', 'magAlert'] camera = obs_lsst_phosim.PhosimMapper().camera @compound('delta_umag', 'delta_gmag', 'delta_rmag', 'delta_imag', 'delta_zmag', 'delta_ymag') def get_TruthVariability(self): return self.applyVariability(self.column_by_name('varParamStr')) @cached def get_dmagAlert(self): return self.column_by_name('delta_%smag' % self.obs_metadata.bandpass) @cached def get_magAlert(self): return self.column_by_name('%smag' % self.obs_metadata.bandpass) + \ self.column_by_name('dmagAlert') star_db = StarAlertTestDBObj(database=self.star_db_name, driver='sqlite') # assemble the true light curves for each object; we need to figure out # if their np.max(dMag) ever goes over dmag_cutoff; then we will know if # we are supposed to simulate them true_lc_dict = {} true_lc_obshistid_dict = {} is_visible_dict = {} obs_dict = {} max_obshistid = -1 n_total_observations = 0 for obs in self.obs_list: obs_dict[obs.OpsimMetaData['obsHistID']] = obs obshistid = obs.OpsimMetaData['obsHistID'] if obshistid > max_obshistid: max_obshistid = obshistid cat = TestAlertsTruthCat(star_db, obs_metadata=obs) for line in cat.iter_catalog(): if line[1] is None: continue n_total_observations += 1 if line[0] not in true_lc_dict: true_lc_dict[line[0]] = {} true_lc_obshistid_dict[line[0]] = [] true_lc_dict[line[0]][obshistid] = line[2] true_lc_obshistid_dict[line[0]].append(obshistid) if line[0] not in is_visible_dict: is_visible_dict[line[0]] = False if line[3] <= self.obs_mag_cutoff[mag_name_to_int[obs.bandpass]]: is_visible_dict[line[0]] = True obshistid_bits = int(np.ceil(np.log(max_obshistid)/np.log(2))) skipped_due_to_mag = 0 objects_to_simulate = [] obshistid_unqid_set = set() for obj_id in true_lc_dict: dmag_max = -1.0 for obshistid in true_lc_dict[obj_id]: if np.abs(true_lc_dict[obj_id][obshistid]) > dmag_max: dmag_max = np.abs(true_lc_dict[obj_id][obshistid]) if dmag_max >= dmag_cutoff: if not is_visible_dict[obj_id]: skipped_due_to_mag += 1 continue objects_to_simulate.append(obj_id) for obshistid in true_lc_obshistid_dict[obj_id]: obshistid_unqid_set.add((obj_id << obshistid_bits) + obshistid) self.assertGreater(len(objects_to_simulate), 10) self.assertGreater(skipped_due_to_mag, 0) log_file_name = tempfile.mktemp(dir=self.output_dir, suffix='log.txt') alert_gen = AlertDataGenerator(testing=True) alert_gen.subdivide_obs(self.obs_list, htmid_level=6) for htmid in alert_gen.htmid_list: alert_gen.alert_data_from_htmid(htmid, star_db, photometry_class=TestAlertsVarCat, output_prefix='alert_test', output_dir=self.output_dir, dmag_cutoff=dmag_cutoff, log_file_name=log_file_name) dummy_sed = Sed() bp_dict = BandpassDict.loadTotalBandpassesFromFiles() phot_params = PhotometricParameters() # First, verify that the contents of the sqlite files are all correct n_tot_simulated = 0 alert_query = 'SELECT alert.uniqueId, alert.obshistId, meta.TAI, ' alert_query += 'meta.band, quiescent.flux, alert.dflux, ' alert_query += 'quiescent.snr, alert.snr, ' alert_query += 'alert.ra, alert.dec, alert.chipNum, ' alert_query += 'alert.xPix, alert.yPix, ast.pmRA, ast.pmDec, ' alert_query += 'ast.parallax ' alert_query += 'FROM alert_data AS alert ' alert_query += 'INNER JOIN metadata AS meta ON meta.obshistId=alert.obshistId ' alert_query += 'INNER JOIN quiescent_flux AS quiescent ' alert_query += 'ON quiescent.uniqueId=alert.uniqueId ' alert_query += 'AND quiescent.band=meta.band ' alert_query += 'INNER JOIN baseline_astrometry AS ast ' alert_query += 'ON ast.uniqueId=alert.uniqueId' alert_dtype = np.dtype([('uniqueId', int), ('obshistId', int), ('TAI', float), ('band', int), ('q_flux', float), ('dflux', float), ('q_snr', float), ('tot_snr', float), ('ra', float), ('dec', float), ('chipNum', int), ('xPix', float), ('yPix', float), ('pmRA', float), ('pmDec', float), ('parallax', float)]) sqlite_file_list = os.listdir(self.output_dir) n_tot_simulated = 0 obshistid_unqid_simulated_set = set() for file_name in sqlite_file_list: if not file_name.endswith('db'): continue full_name = os.path.join(self.output_dir, file_name) self.assertTrue(os.path.exists(full_name)) alert_db = DBObject(full_name, driver='sqlite') alert_data = alert_db.execute_arbitrary(alert_query, dtype=alert_dtype) if len(alert_data) == 0: continue mjd_list = ModifiedJulianDate.get_list(TAI=alert_data['TAI']) for i_obj in range(len(alert_data)): n_tot_simulated += 1 obshistid_unqid_simulated_set.add((alert_data['uniqueId'][i_obj] << obshistid_bits) + alert_data['obshistId'][i_obj]) unq = alert_data['uniqueId'][i_obj] obj_dex = (unq//1024)-1 self.assertAlmostEqual(self.pmra_truth[obj_dex], 0.001*alert_data['pmRA'][i_obj], 4) self.assertAlmostEqual(self.pmdec_truth[obj_dex], 0.001*alert_data['pmDec'][i_obj], 4) self.assertAlmostEqual(self.px_truth[obj_dex], 0.001*alert_data['parallax'][i_obj], 4) ra_truth, dec_truth = applyProperMotion(self.ra_truth[obj_dex], self.dec_truth[obj_dex], self.pmra_truth[obj_dex], self.pmdec_truth[obj_dex], self.px_truth[obj_dex], self.vrad_truth[obj_dex], mjd=mjd_list[i_obj]) distance = angularSeparation(ra_truth, dec_truth, alert_data['ra'][i_obj], alert_data['dec'][i_obj]) distance_arcsec = 3600.0*distance msg = '\ntruth: %e %e\nalert: %e %e\n' % (ra_truth, dec_truth, alert_data['ra'][i_obj], alert_data['dec'][i_obj]) self.assertLess(distance_arcsec, 0.0005, msg=msg) obs = obs_dict[alert_data['obshistId'][i_obj]] chipname = chipNameFromRaDec(self.ra_truth[obj_dex], self.dec_truth[obj_dex], pm_ra=self.pmra_truth[obj_dex], pm_dec=self.pmdec_truth[obj_dex], parallax=self.px_truth[obj_dex], v_rad=self.vrad_truth[obj_dex], obs_metadata=obs, camera=self.camera) chipnum = int(chipname.replace('R', '').replace('S', ''). replace(' ', '').replace(';', '').replace(',', ''). replace(':', '')) self.assertEqual(chipnum, alert_data['chipNum'][i_obj]) xpix, ypix = pixelCoordsFromRaDec(self.ra_truth[obj_dex], self.dec_truth[obj_dex], pm_ra=self.pmra_truth[obj_dex], pm_dec=self.pmdec_truth[obj_dex], parallax=self.px_truth[obj_dex], v_rad=self.vrad_truth[obj_dex], obs_metadata=obs, camera=self.camera) self.assertAlmostEqual(alert_data['xPix'][i_obj], xpix, 4) self.assertAlmostEqual(alert_data['yPix'][i_obj], ypix, 4) dmag_sim = -2.5*np.log10(1.0+alert_data['dflux'][i_obj]/alert_data['q_flux'][i_obj]) self.assertAlmostEqual(true_lc_dict[alert_data['uniqueId'][i_obj]][alert_data['obshistId'][i_obj]], dmag_sim, 3) mag_name = ('u', 'g', 'r', 'i', 'z', 'y')[alert_data['band'][i_obj]] m5 = obs.m5[mag_name] q_mag = dummy_sed.magFromFlux(alert_data['q_flux'][i_obj]) self.assertAlmostEqual(self.mag0_truth_dict[alert_data['band'][i_obj]][obj_dex], q_mag, 4) snr, gamma = calcSNR_m5(self.mag0_truth_dict[alert_data['band'][i_obj]][obj_dex], bp_dict[mag_name], self.obs_mag_cutoff[alert_data['band'][i_obj]], phot_params) self.assertAlmostEqual(snr/alert_data['q_snr'][i_obj], 1.0, 4) tot_mag = self.mag0_truth_dict[alert_data['band'][i_obj]][obj_dex] + \ true_lc_dict[alert_data['uniqueId'][i_obj]][alert_data['obshistId'][i_obj]] snr, gamma = calcSNR_m5(tot_mag, bp_dict[mag_name], m5, phot_params) self.assertAlmostEqual(snr/alert_data['tot_snr'][i_obj], 1.0, 4) for val in obshistid_unqid_set: self.assertIn(val, obshistid_unqid_simulated_set) self.assertEqual(len(obshistid_unqid_set), len(obshistid_unqid_simulated_set)) astrometry_query = 'SELECT uniqueId, ra, dec, TAI ' astrometry_query += 'FROM baseline_astrometry' astrometry_dtype = np.dtype([('uniqueId', int), ('ra', float), ('dec', float), ('TAI', float)]) tai_list = [] for obs in self.obs_list: tai_list.append(obs.mjd.TAI) tai_list = np.array(tai_list) n_tot_ast_simulated = 0 for file_name in sqlite_file_list: if not file_name.endswith('db'): continue full_name = os.path.join(self.output_dir, file_name) self.assertTrue(os.path.exists(full_name)) alert_db = DBObject(full_name, driver='sqlite') astrometry_data = alert_db.execute_arbitrary(astrometry_query, dtype=astrometry_dtype) if len(astrometry_data) == 0: continue mjd_list = ModifiedJulianDate.get_list(TAI=astrometry_data['TAI']) for i_obj in range(len(astrometry_data)): n_tot_ast_simulated += 1 obj_dex = (astrometry_data['uniqueId'][i_obj]//1024) - 1 ra_truth, dec_truth = applyProperMotion(self.ra_truth[obj_dex], self.dec_truth[obj_dex], self.pmra_truth[obj_dex], self.pmdec_truth[obj_dex], self.px_truth[obj_dex], self.vrad_truth[obj_dex], mjd=mjd_list[i_obj]) distance = angularSeparation(ra_truth, dec_truth, astrometry_data['ra'][i_obj], astrometry_data['dec'][i_obj]) self.assertLess(3600.0*distance, 0.0005) del alert_gen gc.collect() self.assertGreater(n_tot_simulated, 10) self.assertGreater(len(obshistid_unqid_simulated_set), 10) self.assertLess(len(obshistid_unqid_simulated_set), n_total_observations) self.assertGreater(n_tot_ast_simulated, 0)
def write_alerts(self, obshistid, data_dir, prefix_list, htmid_list, out_dir, out_prefix, dmag_cutoff, lock=None, log_file_name=None): """ Write the alerts for an obsHistId to a properly formatted avro file. Parameters ---------- obshistid is the integer uniquely identifying the OpSim pointing being simulated data_dir is the directory containing the sqlite files created by the AlertDataGenerator prefix_list is a list of prefixes for those sqlite files. htmid_list is the list of htmids identifying the trixels that overlap this obshistid's field of view. For each htmid in htmid_list and each prefix in prefix_list, this method will process the files data_dir/prefix_htmid_sqlite.db searching for alerts that correspond to this obshistid out_dir is the directory to which the avro files should be written out_prefix is the prefix of the avro file names dmag_cutoff is the minimum delta magnitude needed to trigger an alert lock is an optional multiprocessing.Lock() for use when running many instances of this method. It prevents multiple processes from writing to the logfile or stdout at once. log_file_name is the name of an optional text file to which progress is written. """ out_name = os.path.join(out_dir, '%s_%d.avro' % (out_prefix, obshistid)) if os.path.exists(out_name): os.unlink(out_name) with DataFileWriter(open(out_name, "wb"), DatumWriter(), self._alert_schema) as data_writer: diasource_query = 'SELECT alert.uniqueId, alert.xPix, alert.yPix, ' diasource_query += 'alert.chipNum, alert.dflux, alert.snr, alert.ra, alert.dec, ' diasource_query += 'meta.band, meta.TAI, quiescent.flux, quiescent.snr ' diasource_query += 'FROM alert_data as alert ' diasource_query += 'INNER JOIN metadata AS meta ON alert.obshistId=meta.obshistId ' diasource_query += 'INNER JOIN quiescent_flux AS quiescent ON quiescent.uniqueId=alert.uniqueID ' diasource_query += 'AND quiescent.band=meta.band ' diasource_query += 'WHERE alert.obshistId=%d ' % obshistid diasource_query += 'ORDER BY alert.uniqueId' diasource_dtype = np.dtype([('uniqueId', int), ('xPix', float), ('yPix', float), ('chipNum', int), ('dflux', float), ('tot_snr', float), ('ra', float), ('dec', float), ('band', int), ('TAI', float), ('quiescent_flux', float), ('quiescent_snr', float)]) diaobject_query = 'SELECT uniqueId, ra, dec, TAI, pmRA, pmDec, parallax ' diaobject_query += 'FROM baseline_astrometry' diaobject_dtype = np.dtype([('uniqueId', int), ('ra', float), ('dec', float), ('TAI', float), ('pmRA', float), ('pmDec', float), ('parallax', float)]) t_start = time.time() alert_ct = 0 for htmid in htmid_list: for prefix in prefix_list: db_name = os.path.join(data_dir, '%s_%d_sqlite.db' % (prefix, htmid)) if not os.path.exists(db_name): warnings.warn('%s does not exist' % db_name) continue db_obj = DBObject(db_name, driver='sqlite') diaobject_data = db_obj.execute_arbitrary(diaobject_query, dtype=diaobject_dtype) diaobject_dict = self._create_objects(diaobject_data) diasource_data = db_obj.execute_arbitrary(diasource_query, dtype=diasource_dtype) dmag = 2.5*np.log10(1.0+diasource_data['dflux']/diasource_data['quiescent_flux']) valid_alerts = np.where(np.abs(dmag) >= dmag_cutoff) diasource_data = diasource_data[valid_alerts] avro_diasource_list = self._create_sources(obshistid, diasource_data) for i_source in range(len(avro_diasource_list)): alert_ct += 1 unq = diasource_data[i_source]['uniqueId'] diaobject = diaobject_dict[unq] diasource = avro_diasource_list[i_source] avro_alert = {} avro_alert['alertId'] = np.long((obshistid << 20) + alert_ct) avro_alert['l1dbId'] = np.long(unq) avro_alert['diaSource'] = diasource avro_alert['diaObject'] = diaobject data_writer.append(avro_alert) if lock is not None: lock.acquire() elapsed = (time.time()-t_start)/3600.0 msg = 'finished obshistid %d; %d alerts in %.2e hrs' % (obshistid, alert_ct, elapsed) print(msg) if log_file_name is not None: with open(log_file_name, 'a') as out_file: out_file.write(msg) out_file.write('\n') if lock is not None: lock.release()
def __init__(self, database=None, driver='sqlite', host=None, port=None): """ Constructor for the class Parameters ---------- database : string absolute path to the output of the OpSim database driver : string, optional, defaults to 'sqlite' driver/dialect for the SQL database host : hostName, optional, defaults to None, hostName, None is good for a local database port : hostName, optional, defaults to None, port, None is good for a local database Returns ------ Instance of the ObserverMetaDataGenerator class ..notes : For testing purposes a small OpSim database is available at `os.path.join(getPackageDir('sims_data'), 'OpSimData/opsimblitz1_1133_sqlite.db')` """ self.driver = driver self.host = host self.port = port self.database = database self._seeing_column = 'FWHMeff' # a dict keyed on the user interface names of the OpSimdata columns # (i.e. the args to getObservationMetaData). Returns a tuple that is the # (name of data column in OpSim, transformation to go from user interface to OpSim units, # dtype in OpSim) # # Note: this dict will contain entries for every column (except propID) in the OpSim # summary table, not just those the ObservationMetaDataGenerator is designed to query # on. The idea is that ObservationMetaData generated by this class will carry around # records of the values of all of the associated OpSim Summary columns so that users # can pass those values on to PhoSim/other tools and thier own discretion. self._user_interface_to_opsim = {'obsHistID': ('obsHistID', None, np.int64), 'expDate': ('expDate', None, int), 'fieldRA': ('fieldRA', np.radians, float), 'fieldDec': ('fieldDec', np.radians, float), 'moonRA': ('moonRA', np.radians, float), 'moonDec': ('moonDec', np.radians, float), 'rotSkyPos': ('rotSkyPos', np.radians, float), 'telescopeFilter': ('filter', lambda x: '\'{}\''.format(x), (str, 1)), 'rawSeeing': ('rawSeeing', None, float), 'sunAlt': ('sunAlt', np.radians, float), 'moonAlt': ('moonAlt', np.radians, float), 'dist2Moon': ('dist2Moon', np.radians, float), 'moonPhase': ('moonPhase', None, float), 'expMJD': ('expMJD', None, float), 'altitude': ('altitude', np.radians, float), 'azimuth': ('azimuth', np.radians, float), 'visitExpTime': ('visitExpTime', None, float), 'airmass': ('airmass', None, float), 'm5': ('fiveSigmaDepth', None, float), 'skyBrightness': ('filtSkyBrightness', None, float), 'sessionID': ('sessionID', None, int), 'fieldID': ('fieldID', None, int), 'night': ('night', None, int), 'visitTime': ('visitTime', None, float), 'finRank': ('finRank', None, float), 'FWHMgeom': ('FWHMgeom', None, float), # do not include FWHMeff; that is detected by # self._set_seeing_column() 'transparency': ('transparency', None, float), 'vSkyBright': ('vSkyBright', None, float), 'rotTelPos': ('rotTelPos', None, float), 'lst': ('lst', None, float), 'solarElong': ('solarElong', None, float), 'moonAz': ('moonAz', None, float), 'sunAz': ('sunAz', None, float), 'phaseAngle': ('phaseAngle', None, float), 'rScatter': ('rScatter', None, float), 'mieScatter': ('mieScatter', None, float), 'moonBright': ('moonBright', None, float), 'darkBright': ('darkBright', None, float), 'wind': ('wind', None, float), 'humidity': ('humidity', None, float), 'slewDist': ('slewDist', None, float), 'slewTime': ('slewTime', None, float), 'ditheredRA': ('ditheredRA', None, float), 'ditheredDec': ('ditheredDec', None, float)} if self.database is None: return if not os.path.exists(self.database): raise RuntimeError('%s does not exist' % self.database) self.opsimdb = DBObject(driver=self.driver, database=self.database, host=self.host, port=self.port) # 27 January 2016 # Detect whether the OpSim db you are connecting to uses 'finSeeing' # as its seeing column (deprecated), or FWHMeff, which is the modern # standard self._summary_columns = self.opsimdb.get_column_names('Summary') self._set_seeing_column(self._summary_columns) # Set up self.dtype containg the dtype of the recarray we expect back from the SQL query. # Also setup baseQuery which is just the SELECT clause of the SQL query # # self.active_columns will be a list containing the subset of OpSim database columns # (specified in self._user_interface_to_opsim) that actually exist in this opsim database dtypeList = [] self.baseQuery = 'SELECT' self.active_columns = [] self._queried_columns = [] # This will be a list of all of the # OpSim columns queried # Note: here we will refer to the # columns by their names in OpSim for column in self._user_interface_to_opsim: rec = self._user_interface_to_opsim[column] if rec[0] in self._summary_columns: self.active_columns.append(column) dtypeList.append((rec[0], rec[2])) if self.baseQuery != 'SELECT': self.baseQuery += ',' self.baseQuery += ' ' + rec[0] self._queried_columns.append(rec[0]) # Now loop over self._summary_columns, adding any columns # to the query that have not already been included therein. # Since we do not have explicit information about the # data types of these columns, we will assume they are floats. for column in self._summary_columns: if column not in self._queried_columns: self.baseQuery += ', ' + column dtypeList.append((column, float)) self.dtype = np.dtype(dtypeList)
def write_sprinkled_lc(out_file_name, total_obs_md, pointing_dir, opsim_db_name, ra_colname='descDitheredRA', dec_colname='descDitheredDec', rottel_colname = 'descDitheredRotTelPos', sql_file_name=None, bp_dict=None): """ Create database of light curves Note: this is still under development. It has not yet been used for a production-level truth catalog Parameters ---------- out_file_name is the name of the sqlite file to be written total_obs_md is an ObservationMetaData covering the whole survey area pointing_dir contains a series of files that are two columns: obshistid, mjd. The files must each have 'visits' in their name. These specify the pointings for which we are assembling data. See: https://github.com/LSSTDESC/DC2_Repo/tree/master/data/Run1.1 for an example. opsim_db_name is the name of the OpSim database to be queried for pointings ra_colname is the column used for RA of the pointing (default: descDitheredRA) dec_colname is the column used for the Dec of the pointing (default: descDitheredDec) rottel_colname is the column used for the rotTelPos of the pointing (default: desckDitheredRotTelPos') sql_file_name is the name of the parameter database produced by write_sprinkled_param_db to be used bp_dict is a BandpassDict of the telescope filters to be used Returns ------- None Writes out a database to out_file_name. The tables of this database and their columns are: light_curves: - uniqueId -- an int unique to all objects - obshistid -- an int unique to all pointings - mag -- the magnitude observed for this object at that pointing obs_metadata: - obshistid -- an int unique to all pointings - mjd -- the date of the pointing - filter -- an int corresponding to the telescope filter (0==u, 1==g..) variables_and_transients: - uniqueId -- an int unique to all objects - galaxy_id -- an int indicating the host galaxy - ra -- in degrees - dec -- in degrees - agn -- ==1 if object is an AGN - sn -- ==1 if object is a supernova """ t0_master = time.time() if not os.path.isfile(sql_file_name): raise RuntimeError('%s does not exist' % sql_file_name) sn_simulator = SneSimulator(bp_dict) sed_dir = os.environ['SIMS_SED_LIBRARY_DIR'] create_sprinkled_sql_file(out_file_name) t_start = time.time() # get data about the pointings being simulated (htmid_dict, mjd_dict, filter_dict, obsmd_dict) = get_pointing_htmid(pointing_dir, opsim_db_name, ra_colname=ra_colname, dec_colname=dec_colname) t_htmid_dict = time.time()-t_start bp_to_int = {'u':0, 'g':1, 'r':2, 'i':3, 'z':4, 'y':5} # put the data about the pointings in the obs_metadata table with sqlite3.connect(out_file_name) as conn: cursor = conn.cursor() values = ((int(obs), mjd_dict[obs], bp_to_int[filter_dict[obs]]) for obs in mjd_dict) cursor.executemany('''INSERT INTO obs_metadata VALUES (?,?,?)''', values) cursor.execute('''CREATE INDEX obs_filter ON obs_metadata (obshistid, filter)''') conn.commit() print('\ngot htmid_dict -- %d in %e seconds' % (len(htmid_dict), t_htmid_dict)) db = DBObject(sql_file_name, driver='sqlite') # get a list of htmid corresponding to trixels in which # variables and transients can be found query = 'SELECT DISTINCT htmid FROM zpoint WHERE is_agn=1 OR is_sn=1' dtype = np.dtype([('htmid', int)]) results = db.execute_arbitrary(query, dtype=dtype) object_htmid = results['htmid'] agn_dtype = np.dtype([('uniqueId', int), ('galaxy_id', int), ('ra', float), ('dec', float), ('redshift', float), ('sed', str, 500), ('magnorm', float), ('varParamStr', str, 500), ('is_sprinkled', int)]) agn_base_query = 'SELECT uniqueId, galaxy_id, ' agn_base_query += 'raJ2000, decJ2000, ' agn_base_query += 'redshift, sedFilepath, ' agn_base_query += 'magNorm, varParamStr, is_sprinkled ' agn_base_query += 'FROM zpoint WHERE is_agn=1 ' sn_dtype = np.dtype([('uniqueId', int), ('galaxy_id', int), ('ra', float), ('dec', float), ('redshift', float), ('sn_truth_params', str, 500), ('is_sprinkled', int)]) sn_base_query = 'SELECT uniqueId, galaxy_id, ' sn_base_query += 'raJ2000, decJ2000, ' sn_base_query += 'redshift, sn_truth_params, is_sprinkled ' sn_base_query += 'FROM zpoint WHERE is_sn=1 ' filter_to_int = {'u':0, 'g':1, 'r':2, 'i':3, 'z':4, 'y':5} n_floats = 0 with sqlite3.connect(out_file_name) as conn: cursor = conn.cursor() t_before_htmid = time.time() # loop over trixels containing variables and transients, simulating # the light curves of those objects for htmid_dex, htmid in enumerate(object_htmid): if htmid_dex>0: htmid_duration = (time.time()-t_before_htmid)/3600.0 htmid_prediction = len(object_htmid)*htmid_duration/htmid_dex print('%d htmid out of %d in %e hours; predict %e hours remaining' % (htmid_dex, len(object_htmid), htmid_duration,htmid_prediction-htmid_duration)) mjd_arr = [] obs_arr = [] filter_arr = [] # Find only those pointings which overlap the current trixel for obshistid in htmid_dict: is_contained = False for bounds in htmid_dict[obshistid]: if htmid<=bounds[1] and htmid>=bounds[0]: is_contained = True break if is_contained: mjd_arr.append(mjd_dict[obshistid]) obs_arr.append(obshistid) filter_arr.append(filter_to_int[filter_dict[obshistid]]) if len(mjd_arr) == 0: continue mjd_arr = np.array(mjd_arr) obs_arr = np.array(obs_arr) filter_arr = np.array(filter_arr) sorted_dex = np.argsort(mjd_arr) mjd_arr = mjd_arr[sorted_dex] obs_arr = obs_arr[sorted_dex] filter_arr = filter_arr[sorted_dex] agn_query = agn_base_query + 'AND htmid=%d' % htmid agn_iter = db.get_arbitrary_chunk_iterator(agn_query, dtype=agn_dtype, chunk_size=10000) # put static data about the AGN (position, etc.) into the # variables_and_transients table for i_chunk, agn_results in enumerate(agn_iter): values = ((int(agn_results['uniqueId'][i_obj]), int(agn_results['galaxy_id'][i_obj]), np.degrees(agn_results['ra'][i_obj]), np.degrees(agn_results['dec'][i_obj]), int(agn_results['is_sprinkled'][i_obj]), 1,0) for i_obj in range(len(agn_results))) cursor.executemany('''INSERT INTO variables_and_transients VALUES (?,?,?,?,?,?,?)''', values) agn_simulator = AgnSimulator(agn_results['redshift']) quiescent_mag = np.zeros((len(agn_results), 6), dtype=float) for i_obj, (sed_name, zz, mm) in enumerate(zip(agn_results['sed'], agn_results['redshift'], agn_results['magnorm'])): spec = Sed() spec.readSED_flambda(os.path.join(sed_dir, sed_name)) fnorm = getImsimFluxNorm(spec, mm) spec.multiplyFluxNorm(fnorm) spec.redshiftSED(zz, dimming=True) mag_list = bp_dict.magListForSed(spec) quiescent_mag[i_obj] = mag_list # simulate AGN variability dmag = agn_simulator.applyVariability(agn_results['varParamStr'], expmjd=mjd_arr) # loop over pointings that overlap the current trixel, writing # out simulated photometry for each AGN observed in that pointing for i_time, obshistid in enumerate(obs_arr): # only include objects that were actually on a detector are_on_chip = _actually_on_chip(np.degrees(agn_results['ra']), np.degrees(agn_results['dec']), obsmd_dict[obshistid]) valid_agn = np.where(are_on_chip) if len(valid_agn[0])==0: continue values = ((int(agn_results['uniqueId'][i_obj]), int(obs_arr[i_time]), quiescent_mag[i_obj][filter_arr[i_time]]+ dmag[filter_arr[i_time]][i_obj][i_time]) for i_obj in valid_agn[0]) cursor.executemany('''INSERT INTO light_curves VALUES (?,?,?)''', values) conn.commit() n_floats += len(dmag.flatten()) sn_query = sn_base_query + 'AND htmid=%d' % htmid sn_iter = db.get_arbitrary_chunk_iterator(sn_query, dtype=sn_dtype, chunk_size=10000) for sn_results in sn_iter: t0_sne = time.time() # write static information about SNe to the # variables_and_transients table values = ((int(sn_results['uniqueId'][i_obj]), int(sn_results['galaxy_id'][i_obj]), np.degrees(sn_results['ra'][i_obj]), np.degrees(sn_results['dec'][i_obj]), int(sn_results['is_sprinkled'][i_obj]), 0,1) for i_obj in range(len(sn_results))) cursor.executemany('''INSERT INTO variables_and_transients VALUES (?,?,?,?,?,?,?)''', values) conn.commit() sn_mags = sn_simulator.calculate_sn_magnitudes(sn_results['sn_truth_params'], mjd_arr, filter_arr) print(' did %d sne in %e seconds' % (len(sn_results), time.time()-t0_sne)) # loop over pointings that overlap the current trixel, writing # out simulated photometry for each SNe observed in that pointing for i_time, obshistid in enumerate(obs_arr): # only include objects that fell on a detector are_on_chip = _actually_on_chip(np.degrees(sn_results['ra']), np.degrees(sn_results['dec']), obsmd_dict[obshistid]) valid_obj = np.where(np.logical_and(np.isfinite(sn_mags[:,i_time]), are_on_chip)) if len(valid_obj[0]) == 0: continue values = ((int(sn_results['uniqueId'][i_obj]), int(obs_arr[i_time]), sn_mags[i_obj][i_time]) for i_obj in valid_obj[0]) cursor.executemany('''INSERT INTO light_curves VALUES (?,?,?)''', values) conn.commit() n_floats += len(valid_obj[0]) cursor.execute('CREATE INDEX unq_obs ON light_curves (uniqueId, obshistid)') conn.commit() print('n_floats %d' % n_floats) print('in %e seconds' % (time.time()-t0_master))
parser.add_argument('--partition', type=str, default=None, help='htmid tag of stars_partition_* table to run') parser.add_argument('--outdir', type=str, default=None, help='dir to write output to') parser.add_argument('--n_procs', type=int, default=20) args = parser.parse_args() assert args.partition is not None assert args.outdir is not None assert os.path.isdir(args.outdir) try: db = DBObject(database='LSST', port=1433, host='epyc.astro.washington.edu', driver='mssql+pymssql') except: db = DBObject(database='LSST', port=51432, host='localhost', driver='mssql+pymssql') table_name = 'stars_partition_%s' % args.partition out_name = os.path.join(args.outdir,'isvar_lookup_%s.txt' % args.partition) if os.path.isfile(out_name): os.unlink(out_name) #raise RuntimeError("\n%s\nexists\n" % out_name) query = "SELECT "
bp_dict = BandpassDict.loadTotalBandpassesFromFiles() bp = bp_dict['i'] z_grid = np.arange(0.0, 16.0, 0.01) k_grid = np.zeros(len(z_grid), dtype=float) for i_z, zz in enumerate(z_grid): ss = Sed(flambda=base_sed.flambda, wavelen=base_sed.wavelen) ss.redshiftSED(zz, dimming=True) k = k_correction(ss, bp, zz) k_grid[i_z] = k cosmo = CosmologyObject() db = DBObject(database='LSSTCATSIM', host='fatboy.phys.washington.edu', port=1433, driver='mssql+pymssql') query = 'SELECT magnorm_agn, redshift, varParamStr FROM ' query += 'galaxy WHERE varParamStr IS NOT NULL ' query += 'AND dec BETWEEN -2.5 AND 2.5 ' query += 'AND (ra<2.5 OR ra>357.5)' dtype = np.dtype([('magnorm', float), ('redshift', float), ('varParamStr', str, 400)]) data_iter = db.get_arbitrary_chunk_iterator(query, dtype=dtype, chunk_size=10000) with open('data/dc1_agn_params.txt', 'w') as out_file:
def test_catalog_db_object_cacheing(self): """ Test that opening multiple CatalogDBObjects that connect to the same database only results in one connection being opened and used. We will test this by instantiating two CatalogDBObjects and a DBObject that connect to the same database. We will then test that the two CatalogDBObjects' connections are identical, but that the DBObject has its own connection. """ self.assertEqual(len(CatalogDBObject._connection_cache), 0) class DbClass1(CatalogDBObject): database = self.db_name port = None host = None driver = 'sqlite' tableid = 'test' idColKey = 'id' objid = 'test_db_class_1' columns = [('identification', 'id')] class DbClass2(CatalogDBObject): database = self.db_name port = None host = None driver = 'sqlite' tableid = 'test' idColKey = 'id' objid = 'test_db_class_2' columns = [('other', 'i1')] db1 = DbClass1() db2 = DbClass2() self.assertEqual(id(db1.connection), id(db2.connection)) self.assertEqual(len(CatalogDBObject._connection_cache), 1) db3 = DBObject(database=self.db_name, driver='sqlite', host=None, port=None) self.assertNotEqual(id(db1.connection), id(db3.connection)) self.assertEqual(len(CatalogDBObject._connection_cache), 1) # check that if we had passed db1.connection to a DBObject, # the connections would be identical db4 = DBObject(connection=db1.connection) self.assertEqual(id(db4.connection), id(db1.connection)) self.assertEqual(len(CatalogDBObject._connection_cache), 1) # verify that db1 and db2 are both useable results = db1.query_columns( colnames=['id', 'i1', 'i2', 'identification']) results = next(results) self.assertEqual(len(results), 5) np.testing.assert_array_equal(results['id'], list(range(5))) np.testing.assert_array_equal(results['id'], results['identification']) np.testing.assert_array_equal(results['id']**2, results['i1']) np.testing.assert_array_equal(results['id'] * (-1), results['i2']) results = db2.query_columns(colnames=['id', 'i1', 'i2', 'other']) results = next(results) self.assertEqual(len(results), 5) np.testing.assert_array_equal(results['id'], list(range(5))) np.testing.assert_array_equal(results['id']**2, results['i1']) np.testing.assert_array_equal(results['i1'], results['other']) np.testing.assert_array_equal(results['id'] * (-1), results['i2'])