def init_hmc(model, stepsize, init="jitter+adapt_diag", chains=1):
    from pymc3.step_methods.hmc import quadpotential
    if init == 'jitter+adapt_diag':
        start = []
        for _ in range(chains):
            mean = {var: val.copy() for var, val in model.test_point.items()}
            for val in mean.values():
                val[...] += 2 * np.random.rand(*val.shape) - 1
            start.append(mean)
        mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
        var = np.ones_like(mean)
        potential = quadpotential.QuadPotentialDiagAdapt(
            model.ndim, mean, var, 10)

        return pm.step_methods.HamiltonianMC(step_scale=stepsize,
                                             potential=potential,
                                             path_length=1)
    else:
        raise NotImplementedError()
Example #2
0
def init_nuts(init='auto',
              chains=1,
              n_init=500000,
              model=None,
              random_seed=None,
              progressbar=True,
              **kwargs):
    """Set up the mass matrix initialization for NUTS.

    NUTS convergence and sampling speed is extremely dependent on the
    choice of mass/scaling matrix. This function implements different
    methods for choosing or adapting the mass matrix.

    Parameters
    ----------
    init : str
        Initialization method to use.

        * auto : Choose a default initialization method automatically.
          Currently, this is `'jitter+adapt_diag'`, but this can change in
          the future. If you depend on the exact behaviour, choose an
          initialization method explicitly.
        * adapt_diag : Start with a identity mass matrix and then adapt
          a diagonal based on the variance of the tuning samples. All
          chains use the test value (usually the prior mean) as starting
          point.
        * jitter+adapt_diag : Same as `adapt_diag`, but add uniform jitter
          in [-1, 1] to the starting point in each chain.
        * advi+adapt_diag : Run ADVI and then adapt the resulting diagonal
          mass matrix based on the sample variance of the tuning samples.
        * advi+adapt_diag_grad : Run ADVI and then adapt the resulting
          diagonal mass matrix based on the variance of the gradients
          during tuning. This is **experimental** and might be removed
          in a future release.
        * advi : Run ADVI to estimate posterior mean and diagonal mass
          matrix.
        * advi_map: Initialize ADVI with MAP and use MAP as starting point.
        * map : Use the MAP as starting point. This is discouraged.
        * nuts : Run NUTS and estimate posterior mean and mass matrix from
          the trace.
    chains : int
        Number of jobs to start.
    n_init : int
        Number of iterations of initializer
        If 'ADVI', number of iterations, if 'nuts', number of draws.
    model : Model (optional if in `with` context)
    progressbar : bool
        Whether or not to display a progressbar for advi sampling.
    **kwargs : keyword arguments
        Extra keyword arguments are forwarded to pymc3.NUTS.

    Returns
    -------
    start : pymc3.model.Point
        Starting point for sampler
    nuts_sampler : pymc3.step_methods.NUTS
        Instantiated and initialized NUTS sampler object
    """
    model = pm.modelcontext(model)

    vars = kwargs.get('vars', model.vars)
    if set(vars) != set(model.vars):
        raise ValueError('Must use init_nuts on all variables of a model.')
    if not pm.model.all_continuous(vars):
        raise ValueError('init_nuts can only be used for models with only '
                         'continuous variables.')

    if not isinstance(init, str):
        raise TypeError('init must be a string.')

    if init is not None:
        init = init.lower()

    if init == 'auto':
        init = 'jitter+adapt_diag'

    pm._log.info('Initializing NUTS using {}...'.format(init))

    if random_seed is not None:
        random_seed = int(np.atleast_1d(random_seed)[0])
        np.random.seed(random_seed)

    cb = [
        pm.callbacks.CheckParametersConvergence(tolerance=1e-2,
                                                diff='absolute'),
        pm.callbacks.CheckParametersConvergence(tolerance=1e-2,
                                                diff='relative'),
    ]

    if init == 'adapt_diag':
        start = [model.test_point] * chains
        mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
        var = np.ones_like(mean)
        potential = quadpotential.QuadPotentialDiagAdapt(
            model.ndim, mean, var, 10)
    elif init == 'jitter+adapt_diag':
        start = []
        for _ in range(chains):
            mean = {var: val.copy() for var, val in model.test_point.items()}
            for val in mean.values():
                val[...] += 2 * np.random.rand(*val.shape) - 1
            start.append(mean)
        mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
        var = np.ones_like(mean)
        potential = quadpotential.QuadPotentialDiagAdapt(
            model.ndim, mean, var, 10)
    elif init == 'advi+adapt_diag_grad':
        approx = pm.fit(
            random_seed=random_seed,
            n=n_init,
            method='advi',
            model=model,
            callbacks=cb,
            progressbar=progressbar,
            obj_optimizer=pm.adagrad_window,
        )  # type: pm.MeanField
        start = approx.sample(draws=chains)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds)**2
        mean = approx.bij.rmap(approx.mean.get_value())
        mean = model.dict_to_array(mean)
        weight = 50
        potential = quadpotential.QuadPotentialDiagAdaptGrad(
            model.ndim, mean, cov, weight)
    elif init == 'advi+adapt_diag':
        approx = pm.fit(
            random_seed=random_seed,
            n=n_init,
            method='advi',
            model=model,
            callbacks=cb,
            progressbar=progressbar,
            obj_optimizer=pm.adagrad_window,
        )  # type: pm.MeanField
        start = approx.sample(draws=chains)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds)**2
        mean = approx.bij.rmap(approx.mean.get_value())
        mean = model.dict_to_array(mean)
        weight = 50
        potential = quadpotential.QuadPotentialDiagAdapt(
            model.ndim, mean, cov, weight)
    elif init == 'advi':
        approx = pm.fit(random_seed=random_seed,
                        n=n_init,
                        method='advi',
                        model=model,
                        callbacks=cb,
                        progressbar=progressbar,
                        obj_optimizer=pm.adagrad_window)  # type: pm.MeanField
        start = approx.sample(draws=chains)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds)**2
        potential = quadpotential.QuadPotentialDiag(cov)
    elif init == 'advi_map':
        start = pm.find_MAP(include_transformed=True)
        approx = pm.MeanField(model=model, start=start)
        pm.fit(random_seed=random_seed,
               n=n_init,
               method=pm.KLqp(approx),
               callbacks=cb,
               progressbar=progressbar,
               obj_optimizer=pm.adagrad_window)
        start = approx.sample(draws=chains)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds)**2
        potential = quadpotential.QuadPotentialDiag(cov)
    elif init == 'map':
        start = pm.find_MAP(include_transformed=True)
        cov = pm.find_hessian(point=start)
        start = [start] * chains
        potential = quadpotential.QuadPotentialFull(cov)
    elif init == 'nuts':
        init_trace = pm.sample(draws=n_init,
                               step=pm.NUTS(),
                               tune=n_init // 2,
                               random_seed=random_seed)
        cov = np.atleast_1d(pm.trace_cov(init_trace))
        start = list(np.random.choice(init_trace, chains))
        potential = quadpotential.QuadPotentialFull(cov)
    else:
        raise NotImplementedError(
            'Initializer {} is not supported.'.format(init))

    step = pm.NUTS(potential=potential, **kwargs)

    return start, step
                       mu=0,
                       tau=tau2 * (D2 - rho2 * W2),
                       shape=(1, len(y)))

    mu = tt.exp(phi1.T + phi2.T) + b0 + offset[:, np.newaxis]

    ncrimes = pm.Poisson('ncrimes', mu=mu, observed=y)
    pooled_trace = pm.sample(1000, njobs=4)
#%%
njobs = 4
from pymc3.step_methods.hmc import quadpotential
with model:
    approx = pm.fit(
        n=200000,
        method='advi',
        progressbar=True,
        obj_optimizer=pm.adagrad_window,
    )
    start = approx.sample(draws=njobs)
    start = list(start)
    stds = approx.gbij.rmap(approx.std.eval())
    cov = model.dict_to_array(stds)**2
    mean = approx.gbij.rmap(approx.mean.get_value())
    mean = model.dict_to_array(mean)
    weight = 50
    potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, cov,
                                                     weight)
    step = pm.NUTS(potential=potential)
    pooled_trace = pm.sample(1000, step=step, njobs=njobs, tune=1000)

pm.traceplot(pooled_trace)
import pymc3 as pm
from pymc3.step_methods.hmc import quadpotential
from pymc3.step_methods import step_sizes

n_chains = 4

with pm.Model() as m:
    x = pm.Normal("x", shape=10)
    # init == 'jitter+adapt_diag'
    start = []
    for _ in range(n_chains):
        mean = {var: val.copy() for var, val in m.test_point.items()}
        for val in mean.values():
            val[...] += 2 * np.random.rand(*val.shape) - 1
        start.append(mean)
    mean = np.mean([m.dict_to_array(vals) for vals in start], axis=0)
    var = np.ones_like(mean)
    potential = quadpotential.QuadPotentialDiagAdapt(m.ndim, mean, var, 10)
    step = pm.NUTS(potential=potential)
    trace1 = pm.sample(1000, step=step, tune=1000, cores=n_chains)

with m:  # need to be the same model
    step_size = trace1.get_sampler_stats("step_size_bar")[-1]

    step.tune = False
    step.step_adapt = step_sizes.DualAverageAdaptation(
        step_size, step.target_accept, 0.05, 0.75, 10
    )
    trace2 = pm.sample(draws=100, step=step, tune=0, cores=n_chains)
    print(trace2[-1])