コード例 #1
0
ファイル: base.py プロジェクト: bambang/beat
    def __init__(self, handler=None, homepath=None, stage_number=-1, backend='csv'):

        if handler is not None:
            self.handler = handler
        elif handler is None and homepath is not None:
            self.handler = SampleStage(homepath, backend=backend)
        else:
            raise TypeError('Either handler or homepath have to be not None')

        self.backend = backend
        self.number = stage_number
コード例 #2
0
class Stage(object):
    """
    Stage, containing sampling results and intermediate sampler
    parameters.
    """
    number = None
    path = None
    step = None
    updates = None
    mtrace = None

    def __init__(self,
                 handler=None,
                 homepath=None,
                 stage_number=-1,
                 backend='csv'):

        if handler is not None:
            self.handler = handler
        elif handler is None and homepath is not None:
            self.handler = SampleStage(homepath, backend=backend)
        else:
            raise TypeError('Either handler or homepath have to be not None')

        self.backend = backend
        self.number = stage_number

    def load_results(self,
                     varnames=None,
                     model=None,
                     stage_number=None,
                     chains=None,
                     load='trace'):
        """
        Load stage results from sampling.

        Parameters
        ----------
        model : :class:`pymc3.model.Model`
        stage_number : int
            Number of stage to load
        chains : list, optional
            of result chains to load
        load : str
            what to load and return 'full', 'trace', 'params'
        """
        if varnames is None and model is not None:
            varnames = [var.name for var in model.unobserved_RVs]
        elif varnames is None and model is None:
            raise ValueError(
                'Either "varnames" or "model" need to be not None!')

        if stage_number is None:
            stage_number = self.number

        self.path = self.handler.stage_path(stage_number)

        if not os.path.exists(self.path):
            stage_number = self.handler.highest_sampled_stage()

            logger.info('Stage results %s do not exist! Loading last completed'
                        ' stage %s' % (self.path, stage_number))
            self.path = self.handler.stage_path(stage_number)

        self.number = stage_number

        if load == 'full':
            to_load = ['params', 'trace']
        else:
            to_load = [load]

        if 'trace' in to_load:
            self.mtrace = self.handler.load_multitrace(stage_number,
                                                       varnames=varnames,
                                                       chains=chains)

        if 'params' in to_load:
            if model is not None:
                with model:
                    self.step, self.updates = self.handler.load_sampler_params(
                        stage_number)
            else:
                raise ValueError('To load sampler params model is required!')
コード例 #3
0
def estimate_hypers(step, problem):
    """
    Get initial estimates of the hyperparameters
    """
    from beat.sampler.base import iter_parallel_chains, init_stage, \
        init_chain_hypers

    logger.info('... Estimating hyperparameters ...')

    pc = problem.config.problem_config
    sc = problem.config.hyper_sampler_config
    pa = sc.parameters

    if not (pa.n_chains / pa.n_jobs).is_integer():
        raise ValueError('n_chains / n_jobs has to be a whole number!')

    name = problem.outfolder
    ensuredir(name)

    stage_handler = SampleStage(problem.outfolder, backend=sc.backend)
    chains, step, update = init_stage(stage_handler=stage_handler,
                                      step=step,
                                      stage=0,
                                      progressbar=sc.progressbar,
                                      model=problem.model,
                                      rm_flag=pa.rm_flag)

    # setting stage to 1 otherwise only one sample
    step.stage = 1
    step.n_steps = pa.n_steps

    with problem.model:
        mtrace = iter_parallel_chains(draws=pa.n_steps,
                                      chains=chains,
                                      step=step,
                                      stage_path=stage_handler.stage_path(1),
                                      progressbar=sc.progressbar,
                                      model=problem.model,
                                      n_jobs=pa.n_jobs,
                                      initializer=init_chain_hypers,
                                      initargs=(problem, ),
                                      buffer_size=sc.buffer_size,
                                      buffer_thinning=sc.buffer_thinning,
                                      chunksize=int(pa.n_chains / pa.n_jobs))

    thinned_chain_length = len(
        thin_buffer(list(range(pa.n_steps)),
                    sc.buffer_thinning,
                    ensure_last=True))
    for v in problem.hypernames:
        i = pc.hyperparameters[v]
        d = mtrace.get_values(v,
                              combine=True,
                              burn=int(thinned_chain_length * pa.burn),
                              thin=pa.thin,
                              squeeze=True)

        lower = num.floor(d.min()) - 2.
        upper = num.ceil(d.max()) + 2.
        logger.info('Updating hyperparameter %s from %f, %f to %f, %f' %
                    (v, i.lower, i.upper, lower, upper))
        pc.hyperparameters[v].lower = num.atleast_1d(lower)
        pc.hyperparameters[v].upper = num.atleast_1d(upper)
        pc.hyperparameters[v].testvalue = num.atleast_1d((upper + lower) / 2.)

    config_file_name = 'config_' + pc.mode + '.yaml'
    conf_out = os.path.join(problem.config.project_dir, config_file_name)

    problem.config.problem_config = pc
    bconfig.dump(problem.config, filename=conf_out)
コード例 #4
0
ファイル: test_pt.py プロジェクト: wangyf/beat
    def _test_sample(self, n_jobs, test_folder):
        logger.info('Running on %i cores...' % n_jobs)

        n = 4

        mu1 = num.ones(n) * (1. / 2)
        mu2 = -mu1

        stdev = 0.1
        sigma = num.power(stdev, 2) * num.eye(n)
        isigma = num.linalg.inv(sigma)
        dsigma = num.linalg.det(sigma)

        w1 = stdev
        w2 = (1 - stdev)

        def two_gaussians(x):
            log_like1 = - 0.5 * n * tt.log(2 * num.pi) \
                        - 0.5 * tt.log(dsigma) \
                        - 0.5 * (x - mu1).T.dot(isigma).dot(x - mu1)
            log_like2 = - 0.5 * n * tt.log(2 * num.pi) \
                        - 0.5 * tt.log(dsigma) \
                        - 0.5 * (x - mu2).T.dot(isigma).dot(x - mu2)
            return tt.log(w1 * tt.exp(log_like1) + w2 * tt.exp(log_like2))

        with pm.Model() as PT_test:
            X = pm.Uniform('X',
                           shape=n,
                           lower=-2. * num.ones_like(mu1),
                           upper=2. * num.ones_like(mu1),
                           testval=-1. * num.ones_like(mu1),
                           transform=None)
            like = pm.Deterministic('tmp', two_gaussians(X))
            llk = pm.Potential('like', like)

        with PT_test:
            step = metropolis.Metropolis(
                n_chains=n_jobs,
                likelihood_name=PT_test.deterministics[0].name,
                proposal_name='MultivariateCauchy',
                tune_interval=self.tune_interval)

        pt.pt_sample(
            step,
            n_chains=n_jobs,
            n_samples=self.n_samples,
            swap_interval=self.swap_interval,
            beta_tune_interval=self.beta_tune_interval,
            n_workers_posterior=self.n_workers_posterior,
            homepath=test_folder,
            progressbar=False,
            buffer_size=self.buffer_size,
            model=PT_test,
            rm_flag=False,
            keep_tmp=False)

        stage_handler = SampleStage(test_folder)

        mtrace = stage_handler.load_multitrace(-1, varnames=PT_test.vars)
        history = load_objects(os.path.join(stage_handler.stage_path(-1), sample_p_outname))

        n_steps = self.n_samples
        burn = self.burn
        thin = self.thin

        def burn_sample(x):
            if n_steps == 1:
                return x
            else:
                nchains = int(x.shape[0] / n_steps)
                xout = []
                for i in range(nchains):
                    nstart = int((n_steps * i) + (n_steps * burn))
                    nend = int(n_steps * (i + 1) - 1)
                    xout.append(x[nstart:nend:thin])

                return num.vstack(xout)

        from pymc3 import traceplot
        from matplotlib import pyplot as plt
        with PT_test:
            traceplot(mtrace, transform=burn_sample)

        fig, axes = plt.subplots(
            nrows=1, ncols=2, figsize=mpl_papersize('a5', 'portrait'))
        axes[0].plot(history.acceptance, 'r')
        axes[0].set_ylabel('Acceptance ratio')
        axes[0].set_xlabel('Update interval')
        axes[1].plot(num.array(history.t_scales), 'k')
        axes[1].set_ylabel('Temperature scaling')
        axes[1].set_xlabel('Update interval')

        n_acceptances = len(history)
        ncol = 3
        nrow = int(num.ceil(n_acceptances / float(ncol)))

        fig2, axes1 = plt.subplots(
            nrows=nrow, ncols=ncol, figsize=mpl_papersize('a4', 'portrait'))
        axes1 = num.atleast_2d(axes1)
        fig3, axes2 = plt.subplots(
            nrows=nrow, ncols=ncol, figsize=mpl_papersize('a4', 'portrait'))
        axes2 = num.atleast_2d(axes2)

        acc_arrays = history.get_acceptance_matrixes_array()
        sc_arrays = history.get_sample_counts_array()
        scvmin = sc_arrays.min(0).min(0)
        scvmax = sc_arrays.max(0).max(0)
        accvmin = acc_arrays.min(0).min(0)
        accvmax = acc_arrays.max(0).max(0)

        for i in range(ncol * nrow):
            rowi, coli = mod_i(i, ncol)
            #if i == n_acceptances:
             #   pass
                #plt.colorbar(im, axes1[rowi, coli])
                #plt.colorbar(im2, axes2[rowi, coli])

            if i > n_acceptances - 1:
                try:
                    fig2.delaxes(axes1[rowi, coli])
                    fig3.delaxes(axes2[rowi, coli])
                except KeyError:
                    pass
            else:
                axes1[rowi, coli].matshow(
                    history.acceptance_matrixes[i],
                    vmin=accvmin[i], vmax=accvmax[i], cmap='hot')
                axes1[rowi, coli].set_title('min %i, max%i' % (accvmin[i], accvmax[i]))
                axes1[rowi, coli].get_xaxis().set_ticklabels([])
                axes2[rowi, coli].matshow(
                    history.sample_counts[i], vmin=scvmin[i], vmax=scvmax[i], cmap='hot')
                axes2[rowi, coli].set_title('min %i, max%i' % (scvmin[i], scvmax[i]))
                axes2[rowi, coli].get_xaxis().set_ticklabels([])


        fig2.suptitle('Accepted number of samples')
        fig2.tight_layout()
        fig3.tight_layout()
        fig3.suptitle('Total number of samples')
        plt.show()
コード例 #5
0
def master_process(comm, tags, status, model, step, n_samples, swap_interval,
                   beta_tune_interval, n_workers_posterior, homepath,
                   progressbar, buffer_size, buffer_thinning, resample,
                   rm_flag):
    """
    Master process, that does the managing.
    Sends tasks to workers.
    Collects results and writes them to the trace.
    Fires workers once job is done.

    Parameters
    ----------
    comm : mpi.communicator
    tags : message tags
    status : mpt.status object

    the rest see pt_sample doc-string
    """

    size = comm.size  # total number of processes
    n_workers = size - 1

    if n_workers_posterior >= n_workers:
        raise ValueError(
            'Specified more workers that sample in the posterior "%i",'
            ' than there are total number of workers "%i"' %
            (n_workers_posterior, n_workers))

    stage = -1
    active_workers = 0
    steps_until_tune = 0

    # start sampling of chains with given seed
    logger.info('Master starting with %d workers' % n_workers)
    logger.info('Packing stuff for workers')
    manager = TemperingManager(step=step,
                               n_workers=n_workers,
                               n_workers_posterior=n_workers_posterior,
                               model=model,
                               progressbar=progressbar,
                               buffer_size=buffer_size,
                               swap_interval=swap_interval,
                               beta_tune_interval=beta_tune_interval)

    stage_handler = SampleStage(homepath, backend=step.backend)
    stage_handler.clean_directory(stage, chains=None, rm_flag=rm_flag)

    logger.info('Initializing result trace...')
    logger.info('Writing samples to file every %i samples.' % buffer_size)
    trace = backend_catalog[step.backend](
        dir_path=stage_handler.stage_path(stage),
        model=model,
        buffer_size=buffer_size,
        buffer_thinning=buffer_thinning,
        progressbar=progressbar)
    trace.setup(n_samples, 0, overwrite=False)
    # TODO load starting points from existing trace

    logger.info('Sending work packages to workers...')
    manager.update_betas()
    for beta in manager.betas:
        comm.recv(source=MPI.ANY_SOURCE, tag=tags.READY, status=status)
        source = status.Get_source()
        package = manager.get_package(source, resample=resample)
        comm.send(package, dest=source, tag=tags.INIT)
        logger.debug('Sent work package to worker %i' % source)
        active_workers += 1

    count_sample = 0
    counter = ChainCounter(n=n_samples,
                           n_jobs=1,
                           perc_disp=0.01,
                           subject='samples')

    logger.info('Posterior workers %s',
                list2string(manager.get_posterior_workers()))
    logger.info('Tuning worker betas every %i samples. \n' %
                beta_tune_interval)
    logger.info('Sampling ...')
    logger.info('------------')
    while True:

        m1 = num.empty(manager.step.lordering.size)
        comm.Recv([m1, MPI.DOUBLE],
                  source=MPI.ANY_SOURCE,
                  tag=MPI.ANY_TAG,
                  status=status)
        source1 = status.Get_source()
        logger.debug('Got sample 1 from worker %i' % source1)

        m2 = num.empty(manager.step.lordering.size)
        comm.Recv([m2, MPI.DOUBLE],
                  source=MPI.ANY_SOURCE,
                  tag=MPI.ANY_TAG,
                  status=status)
        source2 = status.Get_source()
        logger.debug('Got sample 2 from worker %i' % source2)

        # write results to trace if workers sample from posterior
        for source, m in zip([source1, source2], [m1, m2]):
            if source in manager.get_posterior_workers():
                count_sample += 1
                counter(source)
                trace.write(manager.worker_a2l(m, source), count_sample)
                steps_until_tune += 1

        m1, m2 = manager.propose_chain_swap(m1, m2, source1, source2)
        # beta updating
        if steps_until_tune >= beta_tune_interval:
            manager.tune_betas()
            steps_until_tune = 0

        if count_sample < n_samples:
            logger.debug('Sending states back to workers ...')
            for source in [source1, source2]:
                if not manager.worker_beta_updated(source1):
                    comm.Send([manager.get_beta(source), MPI.DOUBLE],
                              dest=source,
                              tag=tags.BETA)
                    manager.worker_beta_updated(source, check=True)

            comm.Send(m1, dest=source1, tag=tags.SAMPLE)
            comm.Send(m2, dest=source2, tag=tags.SAMPLE)
        else:
            logger.info('Requested number of samples reached!')
            trace.record_buffer()
            manager.dump_history(save_dir=stage_handler.stage_path(stage))
            break

    logger.info('Master finished! Chain complete!')
    logger.debug('Firing ...')
    for i in range(1, size):
        logger.debug('Sending pay cheque to %i' % i)
        comm.send(None, dest=i, tag=tags.EXIT)
        logger.debug('Fired worker %i' % i)
        active_workers -= 1

    logger.info('Feierabend! Sampling finished!')