示例#1
0
def main(data_path, config_file, data_file_ext, pool, seed, overwrite=False):
    # parse config file
    with open(config_file, 'r') as f:
        config = yaml.load(f.read())
        config['config_file'] = config_file

    cache_path = path.join(data_path, 'cache')
    os.makedirs(cache_path, exist_ok=True)

    n_prior_samples = config['prior']['num_cache']
    n_walkers = config['emcee']['n_walkers']
    joker_pars = config_to_jokerparams(config)

    prior_samples_file = path.join(cache_path, 'prior-samples.hdf5')

    # Create TheJoker sampler instance with the specified random seed and pool
    rnd = np.random.RandomState(seed=seed)
    logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool))
    joker = TheJoker(joker_pars, random_state=rnd, pool=pool)

    # Create a cache of prior samples (if it doesn't exist) and store the
    # filename in the database.
    if not os.path.exists(prior_samples_file) or overwrite:
        logger.debug("Prior samples file not found - generating {0} samples..."
                     .format(n_prior_samples))
        make_prior_cache(prior_samples_file, joker,
                         nsamples=n_prior_samples)
        logger.debug("...done")

    mcmc_model_filename = path.join(cache_path, 'model.pickle')
    data_files = glob.glob(path.join(data_path, '*.{0}'.format(data_file_ext)))

    for filename in data_files:
        basename = path.splitext(path.basename(filename))[0]
        logger.info('Processing file {0}'.format(basename))

        joker_results_filename = path.join(cache_path,
                                           '{0}-joker.hdf5'.format(basename))
        mcmc_results_filename = path.join(cache_path,
                                          '{0}-mcmc.hdf5'.format(basename))
        mcmc_chain_filename = path.join(cache_path,
                                        '{0}-chain.npy'.format(basename))

        data_tbl = QTable.read(filename)
        data = RVData(t=data_tbl['time'], rv=data_tbl['rv'],
                      stddev=data_tbl['rv_err'])

        if not path.exists(joker_results_filename) or overwrite:
            t0 = time.time()
            logger.log(1, "\t visits loaded ({:.2f} seconds)"
                       .format(time.time() - t0))
            try:
                samples = joker.rejection_sample(
                    data=data, prior_cache_file=prior_samples_file,
                    return_logprobs=False)

            except Exception as e:
                logger.warning("\t Failed sampling for star {0} \n Error: {1}"
                               .format(basename, str(e)))
                pool.close()
                sys.exit(1)

            logger.debug("\t done sampling ({:.2f} seconds)"
                         .format(time.time() - t0))

            # Write the samples that pass to the results file
            with h5py.File(joker_results_filename, 'w') as f:
                samples.to_hdf5(f)

            logger.debug("\t saved samples ({:.2f} seconds)"
                         .format(time.time() - t0))
            logger.debug("...done with star {} ({:.2f} seconds)"
                         .format(basename, time.time() - t0))

    pool.close()
    sys.exit(0)
示例#2
0
文件: init.py 项目: adrn/APOGEEBH
def initialize_db(allVisit_file,
                  allStar_file,
                  database_path,
                  drop_all=False,
                  batch_size=4096):
    """Initialize the database given FITS filenames for the APOGEE data.

    Parameters
    ----------
    allVisit_file : str
        Full path to APOGEE allVisit file.
    allStar_file : str
        Full path to APOGEE allStar file.
    database_file : str
        Filename (not path) of database file in cache path.
    drop_all : bool (optional)
        Drop all existing tables and re-create the database.
    batch_size : int (optional)
        How many rows to create before committing.
    """

    norm = lambda x: abspath(expanduser(x))
    allvisit = fits.getdata(norm(allVisit_file))
    allstar = fits.getdata(norm(allStar_file))

    # STAR_BAD
    ASPCAP_skip_bitmask = np.sum(2**np.array([23]))

    # VERY_BRIGHT_NEIGHBOR, LOW_SNR, SUSPECT_RV_COMBINATION, SUSPECT_BROAD_LINES
    STAR_skip_bitmask = np.sum(2**np.array([3, 4, 16, 17]))

    # First filter allStar flags
    mask = ((allstar['NVISITS'] >= 4) &
            ((allstar['ASPCAPFLAG'] & ASPCAP_skip_bitmask) == 0) &
            ((allstar['STARFLAG'] & STAR_skip_bitmask) == 0))
    stars = allstar[mask]
    visits = allvisit[np.isin(allvisit['APOGEE_ID'], stars['APOGEE_ID'])]

    # Next filter allVisit flags
    mask = (((visits['STARFLAG'] & STAR_skip_bitmask) == 0) &
            (visits['VRELERR'] < 100.) & np.isfinite(visits['VHELIO'])
            & np.isfinite(visits['VRELERR']) & (visits['VHELIO'] > -999))
    visits = visits[mask]
    v_apogee_ids, counts = np.unique(visits['APOGEE_ID'], return_counts=True)
    stars = stars[np.isin(stars['APOGEE_ID'], v_apogee_ids[counts >= 4])]
    visits = visits[np.isin(visits['APOGEE_ID'], stars['APOGEE_ID'])]

    # uniquify the stars
    _, idx = np.unique(stars['APOGEE_ID'], return_index=True)
    allstar_tbl = Table(stars[idx])
    allvisit_tbl = Table(visits)

    Session, engine = db_connect(database_path, ensure_db_exists=True)
    logger.debug("Connected to database at '{}'".format(database_path))

    if drop_all:
        # this is the magic that creates the tables based on the definitions in
        # twoface/db/model.py
        Base.metadata.drop_all()
        Base.metadata.create_all()

    session = Session()

    logger.debug("Loading allStar, allVisit tables...")

    # Figure out what data we need to pull out of the FITS files based on what
    # columns exist in the (empty) database
    allstar_skip = ['ID']
    allstar_colnames = []
    allstar_varchar = []
    for x in AllStar.__table__.columns:
        col = str(x).split('.')[1].upper()
        if col in allstar_skip:
            continue

        if str(x.type) == 'VARCHAR':
            allstar_varchar.append(col)

        allstar_colnames.append(col)

    allvisit_skip = ['ID']
    allvisit_colnames = []
    allvisit_varchar = []
    for x in AllVisit.__table__.columns:
        col = str(x).split('.')[1].upper()
        if col in allvisit_skip:
            continue

        if str(x.type) == 'VARCHAR':
            allvisit_varchar.append(col)

        allvisit_colnames.append(col)

    # --------------------------------------------------------------------------
    # First load the status table:
    #
    if session.query(Status).count() == 0:
        logger.debug("Populating Status table...")
        statuses = list()
        statuses.append(Status(id=0, message='untouched'))
        statuses.append(Status(id=1, message='needs more prior samples'))
        statuses.append(Status(id=2, message='needs mcmc'))
        statuses.append(Status(id=3, message='error'))
        statuses.append(Status(id=4, message='completed'))

        session.add_all(statuses)
        session.commit()
        logger.debug("...done")

    # --------------------------------------------------------------------------
    # Load the AllStar table:
    #
    logger.info("Loading AllStar table")

    # What APOGEE_ID's are already loaded?
    all_ap_ids = np.array([x.strip() for x in allstar_tbl['APOGEE_ID']])
    loaded_ap_ids = [x[0] for x in session.query(AllStar.apogee_id).all()]
    mask = np.logical_not(np.isin(all_ap_ids, loaded_ap_ids))
    logger.debug("{0} stars already loaded".format(len(loaded_ap_ids)))
    logger.debug("{0} stars left to load".format(mask.sum()))

    stars = []
    with Timer() as t:
        i = 0
        for row in allstar_tbl[mask]:  # Load every star
            row_data = tblrow_to_dbrow(row, allstar_colnames, allstar_varchar)

            # create a new object for this row
            star = AllStar(**row_data)
            stars.append(star)
            logger.log(1, 'Adding star {0} to database'.format(star))

            if i % batch_size == 0 and i > 0:
                session.add_all(stars)
                session.commit()
                logger.debug("Loaded batch {0} ({1:.2f} seconds)".format(
                    i, t.elapsed()))
                t.reset()
                stars = []

            i += 1

    if len(stars) > 0:
        session.add_all(stars)
        session.commit()

    # --------------------------------------------------------------------------
    # Load the AllVisit table:
    #
    logger.info("Loading AllVisit table")

    # What VISIT_ID's are already loaded?
    all_vis_ids = np.array([x.strip() for x in allvisit_tbl['VISIT_ID']])
    loaded_vis_ids = [x[0] for x in session.query(AllVisit.visit_id).all()]
    mask = np.logical_not(np.isin(all_vis_ids, loaded_vis_ids))
    logger.debug("{0} visits already loaded".format(len(loaded_vis_ids)))
    logger.debug("{0} visits left to load".format(mask.sum()))

    visits = []
    with Timer() as t:
        i = 0
        for row in allvisit_tbl[mask]:  # Load every visit
            row_data = tblrow_to_dbrow(row, allvisit_colnames,
                                       allvisit_varchar)

            # create a new object for this row
            visit = AllVisit(**row_data)
            visits.append(visit)
            logger.log(1, 'Adding visit {0} to database'.format(visit))

            if i % batch_size == 0 and i > 0:
                session.add_all(visits)
                session.commit()
                logger.debug("Loaded batch {0} ({1:.2f} seconds)".format(
                    i, t.elapsed()))
                t.reset()
                visits = []

            i += 1

    if len(visits) > 0:
        session.add_all(visits)
        session.commit()

    # --------------------------------------------------------------------------
    # Now associate rows in AllStar with rows in AllVisit
    logger.info("Linking AllVisit and AllStar tables")

    q = session.query(AllStar).order_by(AllStar.id)

    for i, sub_q in enumerate(paged_query(q, page_size=batch_size)):
        for star in sub_q:
            if len(star.visits) > 0:
                continue

            visits = session.query(AllVisit).filter(
                AllVisit.apogee_id == star.apogee_id).all()

            if len(visits) == 0:
                logger.warn("Visits not found for star {0}".format(star))
                continue

            logger.log(
                1,
                'Attaching {0} visits to star {1}'.format(len(visits), star))

            star.visits = visits

        logger.debug("Committing batch {0}".format(i))
        session.commit()

    session.commit()
    session.close()
示例#3
0
def main(pool, seed, overwrite=False, _continue=False):

    # HACK: hard-coded configuration!
    db_path = path.abspath('../cache/apogeebh.sqlite')
    prior_samples_file = path.abspath('../cache/prior-samples.hdf5')
    results_filename = path.abspath('../cache/apogeebh-joker.hdf5')
    n_prior = 536870912  # number of prior samples to generate
    n_requested_samples = 256  # how many samples to generate, nominally
    max_samples_per_star = 2048  # max. number of posterior samples to save
    P_min = 1 * u.day
    P_max = 1024 * u.day
    jitter = 150 * u.m / u.s

    if not os.path.exists(db_path):
        raise IOError(
            "sqlite database not found at '{0}'\n Did you run "
            "scripts/load_dr15_db.py yet for that database?".format(db_path))

    logger.debug("Connecting to sqlite database at '{0}'".format(db_path))
    Session, engine = db_connect(database_path=db_path, ensure_db_exists=False)
    session = Session()

    # Retrieve or create a JokerRun instance
    params = JokerParams(P_min=P_min, P_max=P_max, jitter=jitter)

    # Create TheJoker sampler instance with the specified random seed and pool
    rnd = np.random.RandomState(seed=seed)
    logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool))
    joker = TheJoker(params, random_state=rnd, pool=pool)

    # Create a cache of prior samples (if it doesn't exist) and store the
    # filename in the database.
    if not os.path.exists(prior_samples_file) or overwrite:
        logger.debug(
            "Prior samples file not found - generating {0} samples...".format(
                n_prior))
        make_prior_cache(prior_samples_file, joker, nsamples=n_prior)
        logger.debug("...done")

    # Get done APOGEE ID's
    done_subq = session.query(AllStar.apogee_id)\
                       .join(StarResult, Status)\
                       .filter(Status.id > 0).distinct()

    # Query to get all stars associated with this run that need processing:
    # they should have a status id = 0 (needs processing)
    star_query = session.query(AllStar)\
                        .filter(AllStar.vscatter >= 5.)\
                        .filter(~AllStar.apogee_id.in_(done_subq))

    # Base query to get a StarResult for a given Star so we can update the
    # status, etc.
    result_query = session.query(StarResult).join(AllStar)\
                          .filter(Status.id == 0)\
                          .filter(~AllStar.apogee_id.in_(done_subq))

    n_stars = star_query.count()
    logger.info("{0} stars left to process".format(n_stars))

    # Ensure that the results file exists - this is where we cache samples that
    # pass the rejection sampling step
    if not os.path.exists(results_filename):
        with h5py.File(results_filename, 'w') as f:
            pass

    # --------------------------------------------------------------------------
    # Here is where we do the actual processing of the data for each star. We
    # loop through all stars that still need processing and iteratively
    # rejection sample with larger and larger prior sample batch sizes. We do
    # this for efficiency, but the argument for this is somewhat made up...

    count = 0  # how many stars we've processed in this star batch
    batch_size = 16  # MAGIC NUMBER: how many stars to process before committing
    for star in star_query.all():

        if result_query.filter(
                AllStar.apogee_id == star.apogee_id).count() < 1:
            logger.debug('Star {0} has no result object!'.format(
                star.apogee_id))
            result = StarResult()
            star.result = result
            session.add(result)
            session.commit()

        # Retrieve existing StarResult from database. We limit(1) because
        # the APOGEE_ID isn't unique, but we attach all visits for a given
        # star to all rows, so grabbing one of them is fine.
        result = result_query.filter(AllStar.apogee_id == star.apogee_id)\
                             .limit(1).one()

        logger.log(1, "Starting star {0}".format(star.apogee_id))
        logger.log(1, "Current status: {0}".format(str(result.status)))
        t0 = time.time()

        data = star.apogeervdata()
        logger.log(
            1, "\t visits loaded ({:.2f} seconds)".format(time.time() - t0))
        try:
            samples, ln_prior = joker.iterative_rejection_sample(
                data=data,
                n_requested_samples=n_requested_samples,
                prior_cache_file=prior_samples_file,
                n_prior_samples=n_prior,
                return_logprobs=True)

        except Exception as e:
            logger.warning(
                "\t Failed sampling for star {0} \n Error: {1}".format(
                    star.apogee_id, str(e)))
            continue

        logger.debug("\t done sampling ({:.2f} seconds)".format(time.time() -
                                                                t0))

        # For now, it's sufficient to write the run results to an HDF5 file
        all_ln_probs = ln_prior[:max_samples_per_star]
        samples = samples[:max_samples_per_star]
        n_actual_samples = len(all_ln_probs)

        # Write the samples that pass to the results file
        with h5py.File(results_filename, 'r+') as f:
            if star.apogee_id in f:
                del f[star.apogee_id]

            # HACK: this will overwrite the past samples!
            g = f.create_group(star.apogee_id)
            samples.to_hdf5(g)

            if 'ln_prior_probs' in g:
                del g['ln_prior_probs']
            g.create_dataset('ln_prior_probs', data=all_ln_probs)

        logger.debug("\t saved samples ({:.2f} seconds)".format(time.time() -
                                                                t0))

        if n_actual_samples >= n_requested_samples:
            result.status_id = 4  # completed

        elif n_actual_samples == 1:
            # Only one sample was returned - this is probably unimodal, so this
            # star needs MCMC
            result.status_id = 2  # needs mcmc

        else:

            if unimodal_P(samples, data):
                # Multiple samples were returned, but they look unimodal
                result.status_id = 2  # needs mcmc

            else:
                # Multiple samples were returned, but not enough to satisfy the
                # number requested in the config file
                result.status_id = 1  # needs more samples

        logger.debug("...done with star {} ({:.2f} seconds)".format(
            star.apogee_id,
            time.time() - t0))

        if count % batch_size == 0 and count > 0:
            session.commit()

        count += 1

    pool.close()

    session.commit()
    session.close()
示例#4
0
def main(db_file, pool, seed, overwrite=False):

    db_path = join(TWOFACE_CACHE_PATH, db_file)
    if not os.path.exists(db_path):
        raise IOError("sqlite database not found at '{0}'\n Did you run "
                      "scripts/initdb.py yet for that database?"
                      .format(db_path))

    logger.debug("Connecting to sqlite database at '{0}'".format(db_path))
    Session, engine = db_connect(database_path=db_path,
                                 ensure_db_exists=False)
    session = Session()

    # TODO: all hard-set, these should be options
    params = JokerParams(P_min=10 * u.day,
                         P_max=1000 * u.day,
                         jitter=(9.5, 1.64),
                         jitter_unit=u.m/u.s)
    n_prior_samples = 2**28
    run_name = 'apogee-jitter'
    apogee_id = '2M01231070+1801407'

    results_filename = join(TWOFACE_CACHE_PATH, '{0}.hdf5'.format(apogee_id))
    prior_samples_file = join(TWOFACE_CACHE_PATH,
                              '{0}-prior.hdf5'.format(apogee_id))

    # Create TheJoker sampler instance with the specified random seed and pool
    rnd = np.random.RandomState(seed=seed)
    logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool))
    joker = TheJoker(params, random_state=rnd, pool=pool)

    # Create a cache of prior samples (if it doesn't exist) and store the
    # filename in the database.
    if not os.path.exists(prior_samples_file) or overwrite:
        logger.debug("Prior samples file not found - generating {0} samples..."
                     .format(n_prior_samples))
        make_prior_cache(prior_samples_file, joker,
                         nsamples=n_prior_samples)
        logger.debug("...done")

    # Query to get all stars associated with this run that need processing:
    # they should have a status id = 0 (needs processing)
    star_query = session.query(AllStar).join(StarResult, JokerRun, Status)\
                                       .filter(JokerRun.name == run_name)\
                                       .filter(AllStar.apogee_id == apogee_id)
    star = star_query.limit(1).one()

    logger.log(1, "Starting star {0}".format(star.apogee_id))
    t0 = time.time()

    data = star.apogeervdata()
    logger.log(1, "\t visits loaded ({:.2f} seconds)"
               .format(time.time()-t0))
    try:
        samples = joker.rejection_sample(
            data=data, prior_cache_file=prior_samples_file,
            return_logprobs=False)

    except Exception as e:
        logger.warning("\t Failed sampling for star {0} \n Error: {1}"
                       .format(star.apogee_id, str(e)))
        pool.close()
        sys.exit(1)

    logger.debug("\t done sampling ({:.2f} seconds)".format(time.time()-t0))

    # Write the samples that pass to the results file
    with h5py.File(results_filename, 'w') as f:
        samples.to_hdf5(f)

    logger.debug("\t saved samples ({:.2f} seconds)".format(time.time()-t0))
    logger.debug("...done with star {} ({:.2f} seconds)"
                 .format(star.apogee_id, time.time()-t0))

    pool.close()
def main(config_file, pool, seed, overwrite=False):
    config_file = path.abspath(path.expanduser(config_file))

    # parse config file
    with open(config_file, 'r') as f:
        config = yaml.load(f.read())

    # filename of sqlite database
    database_file = config['database_file']

    db_path = path.join(TWOFACE_CACHE_PATH, database_file)
    if not os.path.exists(db_path):
        raise IOError(
            "sqlite database not found at '{0}'\n Did you run "
            "scripts/initdb.py yet for that database?".format(db_path))

    logger.debug("Connecting to sqlite database at '{0}'".format(db_path))
    Session, engine = db_connect(database_path=db_path, ensure_db_exists=False)
    session = Session()

    run = get_run(config, session, overwrite=False)

    # The file with cached posterior samples:
    results_filename = path.join(TWOFACE_CACHE_PATH,
                                 "{0}.hdf5".format(run.name))
    if not path.exists(results_filename):
        raise IOError(
            "Posterior samples result file {0} doesn't exist! Are "
            "you sure you ran `run_apogee.py`?".format(results_filename))

    # Create TheJoker sampler instance with the specified random seed and pool
    rnd = np.random.RandomState(seed=seed)
    logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool))
    params = run.get_joker_params()
    joker = TheJoker(params, random_state=rnd, pool=pool)

    # TODO: we should make sure a 2nd prior cache exists, but because I'm only
    # going to deal with "needs mcmc", ignore this
    # _path, ext = path.splitext(run.prior_samples_file)
    # new_path = '{0}_moar{1}'.format(_path, ext)
    # if not path.exists(new_path):
    #     make_prior_cache(new_path, joker,
    #                      nsamples=8 * config['prior']['num_cache'], # ~100 GB
    #                      batch_size=2**24) # MAGIC NUMBER

    # Get all stars in this JokerRun that need more prior samples
    # TODO HACK: this query only selects "needs mcmc" stars!
    star_query = session.query(AllStar).join(StarResult, JokerRun, Status)\
                                       .filter(JokerRun.name == run.name)\
                                       .filter(Status.id == 2)
    # .filter(Status.id == 1)

    # Base query to get a StarResult for a given Star so we can update the
    # status, etc.
    result_query = session.query(StarResult).join(AllStar, JokerRun)\
                                            .filter(JokerRun.name == run.name)

    n_stars = star_query.count()
    logger.info("{0} stars left to process for run more samples '{1}'".format(
        n_stars, run.name))

    # --------------------------------------------------------------------------
    # Here is where we do the actual processing of the data for each star. We
    # loop through all stars that still need processing and continue with
    # rejection sampling.

    count = 0  # how many stars we've processed in this star batch
    batch_size = 16  # MAGIC NUMBER: how many stars to process before committing
    for star in star_query.all():

        if result_query.filter(
                AllStar.apogee_id == star.apogee_id).count() < 1:
            logger.debug('Star {0} has no result object!'.format(
                star.apogee_id))
            continue

        # Retrieve existing StarResult from database. We limit(1) because the
        # APOGEE_ID isn't unique, but we attach all visits for a given star to
        # all rows, so grabbing one of them is fine.
        result = result_query.filter(AllStar.apogee_id == star.apogee_id)\
                             .limit(1).one()

        logger.log(1, "Starting star {0}".format(star.apogee_id))
        logger.log(1, "Current status: {0}".format(str(result.status)))
        t0 = time.time()

        data = star.apogeervdata()
        logger.log(
            1, "\t visits loaded ({:.2f} seconds)".format(time.time() - t0))

        if result.status.id == 1:  # needs more prior samples

            try:
                samples, ln_prior = joker.iterative_rejection_sample(
                    data=data,
                    n_requested_samples=run.requested_samples_per_star,
                    # HACK: prior_cache_file=run.prior_samples_file,
                    prior_cache_file=new_path,
                    return_logprobs=True)

            except Exception as e:
                logger.warning(
                    "\t Failed sampling for star {0} \n Error: {1}".format(
                        star.apogee_id, str(e)))
                continue

            logger.debug(
                "\t done sampling ({:.2f} seconds)".format(time.time() - t0))

        elif result.status.id == 2:  # needs mcmc
            logger.debug('Firing up MCMC:')

            with h5py.File(results_filename, 'r') as f:
                samples0 = JokerSamples.from_hdf5(f[star.apogee_id])

            n_walkers = 2 * run.requested_samples_per_star
            model, samples, sampler = joker.mcmc_sample(data,
                                                        samples0,
                                                        n_burn=1024,
                                                        n_steps=65536,
                                                        n_walkers=n_walkers,
                                                        return_sampler=True)

            sampler.pool = None
            import pickle
            with open('test-mcmc.pickle', 'wb') as f:
                pickle.dump(sampler, f)

        pool.close()
        import sys
        sys.exit(0)

        # For now, it's sufficient to write the run results to an HDF5 file
        n = run.requested_samples_per_star
        all_ln_probs = ln_prior[:n]
        samples = samples[:n]

        # Write the samples that pass to the results file
        with h5py.File(results_filename, 'r+') as f:
            if star.apogee_id in f:
                del f[star.apogee_id]

            g = f.create_group(star.apogee_id)
            samples.to_hdf5(g)
            g.create_dataset('ln_prior_probs', data=all_ln_probs)

        logger.debug("\t saved samples ({:.2f} seconds)".format(time.time() -
                                                                t0))

        result.status_id = get_status_id(samples, data,
                                         run.n_requested_samples)

        logger.debug("...done with star {} ({:.2f} seconds)".format(
            star.apogee_id,
            time.time() - t0))

        if count % batch_size == 0 and count > 0:
            session.commit()

        count += 1

    pool.close()

    session.commit()
    session.close()
示例#6
0
def main(config_file, pool, seed, overwrite=False):
    # Default seed:
    if seed is None:
        seed = 42

    config_file = abspath(expanduser(config_file))

    # parse config file
    with open(config_file, 'r') as f:
        config = yaml.load(f.read())
        config['config_file'] = config_file

    # filename of sqlite database
    if 'database_file' not in config:
        database_file = None

    else:
        database_file = config['database_file']

    db_path = join(TWOFACE_CACHE_PATH, database_file)
    if not os.path.exists(db_path):
        raise IOError(
            "sqlite database not found at '{0}'\n Did you run "
            "scripts/initdb.py yet for that database?".format(db_path))

    logger.debug("Connecting to sqlite database at '{0}'".format(db_path))
    Session, engine = db_connect(database_path=db_path, ensure_db_exists=False)
    session = Session()

    # Retrieve or create a JokerRun instance
    run = get_run(config, session, overwrite=False)  # never overwrite
    params = run.get_joker_params()

    # Create TheJoker sampler instance with the specified random seed and pool
    rnd = np.random.RandomState(seed=seed)
    logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool))
    joker = TheJoker(params, random_state=rnd, pool=pool)

    # Create a file to cache the resulting posterior samples
    results_filename = join(TWOFACE_CACHE_PATH,
                            "{0}-control.hdf5".format(run.name))

    # Ensure that the results file exists - this is where we cache samples that
    # pass the rejection sampling step
    if not os.path.exists(results_filename):
        with h5py.File(results_filename, 'w') as f:
            pass

    with h5py.File(results_filename, 'r') as f:
        done_apogee_ids = list(f.keys())

    # Create a cache of prior samples (if it doesn't exist) and store the
    # filename in the database.
    if not os.path.exists(run.prior_samples_file):
        raise IOError("Prior cache must already exist.")

    # Get random IDs
    star_ids = session.query(AllStar.apogee_id)\
                      .join(StarResult, JokerRun, Status)\
                      .filter(Status.id > 0).distinct().all()
    star_ids = np.array([x[0] for x in star_ids])
    star_ids = rnd.choice(star_ids, size=NCONTROL, replace=False)
    star_ids = star_ids[~np.isin(star_ids, done_apogee_ids)]

    n_stars = len(star_ids)
    logger.info(
        "{0} stars left to process for run '{1}'; {2} already done.".format(
            n_stars, run.name, len(done_apogee_ids)))

    # --------------------------------------------------------------------------
    # Here is where we do the actual processing of the data for each star. We
    # loop through all stars that still need processing and iteratively
    # rejection sample with larger and larger prior sample batch sizes. We do
    # this for efficiency, but the argument for this is somewhat made up...

    for apid in star_ids:
        star = AllStar.get_apogee_id(session, apid)

        logger.log(1, "Starting star {0}".format(star.apogee_id))
        t0 = time.time()

        orig_data = star.apogeervdata()

        # HACK: this assumes we're sampling over the excess variance parameter
        # Generate new data with no RV orbital variations
        y = rnd.normal(params.jitter[0], params.jitter[1])
        s = np.exp(0.5 * y) * params._jitter_unit
        std = np.sqrt(s**2 + orig_data.stddev**2).to(orig_data.rv.unit).value
        new_rv = rnd.normal(np.mean(orig_data.rv).value, std)
        data = APOGEERVData(t=orig_data.t,
                            rv=new_rv * orig_data.rv.unit,
                            stddev=orig_data.stddev)

        logger.log(
            1, "\t visits loaded ({:.2f} seconds)".format(time.time() - t0))
        try:
            samples, ln_prior = joker.iterative_rejection_sample(
                data=data,
                n_requested_samples=run.requested_samples_per_star,
                prior_cache_file=run.prior_samples_file,
                n_prior_samples=run.max_prior_samples,
                return_logprobs=True)

        except Exception as e:
            logger.warning(
                "\t Failed sampling for star {0} \n Error: {1}".format(
                    star.apogee_id, str(e)))
            continue

        logger.debug("\t done sampling ({:.2f} seconds)".format(time.time() -
                                                                t0))

        # For now, it's sufficient to write the run results to an HDF5 file
        n = run.requested_samples_per_star
        samples = samples[:n]

        # Write the samples that pass to the results file
        with h5py.File(results_filename, 'r+') as f:
            if star.apogee_id in f:
                del f[star.apogee_id]

            # HACK: this will overwrite the past samples!
            g = f.create_group(star.apogee_id)
            samples.to_hdf5(g)

        logger.debug("\t saved samples ({:.2f} seconds)".format(time.time() -
                                                                t0))

    pool.close()
    session.close()
示例#7
0
def main(config_file, pool, seed, overwrite=False, _continue=False):
    config_file = abspath(expanduser(config_file))

    # parse config file
    with open(config_file, 'r') as f:
        config = yaml.load(f.read())
        config['config_file'] = config_file

    # filename of sqlite database
    if 'database_file' not in config:
        database_file = None

    else:
        database_file = config['database_file']

    db_path = join(TWOFACE_CACHE_PATH, database_file)
    if not os.path.exists(db_path):
        raise IOError(
            "sqlite database not found at '{0}'\n Did you run "
            "scripts/initdb.py yet for that database?".format(db_path))

    logger.debug("Connecting to sqlite database at '{0}'".format(db_path))
    Session, engine = db_connect(database_path=db_path, ensure_db_exists=False)
    session = Session()

    # Retrieve or create a JokerRun instance
    run = get_run(config, session, overwrite=overwrite)
    params = run.get_joker_params()

    # Create TheJoker sampler instance with the specified random seed and pool
    rnd = np.random.RandomState(seed=seed)
    logger.debug("Creating TheJoker instance with {0}, {1}".format(rnd, pool))
    joker = TheJoker(params, random_state=rnd, pool=pool)

    # Create a cache of prior samples (if it doesn't exist) and store the
    # filename in the database.
    if not os.path.exists(run.prior_samples_file) or overwrite:
        logger.debug(
            "Prior samples file not found - generating {0} samples...".format(
                config['prior']['num_cache']))
        make_prior_cache(run.prior_samples_file,
                         joker,
                         nsamples=config['prior']['num_cache'])
        logger.debug("...done")

    # Get done APOGEE ID's
    done_subq = session.query(AllStar.apogee_id)\
                       .join(StarResult, JokerRun, Status)\
                       .filter(Status.id > 0).distinct()

    # Query to get all stars associated with this run that need processing:
    # they should have a status id = 0 (needs processing)
    star_query = session.query(AllStar).join(StarResult, JokerRun, Status)\
                                       .filter(JokerRun.name == run.name)\
                                       .filter(Status.id == 0)\
                                       .filter(~AllStar.apogee_id.in_(done_subq))

    # Base query to get a StarResult for a given Star so we can update the
    # status, etc.
    result_query = session.query(StarResult).join(AllStar, JokerRun)\
                                            .filter(JokerRun.name == run.name)\
                                            .filter(Status.id == 0)\
                                            .filter(~AllStar.apogee_id.in_(done_subq))

    # Create a file to cache the resulting posterior samples
    results_filename = join(TWOFACE_CACHE_PATH, "{0}.hdf5".format(run.name))
    n_stars = star_query.count()
    logger.info("{0} stars left to process for run '{1}'".format(
        n_stars, run.name))

    # Ensure that the results file exists - this is where we cache samples that
    # pass the rejection sampling step
    if not os.path.exists(results_filename):
        with h5py.File(results_filename, 'w') as f:
            pass

    # --------------------------------------------------------------------------
    # Here is where we do the actual processing of the data for each star. We
    # loop through all stars that still need processing and iteratively
    # rejection sample with larger and larger prior sample batch sizes. We do
    # this for efficiency, but the argument for this is somewhat made up...

    count = 0  # how many stars we've processed in this star batch
    batch_size = 16  # MAGIC NUMBER: how many stars to process before committing
    for star in star_query.all():

        if result_query.filter(
                AllStar.apogee_id == star.apogee_id).count() < 1:
            logger.debug('Star {0} has no result object!'.format(
                star.apogee_id))
            continue

        # Retrieve existing StarResult from database. We limit(1) because the
        # APOGEE_ID isn't unique, but we attach all visits for a given star to
        # all rows, so grabbing one of them is fine.
        result = result_query.filter(AllStar.apogee_id == star.apogee_id)\
                             .limit(1).one()

        logger.log(1, "Starting star {0}".format(star.apogee_id))
        logger.log(1, "Current status: {0}".format(str(result.status)))
        t0 = time.time()

        data = star.apogeervdata()
        logger.log(
            1, "\t visits loaded ({:.2f} seconds)".format(time.time() - t0))
        try:
            samples, ln_prior = joker.iterative_rejection_sample(
                data=data,
                n_requested_samples=run.requested_samples_per_star,
                prior_cache_file=run.prior_samples_file,
                n_prior_samples=run.max_prior_samples,
                return_logprobs=True)

        except Exception as e:
            logger.warning(
                "\t Failed sampling for star {0} \n Error: {1}".format(
                    star.apogee_id, str(e)))
            continue

        logger.debug("\t done sampling ({:.2f} seconds)".format(time.time() -
                                                                t0))

        # For now, it's sufficient to write the run results to an HDF5 file
        n = run.requested_samples_per_star
        all_ln_probs = ln_prior[:n]
        samples = samples[:n]
        n_actual_samples = len(all_ln_probs)

        # Write the samples that pass to the results file
        with h5py.File(results_filename, 'r+') as f:
            if star.apogee_id in f:
                del f[star.apogee_id]

            # HACK: this will overwrite the past samples!
            g = f.create_group(star.apogee_id)
            samples.to_hdf5(g)

            if 'ln_prior_probs' in g:
                del g['ln_prior_probs']
            g.create_dataset('ln_prior_probs', data=all_ln_probs)

        logger.debug("\t saved samples ({:.2f} seconds)".format(time.time() -
                                                                t0))

        if n_actual_samples >= run.requested_samples_per_star:
            result.status_id = 4  # completed

        elif n_actual_samples == 1:
            # Only one sample was returned - this is probably unimodal, so this
            # star needs MCMC
            result.status_id = 2  # needs mcmc

        else:

            if unimodal_P(samples, data):
                # Multiple samples were returned, but they look unimodal
                result.status_id = 2  # needs mcmc

            else:
                # Multiple samples were returned, but not enough to satisfy the
                # number requested in the config file
                result.status_id = 1  # needs more samples

        logger.debug("...done with star {} ({:.2f} seconds)".format(
            star.apogee_id,
            time.time() - t0))

        if count % batch_size == 0 and count > 0:
            session.commit()

        count += 1

    pool.close()

    session.commit()
    session.close()