def main(pool, circ, P, time_sampling, overwrite=False): pars = JokerParams(P_min=1 * u.day, P_max=1024 * u.day) joker = TheJoker(pars, pool=pool) # make the prior cache file prior_file, samples_file = make_caches(2**28, joker, circ=circ, P=P, overwrite=overwrite, time_sampling=time_sampling) logger.info('Files: {0}, {1}'.format(prior_file, samples_file)) n_epochs = np.arange(3, 12 + 1, 1) for n_epoch, i, data, P in make_data(n_epochs, n_orbits=512, P=P, time_sampling=time_sampling, circ=circ): logger.debug("N epochs: {0}, orbit {1}".format(n_epoch, i)) key = '{0}-{1}'.format(n_epoch, i) with h5py.File(samples_file, 'r') as f: if key in f: logger.debug('-- already done! skipping...') continue samples = joker.iterative_rejection_sample(n_requested_samples=256, prior_cache_file=prior_file, data=data) logger.debug("-- done sampling - {0} samples returned".format( len(samples))) with h5py.File(samples_file) as f: g = f.create_group(key) samples.to_hdf5(g) g.attrs['P'] = P
def main(data_path, apogee_id, config, data_file_ext, pool, overwrite=False): cache_path = path.join(data_path, 'cache') os.makedirs(cache_path, exist_ok=True) # Create TheJoker sampler instance with the specified random seed and pool rnd = np.random.RandomState() logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool)) joker = TheJoker(joker_params, random_state=rnd, pool=pool) n_walkers = config['emcee']['n_walkers'] n_steps = config['emcee']['n_steps'] mcmc_model_filename = path.join(cache_path, 'model.pickle') logger.info('Processing {0}'.format(apogee_id)) joker_results_filename = path.join(cache_path, '{0}-joker.hdf5'.format(apogee_id)) mcmc_results_filename = path.join(cache_path, '{0}-mcmc.hdf5'.format(apogee_id)) mcmc_chain_filename = path.join(cache_path, '{0}-chain.npy'.format(apogee_id)) with h5py.File(joker_results_filename) as f: joker_samples = JokerSamples.from_hdf5(f) model = TheJokerMCMCModel(joker_params=joker_params, data=data) if not path.exists(mcmc_chain_filename) or overwrite: joker = TheJoker(joker_params) joker.params.jitter = (8.5, 0.9) # HACK! if unimodal_P(joker_samples, data): logger.debug("Samples are unimodal. Preparing to run MCMC...") sample = MAP_sample(data, joker_samples, joker.params) ball_scale = 1E-5 p0_mean = np.squeeze(model.pack_samples(sample)) # P, M0, e, omega, jitter, K, v0 p0 = np.zeros((n_walkers, len(p0_mean))) for i in range(p0.shape[1]): if i in [2, 4]: # eccentricity, jitter p0[:, i] = np.abs( np.random.normal(p0_mean[i], ball_scale, size=n_walkers)) else: p0[:, i] = np.random.normal(p0_mean[i], ball_scale, size=n_walkers) p0 = model.to_mcmc_params(p0.T).T n_dim = p0.shape[1] sampler = emcee.EnsembleSampler(n_walkers, n_dim, logprob, pool=pool) logger.debug('Running MCMC for {0} steps...'.format(n_steps)) time0 = time.time() _ = sampler.run_mcmc(p0, n_steps) logger.debug('...time spent sampling: {0}'.format(time.time() - time0)) samples = model.unpack_samples_mcmc(sampler.chain[:, -1]) samples.t0 = joker_samples.t0 np.save(mcmc_chain_filename, sampler.chain.astype('f4')) if not path.exists(mcmc_model_filename): with open(mcmc_model_filename, 'wb') as f: pickle.dump(model, f) else: logger.error("Samples are multimodal!") return if not path.exists(mcmc_results_filename) or overwrite: chain = np.load(mcmc_chain_filename) with h5py.File(mcmc_results_filename) as f: chain = np.load(mcmc_chain_filename) n_walkers, n_steps, n_pars = chain.shape logger.debug('Adding star {0} to MCMC cache'.format(apogee_id)) try: g = f.create_group('chain-stats') Rs = gelman_rubin(chain[:, n_steps // 2:]) g.create_dataset(name='gelman_rubin', data=Rs) # take the last sample, downsample end_pos = chain[:, n_steps // 2::1024].reshape(-1, n_pars) samples = model.unpack_samples_mcmc(end_pos) samples.to_hdf5(f) except Exception as e: raise finally: del g
def initialize_db(allVisit_file, allStar_file, database_path, drop_all=False, batch_size=4096): """Initialize the database given FITS filenames for the APOGEE data. Parameters ---------- allVisit_file : str Full path to APOGEE allVisit file. allStar_file : str Full path to APOGEE allStar file. database_file : str Filename (not path) of database file in cache path. drop_all : bool (optional) Drop all existing tables and re-create the database. batch_size : int (optional) How many rows to create before committing. """ norm = lambda x: abspath(expanduser(x)) allvisit = fits.getdata(norm(allVisit_file)) allstar = fits.getdata(norm(allStar_file)) # STAR_BAD ASPCAP_skip_bitmask = np.sum(2**np.array([23])) # VERY_BRIGHT_NEIGHBOR, LOW_SNR, SUSPECT_RV_COMBINATION, SUSPECT_BROAD_LINES STAR_skip_bitmask = np.sum(2**np.array([3, 4, 16, 17])) # First filter allStar flags mask = ((allstar['NVISITS'] >= 4) & ((allstar['ASPCAPFLAG'] & ASPCAP_skip_bitmask) == 0) & ((allstar['STARFLAG'] & STAR_skip_bitmask) == 0)) stars = allstar[mask] visits = allvisit[np.isin(allvisit['APOGEE_ID'], stars['APOGEE_ID'])] # Next filter allVisit flags mask = (((visits['STARFLAG'] & STAR_skip_bitmask) == 0) & (visits['VRELERR'] < 100.) & np.isfinite(visits['VHELIO']) & np.isfinite(visits['VRELERR']) & (visits['VHELIO'] > -999)) visits = visits[mask] v_apogee_ids, counts = np.unique(visits['APOGEE_ID'], return_counts=True) stars = stars[np.isin(stars['APOGEE_ID'], v_apogee_ids[counts >= 4])] visits = visits[np.isin(visits['APOGEE_ID'], stars['APOGEE_ID'])] # uniquify the stars _, idx = np.unique(stars['APOGEE_ID'], return_index=True) allstar_tbl = Table(stars[idx]) allvisit_tbl = Table(visits) Session, engine = db_connect(database_path, ensure_db_exists=True) logger.debug("Connected to database at '{}'".format(database_path)) if drop_all: # this is the magic that creates the tables based on the definitions in # twoface/db/model.py Base.metadata.drop_all() Base.metadata.create_all() session = Session() logger.debug("Loading allStar, allVisit tables...") # Figure out what data we need to pull out of the FITS files based on what # columns exist in the (empty) database allstar_skip = ['ID'] allstar_colnames = [] allstar_varchar = [] for x in AllStar.__table__.columns: col = str(x).split('.')[1].upper() if col in allstar_skip: continue if str(x.type) == 'VARCHAR': allstar_varchar.append(col) allstar_colnames.append(col) allvisit_skip = ['ID'] allvisit_colnames = [] allvisit_varchar = [] for x in AllVisit.__table__.columns: col = str(x).split('.')[1].upper() if col in allvisit_skip: continue if str(x.type) == 'VARCHAR': allvisit_varchar.append(col) allvisit_colnames.append(col) # -------------------------------------------------------------------------- # First load the status table: # if session.query(Status).count() == 0: logger.debug("Populating Status table...") statuses = list() statuses.append(Status(id=0, message='untouched')) statuses.append(Status(id=1, message='needs more prior samples')) statuses.append(Status(id=2, message='needs mcmc')) statuses.append(Status(id=3, message='error')) statuses.append(Status(id=4, message='completed')) session.add_all(statuses) session.commit() logger.debug("...done") # -------------------------------------------------------------------------- # Load the AllStar table: # logger.info("Loading AllStar table") # What APOGEE_ID's are already loaded? all_ap_ids = np.array([x.strip() for x in allstar_tbl['APOGEE_ID']]) loaded_ap_ids = [x[0] for x in session.query(AllStar.apogee_id).all()] mask = np.logical_not(np.isin(all_ap_ids, loaded_ap_ids)) logger.debug("{0} stars already loaded".format(len(loaded_ap_ids))) logger.debug("{0} stars left to load".format(mask.sum())) stars = [] with Timer() as t: i = 0 for row in allstar_tbl[mask]: # Load every star row_data = tblrow_to_dbrow(row, allstar_colnames, allstar_varchar) # create a new object for this row star = AllStar(**row_data) stars.append(star) logger.log(1, 'Adding star {0} to database'.format(star)) if i % batch_size == 0 and i > 0: session.add_all(stars) session.commit() logger.debug("Loaded batch {0} ({1:.2f} seconds)".format( i, t.elapsed())) t.reset() stars = [] i += 1 if len(stars) > 0: session.add_all(stars) session.commit() # -------------------------------------------------------------------------- # Load the AllVisit table: # logger.info("Loading AllVisit table") # What VISIT_ID's are already loaded? all_vis_ids = np.array([x.strip() for x in allvisit_tbl['VISIT_ID']]) loaded_vis_ids = [x[0] for x in session.query(AllVisit.visit_id).all()] mask = np.logical_not(np.isin(all_vis_ids, loaded_vis_ids)) logger.debug("{0} visits already loaded".format(len(loaded_vis_ids))) logger.debug("{0} visits left to load".format(mask.sum())) visits = [] with Timer() as t: i = 0 for row in allvisit_tbl[mask]: # Load every visit row_data = tblrow_to_dbrow(row, allvisit_colnames, allvisit_varchar) # create a new object for this row visit = AllVisit(**row_data) visits.append(visit) logger.log(1, 'Adding visit {0} to database'.format(visit)) if i % batch_size == 0 and i > 0: session.add_all(visits) session.commit() logger.debug("Loaded batch {0} ({1:.2f} seconds)".format( i, t.elapsed())) t.reset() visits = [] i += 1 if len(visits) > 0: session.add_all(visits) session.commit() # -------------------------------------------------------------------------- # Now associate rows in AllStar with rows in AllVisit logger.info("Linking AllVisit and AllStar tables") q = session.query(AllStar).order_by(AllStar.id) for i, sub_q in enumerate(paged_query(q, page_size=batch_size)): for star in sub_q: if len(star.visits) > 0: continue visits = session.query(AllVisit).filter( AllVisit.apogee_id == star.apogee_id).all() if len(visits) == 0: logger.warn("Visits not found for star {0}".format(star)) continue logger.log( 1, 'Attaching {0} visits to star {1}'.format(len(visits), star)) star.visits = visits logger.debug("Committing batch {0}".format(i)) session.commit() session.commit() session.close()
def main(data_path, config_file, data_file_ext, pool, seed, overwrite=False): # parse config file with open(config_file, 'r') as f: config = yaml.load(f.read()) config['config_file'] = config_file cache_path = path.join(data_path, 'cache') os.makedirs(cache_path, exist_ok=True) n_prior_samples = config['prior']['num_cache'] n_walkers = config['emcee']['n_walkers'] joker_pars = config_to_jokerparams(config) prior_samples_file = path.join(cache_path, 'prior-samples.hdf5') # Create TheJoker sampler instance with the specified random seed and pool rnd = np.random.RandomState(seed=seed) logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool)) joker = TheJoker(joker_pars, random_state=rnd, pool=pool) # Create a cache of prior samples (if it doesn't exist) and store the # filename in the database. if not os.path.exists(prior_samples_file) or overwrite: logger.debug("Prior samples file not found - generating {0} samples..." .format(n_prior_samples)) make_prior_cache(prior_samples_file, joker, nsamples=n_prior_samples) logger.debug("...done") mcmc_model_filename = path.join(cache_path, 'model.pickle') data_files = glob.glob(path.join(data_path, '*.{0}'.format(data_file_ext))) for filename in data_files: basename = path.splitext(path.basename(filename))[0] logger.info('Processing file {0}'.format(basename)) joker_results_filename = path.join(cache_path, '{0}-joker.hdf5'.format(basename)) mcmc_results_filename = path.join(cache_path, '{0}-mcmc.hdf5'.format(basename)) mcmc_chain_filename = path.join(cache_path, '{0}-chain.npy'.format(basename)) data_tbl = QTable.read(filename) data = RVData(t=data_tbl['time'], rv=data_tbl['rv'], stddev=data_tbl['rv_err']) if not path.exists(joker_results_filename) or overwrite: t0 = time.time() logger.log(1, "\t visits loaded ({:.2f} seconds)" .format(time.time() - t0)) try: samples = joker.rejection_sample( data=data, prior_cache_file=prior_samples_file, return_logprobs=False) except Exception as e: logger.warning("\t Failed sampling for star {0} \n Error: {1}" .format(basename, str(e))) pool.close() sys.exit(1) logger.debug("\t done sampling ({:.2f} seconds)" .format(time.time() - t0)) # Write the samples that pass to the results file with h5py.File(joker_results_filename, 'w') as f: samples.to_hdf5(f) logger.debug("\t saved samples ({:.2f} seconds)" .format(time.time() - t0)) logger.debug("...done with star {} ({:.2f} seconds)" .format(basename, time.time() - t0)) pool.close() sys.exit(0)
def main(pool, seed, overwrite=False, _continue=False): # HACK: hard-coded configuration! db_path = path.abspath('../cache/apogeebh.sqlite') prior_samples_file = path.abspath('../cache/prior-samples.hdf5') results_filename = path.abspath('../cache/apogeebh-joker.hdf5') n_prior = 536870912 # number of prior samples to generate n_requested_samples = 256 # how many samples to generate, nominally max_samples_per_star = 2048 # max. number of posterior samples to save P_min = 1 * u.day P_max = 1024 * u.day jitter = 150 * u.m / u.s if not os.path.exists(db_path): raise IOError( "sqlite database not found at '{0}'\n Did you run " "scripts/load_dr15_db.py yet for that database?".format(db_path)) logger.debug("Connecting to sqlite database at '{0}'".format(db_path)) Session, engine = db_connect(database_path=db_path, ensure_db_exists=False) session = Session() # Retrieve or create a JokerRun instance params = JokerParams(P_min=P_min, P_max=P_max, jitter=jitter) # Create TheJoker sampler instance with the specified random seed and pool rnd = np.random.RandomState(seed=seed) logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool)) joker = TheJoker(params, random_state=rnd, pool=pool) # Create a cache of prior samples (if it doesn't exist) and store the # filename in the database. if not os.path.exists(prior_samples_file) or overwrite: logger.debug( "Prior samples file not found - generating {0} samples...".format( n_prior)) make_prior_cache(prior_samples_file, joker, nsamples=n_prior) logger.debug("...done") # Get done APOGEE ID's done_subq = session.query(AllStar.apogee_id)\ .join(StarResult, Status)\ .filter(Status.id > 0).distinct() # Query to get all stars associated with this run that need processing: # they should have a status id = 0 (needs processing) star_query = session.query(AllStar)\ .filter(AllStar.vscatter >= 5.)\ .filter(~AllStar.apogee_id.in_(done_subq)) # Base query to get a StarResult for a given Star so we can update the # status, etc. result_query = session.query(StarResult).join(AllStar)\ .filter(Status.id == 0)\ .filter(~AllStar.apogee_id.in_(done_subq)) n_stars = star_query.count() logger.info("{0} stars left to process".format(n_stars)) # Ensure that the results file exists - this is where we cache samples that # pass the rejection sampling step if not os.path.exists(results_filename): with h5py.File(results_filename, 'w') as f: pass # -------------------------------------------------------------------------- # Here is where we do the actual processing of the data for each star. We # loop through all stars that still need processing and iteratively # rejection sample with larger and larger prior sample batch sizes. We do # this for efficiency, but the argument for this is somewhat made up... count = 0 # how many stars we've processed in this star batch batch_size = 16 # MAGIC NUMBER: how many stars to process before committing for star in star_query.all(): if result_query.filter( AllStar.apogee_id == star.apogee_id).count() < 1: logger.debug('Star {0} has no result object!'.format( star.apogee_id)) result = StarResult() star.result = result session.add(result) session.commit() # Retrieve existing StarResult from database. We limit(1) because # the APOGEE_ID isn't unique, but we attach all visits for a given # star to all rows, so grabbing one of them is fine. result = result_query.filter(AllStar.apogee_id == star.apogee_id)\ .limit(1).one() logger.log(1, "Starting star {0}".format(star.apogee_id)) logger.log(1, "Current status: {0}".format(str(result.status))) t0 = time.time() data = star.apogeervdata() logger.log( 1, "\t visits loaded ({:.2f} seconds)".format(time.time() - t0)) try: samples, ln_prior = joker.iterative_rejection_sample( data=data, n_requested_samples=n_requested_samples, prior_cache_file=prior_samples_file, n_prior_samples=n_prior, return_logprobs=True) except Exception as e: logger.warning( "\t Failed sampling for star {0} \n Error: {1}".format( star.apogee_id, str(e))) continue logger.debug("\t done sampling ({:.2f} seconds)".format(time.time() - t0)) # For now, it's sufficient to write the run results to an HDF5 file all_ln_probs = ln_prior[:max_samples_per_star] samples = samples[:max_samples_per_star] n_actual_samples = len(all_ln_probs) # Write the samples that pass to the results file with h5py.File(results_filename, 'r+') as f: if star.apogee_id in f: del f[star.apogee_id] # HACK: this will overwrite the past samples! g = f.create_group(star.apogee_id) samples.to_hdf5(g) if 'ln_prior_probs' in g: del g['ln_prior_probs'] g.create_dataset('ln_prior_probs', data=all_ln_probs) logger.debug("\t saved samples ({:.2f} seconds)".format(time.time() - t0)) if n_actual_samples >= n_requested_samples: result.status_id = 4 # completed elif n_actual_samples == 1: # Only one sample was returned - this is probably unimodal, so this # star needs MCMC result.status_id = 2 # needs mcmc else: if unimodal_P(samples, data): # Multiple samples were returned, but they look unimodal result.status_id = 2 # needs mcmc else: # Multiple samples were returned, but not enough to satisfy the # number requested in the config file result.status_id = 1 # needs more samples logger.debug("...done with star {} ({:.2f} seconds)".format( star.apogee_id, time.time() - t0)) if count % batch_size == 0 and count > 0: session.commit() count += 1 pool.close() session.commit() session.close()
def get_run(config, session, overwrite=False): """Get a JokerRun row instance. Create one if it doesn't exist, otherwise just return the existing one. """ # See if this run (by name) is in the database already, if so, grab that. try: run = session.query(JokerRun).filter( JokerRun.name == config['name']).one() logger.info("JokerRun '{0}' already found in database".format( config['name'])) except NoResultFound: run = None except MultipleResultsFound: raise MultipleResultsFound( "Multiple JokerRun rows found for name '{0}'".format( config['name'])) if run is not None: if overwrite: session.query(StarResult)\ .filter(StarResult.jokerrun_id == run.id)\ .delete() session.commit() session.delete(run) session.commit() else: return run # If we've gotten here, this run doesn't exist in the database yet, so # create it using the parameters read from the config file. logger.info( "JokerRun '{0}' not found in database, creating entry...".format( config['name'])) # Create a JokerRun for this run run = JokerRun() run.config_file = config['config_file'] run.name = config['name'] run.P_min = u.Quantity(*config['hyperparams']['P_min'].split()) run.P_max = u.Quantity(*config['hyperparams']['P_max'].split()) run.requested_samples_per_star = int( config['hyperparams']['requested_samples_per_star']) run.max_prior_samples = int(config['prior']['max_samples']) run.prior_samples_file = join(TWOFACE_CACHE_PATH, config['prior']['samples_file']) if 'jitter' in config['hyperparams']: # jitter is fixed to some quantity, specified in config file run.jitter = u.Quantity(*config['hyperparams']['jitter'].split()) logger.debug('Jitter is fixed to: {0:.2f}'.format(run.jitter)) elif 'jitter_prior_mean' in config['hyperparams']: # jitter prior parameters are specified in config file run.jitter_mean = config['hyperparams']['jitter_prior_mean'] run.jitter_stddev = config['hyperparams']['jitter_prior_stddev'] run.jitter_unit = config['hyperparams']['jitter_prior_unit'] logger.debug('Sampling in jitter with mean = {0:.2f} (stddev in ' 'log(var) = {1:.2f}) [{2}]'.format( np.sqrt(np.exp(run.jitter_mean)), run.jitter_stddev, run.jitter_unit)) else: # no jitter is specified, assume no jitter run.jitter = 0. * u.m / u.s logger.debug('No jitter.') # Get all stars with >=3 visits q = session.query(AllStar).join(AllVisitToAllStar, AllVisit)\ .group_by(AllStar.apstar_id)\ .having(func.count(AllVisit.id) >= 3) stars = q.all() run.stars = stars session.add(run) session.commit() return run
def main(config_file, pool, seed, overwrite=False): config_file = path.abspath(path.expanduser(config_file)) # parse config file with open(config_file, 'r') as f: config = yaml.load(f.read()) # filename of sqlite database database_file = config['database_file'] db_path = path.join(TWOFACE_CACHE_PATH, database_file) if not os.path.exists(db_path): raise IOError( "sqlite database not found at '{0}'\n Did you run " "scripts/initdb.py yet for that database?".format(db_path)) logger.debug("Connecting to sqlite database at '{0}'".format(db_path)) Session, engine = db_connect(database_path=db_path, ensure_db_exists=False) session = Session() run = get_run(config, session, overwrite=False) # The file with cached posterior samples: results_filename = path.join(TWOFACE_CACHE_PATH, "{0}.hdf5".format(run.name)) if not path.exists(results_filename): raise IOError( "Posterior samples result file {0} doesn't exist! Are " "you sure you ran `run_apogee.py`?".format(results_filename)) # Create TheJoker sampler instance with the specified random seed and pool rnd = np.random.RandomState(seed=seed) logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool)) params = run.get_joker_params() joker = TheJoker(params, random_state=rnd, pool=pool) # TODO: we should make sure a 2nd prior cache exists, but because I'm only # going to deal with "needs mcmc", ignore this # _path, ext = path.splitext(run.prior_samples_file) # new_path = '{0}_moar{1}'.format(_path, ext) # if not path.exists(new_path): # make_prior_cache(new_path, joker, # nsamples=8 * config['prior']['num_cache'], # ~100 GB # batch_size=2**24) # MAGIC NUMBER # Get all stars in this JokerRun that need more prior samples # TODO HACK: this query only selects "needs mcmc" stars! star_query = session.query(AllStar).join(StarResult, JokerRun, Status)\ .filter(JokerRun.name == run.name)\ .filter(Status.id == 2) # .filter(Status.id == 1) # Base query to get a StarResult for a given Star so we can update the # status, etc. result_query = session.query(StarResult).join(AllStar, JokerRun)\ .filter(JokerRun.name == run.name) n_stars = star_query.count() logger.info("{0} stars left to process for run more samples '{1}'".format( n_stars, run.name)) # -------------------------------------------------------------------------- # Here is where we do the actual processing of the data for each star. We # loop through all stars that still need processing and continue with # rejection sampling. count = 0 # how many stars we've processed in this star batch batch_size = 16 # MAGIC NUMBER: how many stars to process before committing for star in star_query.all(): if result_query.filter( AllStar.apogee_id == star.apogee_id).count() < 1: logger.debug('Star {0} has no result object!'.format( star.apogee_id)) continue # Retrieve existing StarResult from database. We limit(1) because the # APOGEE_ID isn't unique, but we attach all visits for a given star to # all rows, so grabbing one of them is fine. result = result_query.filter(AllStar.apogee_id == star.apogee_id)\ .limit(1).one() logger.log(1, "Starting star {0}".format(star.apogee_id)) logger.log(1, "Current status: {0}".format(str(result.status))) t0 = time.time() data = star.apogeervdata() logger.log( 1, "\t visits loaded ({:.2f} seconds)".format(time.time() - t0)) if result.status.id == 1: # needs more prior samples try: samples, ln_prior = joker.iterative_rejection_sample( data=data, n_requested_samples=run.requested_samples_per_star, # HACK: prior_cache_file=run.prior_samples_file, prior_cache_file=new_path, return_logprobs=True) except Exception as e: logger.warning( "\t Failed sampling for star {0} \n Error: {1}".format( star.apogee_id, str(e))) continue logger.debug( "\t done sampling ({:.2f} seconds)".format(time.time() - t0)) elif result.status.id == 2: # needs mcmc logger.debug('Firing up MCMC:') with h5py.File(results_filename, 'r') as f: samples0 = JokerSamples.from_hdf5(f[star.apogee_id]) n_walkers = 2 * run.requested_samples_per_star model, samples, sampler = joker.mcmc_sample(data, samples0, n_burn=1024, n_steps=65536, n_walkers=n_walkers, return_sampler=True) sampler.pool = None import pickle with open('test-mcmc.pickle', 'wb') as f: pickle.dump(sampler, f) pool.close() import sys sys.exit(0) # For now, it's sufficient to write the run results to an HDF5 file n = run.requested_samples_per_star all_ln_probs = ln_prior[:n] samples = samples[:n] # Write the samples that pass to the results file with h5py.File(results_filename, 'r+') as f: if star.apogee_id in f: del f[star.apogee_id] g = f.create_group(star.apogee_id) samples.to_hdf5(g) g.create_dataset('ln_prior_probs', data=all_ln_probs) logger.debug("\t saved samples ({:.2f} seconds)".format(time.time() - t0)) result.status_id = get_status_id(samples, data, run.n_requested_samples) logger.debug("...done with star {} ({:.2f} seconds)".format( star.apogee_id, time.time() - t0)) if count % batch_size == 0 and count > 0: session.commit() count += 1 pool.close() session.commit() session.close()
def main(config_file, pool, seed, overwrite=False): # Default seed: if seed is None: seed = 42 config_file = abspath(expanduser(config_file)) # parse config file with open(config_file, 'r') as f: config = yaml.load(f.read()) config['config_file'] = config_file # filename of sqlite database if 'database_file' not in config: database_file = None else: database_file = config['database_file'] db_path = join(TWOFACE_CACHE_PATH, database_file) if not os.path.exists(db_path): raise IOError( "sqlite database not found at '{0}'\n Did you run " "scripts/initdb.py yet for that database?".format(db_path)) logger.debug("Connecting to sqlite database at '{0}'".format(db_path)) Session, engine = db_connect(database_path=db_path, ensure_db_exists=False) session = Session() # Retrieve or create a JokerRun instance run = get_run(config, session, overwrite=False) # never overwrite params = run.get_joker_params() # Create TheJoker sampler instance with the specified random seed and pool rnd = np.random.RandomState(seed=seed) logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool)) joker = TheJoker(params, random_state=rnd, pool=pool) # Create a file to cache the resulting posterior samples results_filename = join(TWOFACE_CACHE_PATH, "{0}-control.hdf5".format(run.name)) # Ensure that the results file exists - this is where we cache samples that # pass the rejection sampling step if not os.path.exists(results_filename): with h5py.File(results_filename, 'w') as f: pass with h5py.File(results_filename, 'r') as f: done_apogee_ids = list(f.keys()) # Create a cache of prior samples (if it doesn't exist) and store the # filename in the database. if not os.path.exists(run.prior_samples_file): raise IOError("Prior cache must already exist.") # Get random IDs star_ids = session.query(AllStar.apogee_id)\ .join(StarResult, JokerRun, Status)\ .filter(Status.id > 0).distinct().all() star_ids = np.array([x[0] for x in star_ids]) star_ids = rnd.choice(star_ids, size=NCONTROL, replace=False) star_ids = star_ids[~np.isin(star_ids, done_apogee_ids)] n_stars = len(star_ids) logger.info( "{0} stars left to process for run '{1}'; {2} already done.".format( n_stars, run.name, len(done_apogee_ids))) # -------------------------------------------------------------------------- # Here is where we do the actual processing of the data for each star. We # loop through all stars that still need processing and iteratively # rejection sample with larger and larger prior sample batch sizes. We do # this for efficiency, but the argument for this is somewhat made up... for apid in star_ids: star = AllStar.get_apogee_id(session, apid) logger.log(1, "Starting star {0}".format(star.apogee_id)) t0 = time.time() orig_data = star.apogeervdata() # HACK: this assumes we're sampling over the excess variance parameter # Generate new data with no RV orbital variations y = rnd.normal(params.jitter[0], params.jitter[1]) s = np.exp(0.5 * y) * params._jitter_unit std = np.sqrt(s**2 + orig_data.stddev**2).to(orig_data.rv.unit).value new_rv = rnd.normal(np.mean(orig_data.rv).value, std) data = APOGEERVData(t=orig_data.t, rv=new_rv * orig_data.rv.unit, stddev=orig_data.stddev) logger.log( 1, "\t visits loaded ({:.2f} seconds)".format(time.time() - t0)) try: samples, ln_prior = joker.iterative_rejection_sample( data=data, n_requested_samples=run.requested_samples_per_star, prior_cache_file=run.prior_samples_file, n_prior_samples=run.max_prior_samples, return_logprobs=True) except Exception as e: logger.warning( "\t Failed sampling for star {0} \n Error: {1}".format( star.apogee_id, str(e))) continue logger.debug("\t done sampling ({:.2f} seconds)".format(time.time() - t0)) # For now, it's sufficient to write the run results to an HDF5 file n = run.requested_samples_per_star samples = samples[:n] # Write the samples that pass to the results file with h5py.File(results_filename, 'r+') as f: if star.apogee_id in f: del f[star.apogee_id] # HACK: this will overwrite the past samples! g = f.create_group(star.apogee_id) samples.to_hdf5(g) logger.debug("\t saved samples ({:.2f} seconds)".format(time.time() - t0)) pool.close() session.close()
def main(config_file, pool, seed, overwrite=False, _continue=False): config_file = abspath(expanduser(config_file)) # parse config file with open(config_file, 'r') as f: config = yaml.load(f.read()) config['config_file'] = config_file # filename of sqlite database if 'database_file' not in config: database_file = None else: database_file = config['database_file'] db_path = join(TWOFACE_CACHE_PATH, database_file) if not os.path.exists(db_path): raise IOError( "sqlite database not found at '{0}'\n Did you run " "scripts/initdb.py yet for that database?".format(db_path)) logger.debug("Connecting to sqlite database at '{0}'".format(db_path)) Session, engine = db_connect(database_path=db_path, ensure_db_exists=False) session = Session() # Retrieve or create a JokerRun instance run = get_run(config, session, overwrite=overwrite) params = run.get_joker_params() # Create TheJoker sampler instance with the specified random seed and pool rnd = np.random.RandomState(seed=seed) logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool)) joker = TheJoker(params, random_state=rnd, pool=pool) # Create a cache of prior samples (if it doesn't exist) and store the # filename in the database. if not os.path.exists(run.prior_samples_file) or overwrite: logger.debug( "Prior samples file not found - generating {0} samples...".format( config['prior']['num_cache'])) make_prior_cache(run.prior_samples_file, joker, nsamples=config['prior']['num_cache']) logger.debug("...done") # Get done APOGEE ID's done_subq = session.query(AllStar.apogee_id)\ .join(StarResult, JokerRun, Status)\ .filter(Status.id > 0).distinct() # Query to get all stars associated with this run that need processing: # they should have a status id = 0 (needs processing) star_query = session.query(AllStar).join(StarResult, JokerRun, Status)\ .filter(JokerRun.name == run.name)\ .filter(Status.id == 0)\ .filter(~AllStar.apogee_id.in_(done_subq)) # Base query to get a StarResult for a given Star so we can update the # status, etc. result_query = session.query(StarResult).join(AllStar, JokerRun)\ .filter(JokerRun.name == run.name)\ .filter(Status.id == 0)\ .filter(~AllStar.apogee_id.in_(done_subq)) # Create a file to cache the resulting posterior samples results_filename = join(TWOFACE_CACHE_PATH, "{0}.hdf5".format(run.name)) n_stars = star_query.count() logger.info("{0} stars left to process for run '{1}'".format( n_stars, run.name)) # Ensure that the results file exists - this is where we cache samples that # pass the rejection sampling step if not os.path.exists(results_filename): with h5py.File(results_filename, 'w') as f: pass # -------------------------------------------------------------------------- # Here is where we do the actual processing of the data for each star. We # loop through all stars that still need processing and iteratively # rejection sample with larger and larger prior sample batch sizes. We do # this for efficiency, but the argument for this is somewhat made up... count = 0 # how many stars we've processed in this star batch batch_size = 16 # MAGIC NUMBER: how many stars to process before committing for star in star_query.all(): if result_query.filter( AllStar.apogee_id == star.apogee_id).count() < 1: logger.debug('Star {0} has no result object!'.format( star.apogee_id)) continue # Retrieve existing StarResult from database. We limit(1) because the # APOGEE_ID isn't unique, but we attach all visits for a given star to # all rows, so grabbing one of them is fine. result = result_query.filter(AllStar.apogee_id == star.apogee_id)\ .limit(1).one() logger.log(1, "Starting star {0}".format(star.apogee_id)) logger.log(1, "Current status: {0}".format(str(result.status))) t0 = time.time() data = star.apogeervdata() logger.log( 1, "\t visits loaded ({:.2f} seconds)".format(time.time() - t0)) try: samples, ln_prior = joker.iterative_rejection_sample( data=data, n_requested_samples=run.requested_samples_per_star, prior_cache_file=run.prior_samples_file, n_prior_samples=run.max_prior_samples, return_logprobs=True) except Exception as e: logger.warning( "\t Failed sampling for star {0} \n Error: {1}".format( star.apogee_id, str(e))) continue logger.debug("\t done sampling ({:.2f} seconds)".format(time.time() - t0)) # For now, it's sufficient to write the run results to an HDF5 file n = run.requested_samples_per_star all_ln_probs = ln_prior[:n] samples = samples[:n] n_actual_samples = len(all_ln_probs) # Write the samples that pass to the results file with h5py.File(results_filename, 'r+') as f: if star.apogee_id in f: del f[star.apogee_id] # HACK: this will overwrite the past samples! g = f.create_group(star.apogee_id) samples.to_hdf5(g) if 'ln_prior_probs' in g: del g['ln_prior_probs'] g.create_dataset('ln_prior_probs', data=all_ln_probs) logger.debug("\t saved samples ({:.2f} seconds)".format(time.time() - t0)) if n_actual_samples >= run.requested_samples_per_star: result.status_id = 4 # completed elif n_actual_samples == 1: # Only one sample was returned - this is probably unimodal, so this # star needs MCMC result.status_id = 2 # needs mcmc else: if unimodal_P(samples, data): # Multiple samples were returned, but they look unimodal result.status_id = 2 # needs mcmc else: # Multiple samples were returned, but not enough to satisfy the # number requested in the config file result.status_id = 1 # needs more samples logger.debug("...done with star {} ({:.2f} seconds)".format( star.apogee_id, time.time() - t0)) if count % batch_size == 0 and count > 0: session.commit() count += 1 pool.close() session.commit() session.close()