示例#1
0
def main(pool, circ, P, time_sampling, overwrite=False):

    pars = JokerParams(P_min=1 * u.day, P_max=1024 * u.day)
    joker = TheJoker(pars, pool=pool)

    # make the prior cache file
    prior_file, samples_file = make_caches(2**28,
                                           joker,
                                           circ=circ,
                                           P=P,
                                           overwrite=overwrite,
                                           time_sampling=time_sampling)

    logger.info('Files: {0}, {1}'.format(prior_file, samples_file))

    n_epochs = np.arange(3, 12 + 1, 1)
    for n_epoch, i, data, P in make_data(n_epochs,
                                         n_orbits=512,
                                         P=P,
                                         time_sampling=time_sampling,
                                         circ=circ):
        logger.debug("N epochs: {0}, orbit {1}".format(n_epoch, i))

        key = '{0}-{1}'.format(n_epoch, i)
        with h5py.File(samples_file, 'r') as f:
            if key in f:
                logger.debug('-- already done! skipping...')
                continue

        samples = joker.iterative_rejection_sample(n_requested_samples=256,
                                                   prior_cache_file=prior_file,
                                                   data=data)
        logger.debug("-- done sampling - {0} samples returned".format(
            len(samples)))

        with h5py.File(samples_file) as f:
            g = f.create_group(key)
            samples.to_hdf5(g)
            g.attrs['P'] = P
示例#2
0
文件: run_mcmc.py 项目: adrn/APOGEEBH
def main(data_path, apogee_id, config, data_file_ext, pool, overwrite=False):

    cache_path = path.join(data_path, 'cache')
    os.makedirs(cache_path, exist_ok=True)

    # Create TheJoker sampler instance with the specified random seed and pool
    rnd = np.random.RandomState()
    logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool))
    joker = TheJoker(joker_params, random_state=rnd, pool=pool)

    n_walkers = config['emcee']['n_walkers']
    n_steps = config['emcee']['n_steps']

    mcmc_model_filename = path.join(cache_path, 'model.pickle')

    logger.info('Processing {0}'.format(apogee_id))

    joker_results_filename = path.join(cache_path,
                                       '{0}-joker.hdf5'.format(apogee_id))
    mcmc_results_filename = path.join(cache_path,
                                      '{0}-mcmc.hdf5'.format(apogee_id))
    mcmc_chain_filename = path.join(cache_path,
                                    '{0}-chain.npy'.format(apogee_id))

    with h5py.File(joker_results_filename) as f:
        joker_samples = JokerSamples.from_hdf5(f)

    model = TheJokerMCMCModel(joker_params=joker_params, data=data)

    if not path.exists(mcmc_chain_filename) or overwrite:
        joker = TheJoker(joker_params)
        joker.params.jitter = (8.5, 0.9)  # HACK!

        if unimodal_P(joker_samples, data):
            logger.debug("Samples are unimodal. Preparing to run MCMC...")

            sample = MAP_sample(data, joker_samples, joker.params)

            ball_scale = 1E-5
            p0_mean = np.squeeze(model.pack_samples(sample))

            # P, M0, e, omega, jitter, K, v0
            p0 = np.zeros((n_walkers, len(p0_mean)))
            for i in range(p0.shape[1]):
                if i in [2, 4]:  # eccentricity, jitter
                    p0[:, i] = np.abs(
                        np.random.normal(p0_mean[i],
                                         ball_scale,
                                         size=n_walkers))

                else:
                    p0[:, i] = np.random.normal(p0_mean[i],
                                                ball_scale,
                                                size=n_walkers)

            p0 = model.to_mcmc_params(p0.T).T

            n_dim = p0.shape[1]
            sampler = emcee.EnsembleSampler(n_walkers,
                                            n_dim,
                                            logprob,
                                            pool=pool)

            logger.debug('Running MCMC for {0} steps...'.format(n_steps))
            time0 = time.time()
            _ = sampler.run_mcmc(p0, n_steps)
            logger.debug('...time spent sampling: {0}'.format(time.time() -
                                                              time0))

            samples = model.unpack_samples_mcmc(sampler.chain[:, -1])
            samples.t0 = joker_samples.t0

            np.save(mcmc_chain_filename, sampler.chain.astype('f4'))

            if not path.exists(mcmc_model_filename):
                with open(mcmc_model_filename, 'wb') as f:
                    pickle.dump(model, f)

        else:
            logger.error("Samples are multimodal!")
            return

    if not path.exists(mcmc_results_filename) or overwrite:
        chain = np.load(mcmc_chain_filename)
        with h5py.File(mcmc_results_filename) as f:
            chain = np.load(mcmc_chain_filename)
            n_walkers, n_steps, n_pars = chain.shape

            logger.debug('Adding star {0} to MCMC cache'.format(apogee_id))

            try:
                g = f.create_group('chain-stats')
                Rs = gelman_rubin(chain[:, n_steps // 2:])
                g.create_dataset(name='gelman_rubin', data=Rs)

                # take the last sample, downsample
                end_pos = chain[:, n_steps // 2::1024].reshape(-1, n_pars)
                samples = model.unpack_samples_mcmc(end_pos)
                samples.to_hdf5(f)

            except Exception as e:
                raise

            finally:
                del g
示例#3
0
文件: init.py 项目: adrn/APOGEEBH
def initialize_db(allVisit_file,
                  allStar_file,
                  database_path,
                  drop_all=False,
                  batch_size=4096):
    """Initialize the database given FITS filenames for the APOGEE data.

    Parameters
    ----------
    allVisit_file : str
        Full path to APOGEE allVisit file.
    allStar_file : str
        Full path to APOGEE allStar file.
    database_file : str
        Filename (not path) of database file in cache path.
    drop_all : bool (optional)
        Drop all existing tables and re-create the database.
    batch_size : int (optional)
        How many rows to create before committing.
    """

    norm = lambda x: abspath(expanduser(x))
    allvisit = fits.getdata(norm(allVisit_file))
    allstar = fits.getdata(norm(allStar_file))

    # STAR_BAD
    ASPCAP_skip_bitmask = np.sum(2**np.array([23]))

    # VERY_BRIGHT_NEIGHBOR, LOW_SNR, SUSPECT_RV_COMBINATION, SUSPECT_BROAD_LINES
    STAR_skip_bitmask = np.sum(2**np.array([3, 4, 16, 17]))

    # First filter allStar flags
    mask = ((allstar['NVISITS'] >= 4) &
            ((allstar['ASPCAPFLAG'] & ASPCAP_skip_bitmask) == 0) &
            ((allstar['STARFLAG'] & STAR_skip_bitmask) == 0))
    stars = allstar[mask]
    visits = allvisit[np.isin(allvisit['APOGEE_ID'], stars['APOGEE_ID'])]

    # Next filter allVisit flags
    mask = (((visits['STARFLAG'] & STAR_skip_bitmask) == 0) &
            (visits['VRELERR'] < 100.) & np.isfinite(visits['VHELIO'])
            & np.isfinite(visits['VRELERR']) & (visits['VHELIO'] > -999))
    visits = visits[mask]
    v_apogee_ids, counts = np.unique(visits['APOGEE_ID'], return_counts=True)
    stars = stars[np.isin(stars['APOGEE_ID'], v_apogee_ids[counts >= 4])]
    visits = visits[np.isin(visits['APOGEE_ID'], stars['APOGEE_ID'])]

    # uniquify the stars
    _, idx = np.unique(stars['APOGEE_ID'], return_index=True)
    allstar_tbl = Table(stars[idx])
    allvisit_tbl = Table(visits)

    Session, engine = db_connect(database_path, ensure_db_exists=True)
    logger.debug("Connected to database at '{}'".format(database_path))

    if drop_all:
        # this is the magic that creates the tables based on the definitions in
        # twoface/db/model.py
        Base.metadata.drop_all()
        Base.metadata.create_all()

    session = Session()

    logger.debug("Loading allStar, allVisit tables...")

    # Figure out what data we need to pull out of the FITS files based on what
    # columns exist in the (empty) database
    allstar_skip = ['ID']
    allstar_colnames = []
    allstar_varchar = []
    for x in AllStar.__table__.columns:
        col = str(x).split('.')[1].upper()
        if col in allstar_skip:
            continue

        if str(x.type) == 'VARCHAR':
            allstar_varchar.append(col)

        allstar_colnames.append(col)

    allvisit_skip = ['ID']
    allvisit_colnames = []
    allvisit_varchar = []
    for x in AllVisit.__table__.columns:
        col = str(x).split('.')[1].upper()
        if col in allvisit_skip:
            continue

        if str(x.type) == 'VARCHAR':
            allvisit_varchar.append(col)

        allvisit_colnames.append(col)

    # --------------------------------------------------------------------------
    # First load the status table:
    #
    if session.query(Status).count() == 0:
        logger.debug("Populating Status table...")
        statuses = list()
        statuses.append(Status(id=0, message='untouched'))
        statuses.append(Status(id=1, message='needs more prior samples'))
        statuses.append(Status(id=2, message='needs mcmc'))
        statuses.append(Status(id=3, message='error'))
        statuses.append(Status(id=4, message='completed'))

        session.add_all(statuses)
        session.commit()
        logger.debug("...done")

    # --------------------------------------------------------------------------
    # Load the AllStar table:
    #
    logger.info("Loading AllStar table")

    # What APOGEE_ID's are already loaded?
    all_ap_ids = np.array([x.strip() for x in allstar_tbl['APOGEE_ID']])
    loaded_ap_ids = [x[0] for x in session.query(AllStar.apogee_id).all()]
    mask = np.logical_not(np.isin(all_ap_ids, loaded_ap_ids))
    logger.debug("{0} stars already loaded".format(len(loaded_ap_ids)))
    logger.debug("{0} stars left to load".format(mask.sum()))

    stars = []
    with Timer() as t:
        i = 0
        for row in allstar_tbl[mask]:  # Load every star
            row_data = tblrow_to_dbrow(row, allstar_colnames, allstar_varchar)

            # create a new object for this row
            star = AllStar(**row_data)
            stars.append(star)
            logger.log(1, 'Adding star {0} to database'.format(star))

            if i % batch_size == 0 and i > 0:
                session.add_all(stars)
                session.commit()
                logger.debug("Loaded batch {0} ({1:.2f} seconds)".format(
                    i, t.elapsed()))
                t.reset()
                stars = []

            i += 1

    if len(stars) > 0:
        session.add_all(stars)
        session.commit()

    # --------------------------------------------------------------------------
    # Load the AllVisit table:
    #
    logger.info("Loading AllVisit table")

    # What VISIT_ID's are already loaded?
    all_vis_ids = np.array([x.strip() for x in allvisit_tbl['VISIT_ID']])
    loaded_vis_ids = [x[0] for x in session.query(AllVisit.visit_id).all()]
    mask = np.logical_not(np.isin(all_vis_ids, loaded_vis_ids))
    logger.debug("{0} visits already loaded".format(len(loaded_vis_ids)))
    logger.debug("{0} visits left to load".format(mask.sum()))

    visits = []
    with Timer() as t:
        i = 0
        for row in allvisit_tbl[mask]:  # Load every visit
            row_data = tblrow_to_dbrow(row, allvisit_colnames,
                                       allvisit_varchar)

            # create a new object for this row
            visit = AllVisit(**row_data)
            visits.append(visit)
            logger.log(1, 'Adding visit {0} to database'.format(visit))

            if i % batch_size == 0 and i > 0:
                session.add_all(visits)
                session.commit()
                logger.debug("Loaded batch {0} ({1:.2f} seconds)".format(
                    i, t.elapsed()))
                t.reset()
                visits = []

            i += 1

    if len(visits) > 0:
        session.add_all(visits)
        session.commit()

    # --------------------------------------------------------------------------
    # Now associate rows in AllStar with rows in AllVisit
    logger.info("Linking AllVisit and AllStar tables")

    q = session.query(AllStar).order_by(AllStar.id)

    for i, sub_q in enumerate(paged_query(q, page_size=batch_size)):
        for star in sub_q:
            if len(star.visits) > 0:
                continue

            visits = session.query(AllVisit).filter(
                AllVisit.apogee_id == star.apogee_id).all()

            if len(visits) == 0:
                logger.warn("Visits not found for star {0}".format(star))
                continue

            logger.log(
                1,
                'Attaching {0} visits to star {1}'.format(len(visits), star))

            star.visits = visits

        logger.debug("Committing batch {0}".format(i))
        session.commit()

    session.commit()
    session.close()
示例#4
0
def main(data_path, config_file, data_file_ext, pool, seed, overwrite=False):
    # parse config file
    with open(config_file, 'r') as f:
        config = yaml.load(f.read())
        config['config_file'] = config_file

    cache_path = path.join(data_path, 'cache')
    os.makedirs(cache_path, exist_ok=True)

    n_prior_samples = config['prior']['num_cache']
    n_walkers = config['emcee']['n_walkers']
    joker_pars = config_to_jokerparams(config)

    prior_samples_file = path.join(cache_path, 'prior-samples.hdf5')

    # Create TheJoker sampler instance with the specified random seed and pool
    rnd = np.random.RandomState(seed=seed)
    logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool))
    joker = TheJoker(joker_pars, random_state=rnd, pool=pool)

    # Create a cache of prior samples (if it doesn't exist) and store the
    # filename in the database.
    if not os.path.exists(prior_samples_file) or overwrite:
        logger.debug("Prior samples file not found - generating {0} samples..."
                     .format(n_prior_samples))
        make_prior_cache(prior_samples_file, joker,
                         nsamples=n_prior_samples)
        logger.debug("...done")

    mcmc_model_filename = path.join(cache_path, 'model.pickle')
    data_files = glob.glob(path.join(data_path, '*.{0}'.format(data_file_ext)))

    for filename in data_files:
        basename = path.splitext(path.basename(filename))[0]
        logger.info('Processing file {0}'.format(basename))

        joker_results_filename = path.join(cache_path,
                                           '{0}-joker.hdf5'.format(basename))
        mcmc_results_filename = path.join(cache_path,
                                          '{0}-mcmc.hdf5'.format(basename))
        mcmc_chain_filename = path.join(cache_path,
                                        '{0}-chain.npy'.format(basename))

        data_tbl = QTable.read(filename)
        data = RVData(t=data_tbl['time'], rv=data_tbl['rv'],
                      stddev=data_tbl['rv_err'])

        if not path.exists(joker_results_filename) or overwrite:
            t0 = time.time()
            logger.log(1, "\t visits loaded ({:.2f} seconds)"
                       .format(time.time() - t0))
            try:
                samples = joker.rejection_sample(
                    data=data, prior_cache_file=prior_samples_file,
                    return_logprobs=False)

            except Exception as e:
                logger.warning("\t Failed sampling for star {0} \n Error: {1}"
                               .format(basename, str(e)))
                pool.close()
                sys.exit(1)

            logger.debug("\t done sampling ({:.2f} seconds)"
                         .format(time.time() - t0))

            # Write the samples that pass to the results file
            with h5py.File(joker_results_filename, 'w') as f:
                samples.to_hdf5(f)

            logger.debug("\t saved samples ({:.2f} seconds)"
                         .format(time.time() - t0))
            logger.debug("...done with star {} ({:.2f} seconds)"
                         .format(basename, time.time() - t0))

    pool.close()
    sys.exit(0)
示例#5
0
def main(pool, seed, overwrite=False, _continue=False):

    # HACK: hard-coded configuration!
    db_path = path.abspath('../cache/apogeebh.sqlite')
    prior_samples_file = path.abspath('../cache/prior-samples.hdf5')
    results_filename = path.abspath('../cache/apogeebh-joker.hdf5')
    n_prior = 536870912  # number of prior samples to generate
    n_requested_samples = 256  # how many samples to generate, nominally
    max_samples_per_star = 2048  # max. number of posterior samples to save
    P_min = 1 * u.day
    P_max = 1024 * u.day
    jitter = 150 * u.m / u.s

    if not os.path.exists(db_path):
        raise IOError(
            "sqlite database not found at '{0}'\n Did you run "
            "scripts/load_dr15_db.py yet for that database?".format(db_path))

    logger.debug("Connecting to sqlite database at '{0}'".format(db_path))
    Session, engine = db_connect(database_path=db_path, ensure_db_exists=False)
    session = Session()

    # Retrieve or create a JokerRun instance
    params = JokerParams(P_min=P_min, P_max=P_max, jitter=jitter)

    # Create TheJoker sampler instance with the specified random seed and pool
    rnd = np.random.RandomState(seed=seed)
    logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool))
    joker = TheJoker(params, random_state=rnd, pool=pool)

    # Create a cache of prior samples (if it doesn't exist) and store the
    # filename in the database.
    if not os.path.exists(prior_samples_file) or overwrite:
        logger.debug(
            "Prior samples file not found - generating {0} samples...".format(
                n_prior))
        make_prior_cache(prior_samples_file, joker, nsamples=n_prior)
        logger.debug("...done")

    # Get done APOGEE ID's
    done_subq = session.query(AllStar.apogee_id)\
                       .join(StarResult, Status)\
                       .filter(Status.id > 0).distinct()

    # Query to get all stars associated with this run that need processing:
    # they should have a status id = 0 (needs processing)
    star_query = session.query(AllStar)\
                        .filter(AllStar.vscatter >= 5.)\
                        .filter(~AllStar.apogee_id.in_(done_subq))

    # Base query to get a StarResult for a given Star so we can update the
    # status, etc.
    result_query = session.query(StarResult).join(AllStar)\
                          .filter(Status.id == 0)\
                          .filter(~AllStar.apogee_id.in_(done_subq))

    n_stars = star_query.count()
    logger.info("{0} stars left to process".format(n_stars))

    # Ensure that the results file exists - this is where we cache samples that
    # pass the rejection sampling step
    if not os.path.exists(results_filename):
        with h5py.File(results_filename, 'w') as f:
            pass

    # --------------------------------------------------------------------------
    # Here is where we do the actual processing of the data for each star. We
    # loop through all stars that still need processing and iteratively
    # rejection sample with larger and larger prior sample batch sizes. We do
    # this for efficiency, but the argument for this is somewhat made up...

    count = 0  # how many stars we've processed in this star batch
    batch_size = 16  # MAGIC NUMBER: how many stars to process before committing
    for star in star_query.all():

        if result_query.filter(
                AllStar.apogee_id == star.apogee_id).count() < 1:
            logger.debug('Star {0} has no result object!'.format(
                star.apogee_id))
            result = StarResult()
            star.result = result
            session.add(result)
            session.commit()

        # Retrieve existing StarResult from database. We limit(1) because
        # the APOGEE_ID isn't unique, but we attach all visits for a given
        # star to all rows, so grabbing one of them is fine.
        result = result_query.filter(AllStar.apogee_id == star.apogee_id)\
                             .limit(1).one()

        logger.log(1, "Starting star {0}".format(star.apogee_id))
        logger.log(1, "Current status: {0}".format(str(result.status)))
        t0 = time.time()

        data = star.apogeervdata()
        logger.log(
            1, "\t visits loaded ({:.2f} seconds)".format(time.time() - t0))
        try:
            samples, ln_prior = joker.iterative_rejection_sample(
                data=data,
                n_requested_samples=n_requested_samples,
                prior_cache_file=prior_samples_file,
                n_prior_samples=n_prior,
                return_logprobs=True)

        except Exception as e:
            logger.warning(
                "\t Failed sampling for star {0} \n Error: {1}".format(
                    star.apogee_id, str(e)))
            continue

        logger.debug("\t done sampling ({:.2f} seconds)".format(time.time() -
                                                                t0))

        # For now, it's sufficient to write the run results to an HDF5 file
        all_ln_probs = ln_prior[:max_samples_per_star]
        samples = samples[:max_samples_per_star]
        n_actual_samples = len(all_ln_probs)

        # Write the samples that pass to the results file
        with h5py.File(results_filename, 'r+') as f:
            if star.apogee_id in f:
                del f[star.apogee_id]

            # HACK: this will overwrite the past samples!
            g = f.create_group(star.apogee_id)
            samples.to_hdf5(g)

            if 'ln_prior_probs' in g:
                del g['ln_prior_probs']
            g.create_dataset('ln_prior_probs', data=all_ln_probs)

        logger.debug("\t saved samples ({:.2f} seconds)".format(time.time() -
                                                                t0))

        if n_actual_samples >= n_requested_samples:
            result.status_id = 4  # completed

        elif n_actual_samples == 1:
            # Only one sample was returned - this is probably unimodal, so this
            # star needs MCMC
            result.status_id = 2  # needs mcmc

        else:

            if unimodal_P(samples, data):
                # Multiple samples were returned, but they look unimodal
                result.status_id = 2  # needs mcmc

            else:
                # Multiple samples were returned, but not enough to satisfy the
                # number requested in the config file
                result.status_id = 1  # needs more samples

        logger.debug("...done with star {} ({:.2f} seconds)".format(
            star.apogee_id,
            time.time() - t0))

        if count % batch_size == 0 and count > 0:
            session.commit()

        count += 1

    pool.close()

    session.commit()
    session.close()
示例#6
0
def get_run(config, session, overwrite=False):
    """Get a JokerRun row instance. Create one if it doesn't exist, otherwise
    just return the existing one.
    """

    # See if this run (by name) is in the database already, if so, grab that.
    try:
        run = session.query(JokerRun).filter(
            JokerRun.name == config['name']).one()
        logger.info("JokerRun '{0}' already found in database".format(
            config['name']))

    except NoResultFound:
        run = None

    except MultipleResultsFound:
        raise MultipleResultsFound(
            "Multiple JokerRun rows found for name '{0}'".format(
                config['name']))

    if run is not None:
        if overwrite:
            session.query(StarResult)\
                   .filter(StarResult.jokerrun_id == run.id)\
                   .delete()
            session.commit()
            session.delete(run)
            session.commit()

        else:
            return run

    # If we've gotten here, this run doesn't exist in the database yet, so
    # create it using the parameters read from the config file.
    logger.info(
        "JokerRun '{0}' not found in database, creating entry...".format(
            config['name']))

    # Create a JokerRun for this run
    run = JokerRun()
    run.config_file = config['config_file']
    run.name = config['name']
    run.P_min = u.Quantity(*config['hyperparams']['P_min'].split())
    run.P_max = u.Quantity(*config['hyperparams']['P_max'].split())
    run.requested_samples_per_star = int(
        config['hyperparams']['requested_samples_per_star'])
    run.max_prior_samples = int(config['prior']['max_samples'])
    run.prior_samples_file = join(TWOFACE_CACHE_PATH,
                                  config['prior']['samples_file'])

    if 'jitter' in config['hyperparams']:
        # jitter is fixed to some quantity, specified in config file
        run.jitter = u.Quantity(*config['hyperparams']['jitter'].split())
        logger.debug('Jitter is fixed to: {0:.2f}'.format(run.jitter))

    elif 'jitter_prior_mean' in config['hyperparams']:
        # jitter prior parameters are specified in config file
        run.jitter_mean = config['hyperparams']['jitter_prior_mean']
        run.jitter_stddev = config['hyperparams']['jitter_prior_stddev']
        run.jitter_unit = config['hyperparams']['jitter_prior_unit']
        logger.debug('Sampling in jitter with mean = {0:.2f} (stddev in '
                     'log(var) = {1:.2f}) [{2}]'.format(
                         np.sqrt(np.exp(run.jitter_mean)), run.jitter_stddev,
                         run.jitter_unit))

    else:
        # no jitter is specified, assume no jitter
        run.jitter = 0. * u.m / u.s
        logger.debug('No jitter.')

    # Get all stars with >=3 visits
    q = session.query(AllStar).join(AllVisitToAllStar, AllVisit)\
                              .group_by(AllStar.apstar_id)\
                              .having(func.count(AllVisit.id) >= 3)

    stars = q.all()

    run.stars = stars
    session.add(run)
    session.commit()

    return run
def main(config_file, pool, seed, overwrite=False):
    config_file = path.abspath(path.expanduser(config_file))

    # parse config file
    with open(config_file, 'r') as f:
        config = yaml.load(f.read())

    # filename of sqlite database
    database_file = config['database_file']

    db_path = path.join(TWOFACE_CACHE_PATH, database_file)
    if not os.path.exists(db_path):
        raise IOError(
            "sqlite database not found at '{0}'\n Did you run "
            "scripts/initdb.py yet for that database?".format(db_path))

    logger.debug("Connecting to sqlite database at '{0}'".format(db_path))
    Session, engine = db_connect(database_path=db_path, ensure_db_exists=False)
    session = Session()

    run = get_run(config, session, overwrite=False)

    # The file with cached posterior samples:
    results_filename = path.join(TWOFACE_CACHE_PATH,
                                 "{0}.hdf5".format(run.name))
    if not path.exists(results_filename):
        raise IOError(
            "Posterior samples result file {0} doesn't exist! Are "
            "you sure you ran `run_apogee.py`?".format(results_filename))

    # Create TheJoker sampler instance with the specified random seed and pool
    rnd = np.random.RandomState(seed=seed)
    logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool))
    params = run.get_joker_params()
    joker = TheJoker(params, random_state=rnd, pool=pool)

    # TODO: we should make sure a 2nd prior cache exists, but because I'm only
    # going to deal with "needs mcmc", ignore this
    # _path, ext = path.splitext(run.prior_samples_file)
    # new_path = '{0}_moar{1}'.format(_path, ext)
    # if not path.exists(new_path):
    #     make_prior_cache(new_path, joker,
    #                      nsamples=8 * config['prior']['num_cache'], # ~100 GB
    #                      batch_size=2**24) # MAGIC NUMBER

    # Get all stars in this JokerRun that need more prior samples
    # TODO HACK: this query only selects "needs mcmc" stars!
    star_query = session.query(AllStar).join(StarResult, JokerRun, Status)\
                                       .filter(JokerRun.name == run.name)\
                                       .filter(Status.id == 2)
    # .filter(Status.id == 1)

    # Base query to get a StarResult for a given Star so we can update the
    # status, etc.
    result_query = session.query(StarResult).join(AllStar, JokerRun)\
                                            .filter(JokerRun.name == run.name)

    n_stars = star_query.count()
    logger.info("{0} stars left to process for run more samples '{1}'".format(
        n_stars, run.name))

    # --------------------------------------------------------------------------
    # Here is where we do the actual processing of the data for each star. We
    # loop through all stars that still need processing and continue with
    # rejection sampling.

    count = 0  # how many stars we've processed in this star batch
    batch_size = 16  # MAGIC NUMBER: how many stars to process before committing
    for star in star_query.all():

        if result_query.filter(
                AllStar.apogee_id == star.apogee_id).count() < 1:
            logger.debug('Star {0} has no result object!'.format(
                star.apogee_id))
            continue

        # Retrieve existing StarResult from database. We limit(1) because the
        # APOGEE_ID isn't unique, but we attach all visits for a given star to
        # all rows, so grabbing one of them is fine.
        result = result_query.filter(AllStar.apogee_id == star.apogee_id)\
                             .limit(1).one()

        logger.log(1, "Starting star {0}".format(star.apogee_id))
        logger.log(1, "Current status: {0}".format(str(result.status)))
        t0 = time.time()

        data = star.apogeervdata()
        logger.log(
            1, "\t visits loaded ({:.2f} seconds)".format(time.time() - t0))

        if result.status.id == 1:  # needs more prior samples

            try:
                samples, ln_prior = joker.iterative_rejection_sample(
                    data=data,
                    n_requested_samples=run.requested_samples_per_star,
                    # HACK: prior_cache_file=run.prior_samples_file,
                    prior_cache_file=new_path,
                    return_logprobs=True)

            except Exception as e:
                logger.warning(
                    "\t Failed sampling for star {0} \n Error: {1}".format(
                        star.apogee_id, str(e)))
                continue

            logger.debug(
                "\t done sampling ({:.2f} seconds)".format(time.time() - t0))

        elif result.status.id == 2:  # needs mcmc
            logger.debug('Firing up MCMC:')

            with h5py.File(results_filename, 'r') as f:
                samples0 = JokerSamples.from_hdf5(f[star.apogee_id])

            n_walkers = 2 * run.requested_samples_per_star
            model, samples, sampler = joker.mcmc_sample(data,
                                                        samples0,
                                                        n_burn=1024,
                                                        n_steps=65536,
                                                        n_walkers=n_walkers,
                                                        return_sampler=True)

            sampler.pool = None
            import pickle
            with open('test-mcmc.pickle', 'wb') as f:
                pickle.dump(sampler, f)

        pool.close()
        import sys
        sys.exit(0)

        # For now, it's sufficient to write the run results to an HDF5 file
        n = run.requested_samples_per_star
        all_ln_probs = ln_prior[:n]
        samples = samples[:n]

        # Write the samples that pass to the results file
        with h5py.File(results_filename, 'r+') as f:
            if star.apogee_id in f:
                del f[star.apogee_id]

            g = f.create_group(star.apogee_id)
            samples.to_hdf5(g)
            g.create_dataset('ln_prior_probs', data=all_ln_probs)

        logger.debug("\t saved samples ({:.2f} seconds)".format(time.time() -
                                                                t0))

        result.status_id = get_status_id(samples, data,
                                         run.n_requested_samples)

        logger.debug("...done with star {} ({:.2f} seconds)".format(
            star.apogee_id,
            time.time() - t0))

        if count % batch_size == 0 and count > 0:
            session.commit()

        count += 1

    pool.close()

    session.commit()
    session.close()
示例#8
0
def main(config_file, pool, seed, overwrite=False):
    # Default seed:
    if seed is None:
        seed = 42

    config_file = abspath(expanduser(config_file))

    # parse config file
    with open(config_file, 'r') as f:
        config = yaml.load(f.read())
        config['config_file'] = config_file

    # filename of sqlite database
    if 'database_file' not in config:
        database_file = None

    else:
        database_file = config['database_file']

    db_path = join(TWOFACE_CACHE_PATH, database_file)
    if not os.path.exists(db_path):
        raise IOError(
            "sqlite database not found at '{0}'\n Did you run "
            "scripts/initdb.py yet for that database?".format(db_path))

    logger.debug("Connecting to sqlite database at '{0}'".format(db_path))
    Session, engine = db_connect(database_path=db_path, ensure_db_exists=False)
    session = Session()

    # Retrieve or create a JokerRun instance
    run = get_run(config, session, overwrite=False)  # never overwrite
    params = run.get_joker_params()

    # Create TheJoker sampler instance with the specified random seed and pool
    rnd = np.random.RandomState(seed=seed)
    logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool))
    joker = TheJoker(params, random_state=rnd, pool=pool)

    # Create a file to cache the resulting posterior samples
    results_filename = join(TWOFACE_CACHE_PATH,
                            "{0}-control.hdf5".format(run.name))

    # Ensure that the results file exists - this is where we cache samples that
    # pass the rejection sampling step
    if not os.path.exists(results_filename):
        with h5py.File(results_filename, 'w') as f:
            pass

    with h5py.File(results_filename, 'r') as f:
        done_apogee_ids = list(f.keys())

    # Create a cache of prior samples (if it doesn't exist) and store the
    # filename in the database.
    if not os.path.exists(run.prior_samples_file):
        raise IOError("Prior cache must already exist.")

    # Get random IDs
    star_ids = session.query(AllStar.apogee_id)\
                      .join(StarResult, JokerRun, Status)\
                      .filter(Status.id > 0).distinct().all()
    star_ids = np.array([x[0] for x in star_ids])
    star_ids = rnd.choice(star_ids, size=NCONTROL, replace=False)
    star_ids = star_ids[~np.isin(star_ids, done_apogee_ids)]

    n_stars = len(star_ids)
    logger.info(
        "{0} stars left to process for run '{1}'; {2} already done.".format(
            n_stars, run.name, len(done_apogee_ids)))

    # --------------------------------------------------------------------------
    # Here is where we do the actual processing of the data for each star. We
    # loop through all stars that still need processing and iteratively
    # rejection sample with larger and larger prior sample batch sizes. We do
    # this for efficiency, but the argument for this is somewhat made up...

    for apid in star_ids:
        star = AllStar.get_apogee_id(session, apid)

        logger.log(1, "Starting star {0}".format(star.apogee_id))
        t0 = time.time()

        orig_data = star.apogeervdata()

        # HACK: this assumes we're sampling over the excess variance parameter
        # Generate new data with no RV orbital variations
        y = rnd.normal(params.jitter[0], params.jitter[1])
        s = np.exp(0.5 * y) * params._jitter_unit
        std = np.sqrt(s**2 + orig_data.stddev**2).to(orig_data.rv.unit).value
        new_rv = rnd.normal(np.mean(orig_data.rv).value, std)
        data = APOGEERVData(t=orig_data.t,
                            rv=new_rv * orig_data.rv.unit,
                            stddev=orig_data.stddev)

        logger.log(
            1, "\t visits loaded ({:.2f} seconds)".format(time.time() - t0))
        try:
            samples, ln_prior = joker.iterative_rejection_sample(
                data=data,
                n_requested_samples=run.requested_samples_per_star,
                prior_cache_file=run.prior_samples_file,
                n_prior_samples=run.max_prior_samples,
                return_logprobs=True)

        except Exception as e:
            logger.warning(
                "\t Failed sampling for star {0} \n Error: {1}".format(
                    star.apogee_id, str(e)))
            continue

        logger.debug("\t done sampling ({:.2f} seconds)".format(time.time() -
                                                                t0))

        # For now, it's sufficient to write the run results to an HDF5 file
        n = run.requested_samples_per_star
        samples = samples[:n]

        # Write the samples that pass to the results file
        with h5py.File(results_filename, 'r+') as f:
            if star.apogee_id in f:
                del f[star.apogee_id]

            # HACK: this will overwrite the past samples!
            g = f.create_group(star.apogee_id)
            samples.to_hdf5(g)

        logger.debug("\t saved samples ({:.2f} seconds)".format(time.time() -
                                                                t0))

    pool.close()
    session.close()
示例#9
0
def main(config_file, pool, seed, overwrite=False, _continue=False):
    config_file = abspath(expanduser(config_file))

    # parse config file
    with open(config_file, 'r') as f:
        config = yaml.load(f.read())
        config['config_file'] = config_file

    # filename of sqlite database
    if 'database_file' not in config:
        database_file = None

    else:
        database_file = config['database_file']

    db_path = join(TWOFACE_CACHE_PATH, database_file)
    if not os.path.exists(db_path):
        raise IOError(
            "sqlite database not found at '{0}'\n Did you run "
            "scripts/initdb.py yet for that database?".format(db_path))

    logger.debug("Connecting to sqlite database at '{0}'".format(db_path))
    Session, engine = db_connect(database_path=db_path, ensure_db_exists=False)
    session = Session()

    # Retrieve or create a JokerRun instance
    run = get_run(config, session, overwrite=overwrite)
    params = run.get_joker_params()

    # Create TheJoker sampler instance with the specified random seed and pool
    rnd = np.random.RandomState(seed=seed)
    logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool))
    joker = TheJoker(params, random_state=rnd, pool=pool)

    # Create a cache of prior samples (if it doesn't exist) and store the
    # filename in the database.
    if not os.path.exists(run.prior_samples_file) or overwrite:
        logger.debug(
            "Prior samples file not found - generating {0} samples...".format(
                config['prior']['num_cache']))
        make_prior_cache(run.prior_samples_file,
                         joker,
                         nsamples=config['prior']['num_cache'])
        logger.debug("...done")

    # Get done APOGEE ID's
    done_subq = session.query(AllStar.apogee_id)\
                       .join(StarResult, JokerRun, Status)\
                       .filter(Status.id > 0).distinct()

    # Query to get all stars associated with this run that need processing:
    # they should have a status id = 0 (needs processing)
    star_query = session.query(AllStar).join(StarResult, JokerRun, Status)\
                                       .filter(JokerRun.name == run.name)\
                                       .filter(Status.id == 0)\
                                       .filter(~AllStar.apogee_id.in_(done_subq))

    # Base query to get a StarResult for a given Star so we can update the
    # status, etc.
    result_query = session.query(StarResult).join(AllStar, JokerRun)\
                                            .filter(JokerRun.name == run.name)\
                                            .filter(Status.id == 0)\
                                            .filter(~AllStar.apogee_id.in_(done_subq))

    # Create a file to cache the resulting posterior samples
    results_filename = join(TWOFACE_CACHE_PATH, "{0}.hdf5".format(run.name))
    n_stars = star_query.count()
    logger.info("{0} stars left to process for run '{1}'".format(
        n_stars, run.name))

    # Ensure that the results file exists - this is where we cache samples that
    # pass the rejection sampling step
    if not os.path.exists(results_filename):
        with h5py.File(results_filename, 'w') as f:
            pass

    # --------------------------------------------------------------------------
    # Here is where we do the actual processing of the data for each star. We
    # loop through all stars that still need processing and iteratively
    # rejection sample with larger and larger prior sample batch sizes. We do
    # this for efficiency, but the argument for this is somewhat made up...

    count = 0  # how many stars we've processed in this star batch
    batch_size = 16  # MAGIC NUMBER: how many stars to process before committing
    for star in star_query.all():

        if result_query.filter(
                AllStar.apogee_id == star.apogee_id).count() < 1:
            logger.debug('Star {0} has no result object!'.format(
                star.apogee_id))
            continue

        # Retrieve existing StarResult from database. We limit(1) because the
        # APOGEE_ID isn't unique, but we attach all visits for a given star to
        # all rows, so grabbing one of them is fine.
        result = result_query.filter(AllStar.apogee_id == star.apogee_id)\
                             .limit(1).one()

        logger.log(1, "Starting star {0}".format(star.apogee_id))
        logger.log(1, "Current status: {0}".format(str(result.status)))
        t0 = time.time()

        data = star.apogeervdata()
        logger.log(
            1, "\t visits loaded ({:.2f} seconds)".format(time.time() - t0))
        try:
            samples, ln_prior = joker.iterative_rejection_sample(
                data=data,
                n_requested_samples=run.requested_samples_per_star,
                prior_cache_file=run.prior_samples_file,
                n_prior_samples=run.max_prior_samples,
                return_logprobs=True)

        except Exception as e:
            logger.warning(
                "\t Failed sampling for star {0} \n Error: {1}".format(
                    star.apogee_id, str(e)))
            continue

        logger.debug("\t done sampling ({:.2f} seconds)".format(time.time() -
                                                                t0))

        # For now, it's sufficient to write the run results to an HDF5 file
        n = run.requested_samples_per_star
        all_ln_probs = ln_prior[:n]
        samples = samples[:n]
        n_actual_samples = len(all_ln_probs)

        # Write the samples that pass to the results file
        with h5py.File(results_filename, 'r+') as f:
            if star.apogee_id in f:
                del f[star.apogee_id]

            # HACK: this will overwrite the past samples!
            g = f.create_group(star.apogee_id)
            samples.to_hdf5(g)

            if 'ln_prior_probs' in g:
                del g['ln_prior_probs']
            g.create_dataset('ln_prior_probs', data=all_ln_probs)

        logger.debug("\t saved samples ({:.2f} seconds)".format(time.time() -
                                                                t0))

        if n_actual_samples >= run.requested_samples_per_star:
            result.status_id = 4  # completed

        elif n_actual_samples == 1:
            # Only one sample was returned - this is probably unimodal, so this
            # star needs MCMC
            result.status_id = 2  # needs mcmc

        else:

            if unimodal_P(samples, data):
                # Multiple samples were returned, but they look unimodal
                result.status_id = 2  # needs mcmc

            else:
                # Multiple samples were returned, but not enough to satisfy the
                # number requested in the config file
                result.status_id = 1  # needs more samples

        logger.debug("...done with star {} ({:.2f} seconds)".format(
            star.apogee_id,
            time.time() - t0))

        if count % batch_size == 0 and count > 0:
            session.commit()

        count += 1

    pool.close()

    session.commit()
    session.close()