def init_hmc(model, stepsize, init="jitter+adapt_diag", chains=1): from pymc3.step_methods.hmc import quadpotential if init == 'jitter+adapt_diag': start = [] for _ in range(chains): mean = {var: val.copy() for var, val in model.test_point.items()} for val in mean.values(): val[...] += 2 * np.random.rand(*val.shape) - 1 start.append(mean) mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0) var = np.ones_like(mean) potential = quadpotential.QuadPotentialDiagAdapt( model.ndim, mean, var, 10) return pm.step_methods.HamiltonianMC(step_scale=stepsize, potential=potential, path_length=1) else: raise NotImplementedError()
def init_nuts(init='auto', chains=1, n_init=500000, model=None, random_seed=None, progressbar=True, **kwargs): """Set up the mass matrix initialization for NUTS. NUTS convergence and sampling speed is extremely dependent on the choice of mass/scaling matrix. This function implements different methods for choosing or adapting the mass matrix. Parameters ---------- init : str Initialization method to use. * auto : Choose a default initialization method automatically. Currently, this is `'jitter+adapt_diag'`, but this can change in the future. If you depend on the exact behaviour, choose an initialization method explicitly. * adapt_diag : Start with a identity mass matrix and then adapt a diagonal based on the variance of the tuning samples. All chains use the test value (usually the prior mean) as starting point. * jitter+adapt_diag : Same as `adapt_diag`, but add uniform jitter in [-1, 1] to the starting point in each chain. * advi+adapt_diag : Run ADVI and then adapt the resulting diagonal mass matrix based on the sample variance of the tuning samples. * advi+adapt_diag_grad : Run ADVI and then adapt the resulting diagonal mass matrix based on the variance of the gradients during tuning. This is **experimental** and might be removed in a future release. * advi : Run ADVI to estimate posterior mean and diagonal mass matrix. * advi_map: Initialize ADVI with MAP and use MAP as starting point. * map : Use the MAP as starting point. This is discouraged. * nuts : Run NUTS and estimate posterior mean and mass matrix from the trace. chains : int Number of jobs to start. n_init : int Number of iterations of initializer If 'ADVI', number of iterations, if 'nuts', number of draws. model : Model (optional if in `with` context) progressbar : bool Whether or not to display a progressbar for advi sampling. **kwargs : keyword arguments Extra keyword arguments are forwarded to pymc3.NUTS. Returns ------- start : pymc3.model.Point Starting point for sampler nuts_sampler : pymc3.step_methods.NUTS Instantiated and initialized NUTS sampler object """ model = pm.modelcontext(model) vars = kwargs.get('vars', model.vars) if set(vars) != set(model.vars): raise ValueError('Must use init_nuts on all variables of a model.') if not pm.model.all_continuous(vars): raise ValueError('init_nuts can only be used for models with only ' 'continuous variables.') if not isinstance(init, str): raise TypeError('init must be a string.') if init is not None: init = init.lower() if init == 'auto': init = 'jitter+adapt_diag' pm._log.info('Initializing NUTS using {}...'.format(init)) if random_seed is not None: random_seed = int(np.atleast_1d(random_seed)[0]) np.random.seed(random_seed) cb = [ pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff='absolute'), pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff='relative'), ] if init == 'adapt_diag': start = [model.test_point] * chains mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0) var = np.ones_like(mean) potential = quadpotential.QuadPotentialDiagAdapt( model.ndim, mean, var, 10) elif init == 'jitter+adapt_diag': start = [] for _ in range(chains): mean = {var: val.copy() for var, val in model.test_point.items()} for val in mean.values(): val[...] += 2 * np.random.rand(*val.shape) - 1 start.append(mean) mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0) var = np.ones_like(mean) potential = quadpotential.QuadPotentialDiagAdapt( model.ndim, mean, var, 10) elif init == 'advi+adapt_diag_grad': approx = pm.fit( random_seed=random_seed, n=n_init, method='advi', model=model, callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) # type: pm.MeanField start = approx.sample(draws=chains) start = list(start) stds = approx.bij.rmap(approx.std.eval()) cov = model.dict_to_array(stds)**2 mean = approx.bij.rmap(approx.mean.get_value()) mean = model.dict_to_array(mean) weight = 50 potential = quadpotential.QuadPotentialDiagAdaptGrad( model.ndim, mean, cov, weight) elif init == 'advi+adapt_diag': approx = pm.fit( random_seed=random_seed, n=n_init, method='advi', model=model, callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) # type: pm.MeanField start = approx.sample(draws=chains) start = list(start) stds = approx.bij.rmap(approx.std.eval()) cov = model.dict_to_array(stds)**2 mean = approx.bij.rmap(approx.mean.get_value()) mean = model.dict_to_array(mean) weight = 50 potential = quadpotential.QuadPotentialDiagAdapt( model.ndim, mean, cov, weight) elif init == 'advi': approx = pm.fit(random_seed=random_seed, n=n_init, method='advi', model=model, callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window) # type: pm.MeanField start = approx.sample(draws=chains) start = list(start) stds = approx.bij.rmap(approx.std.eval()) cov = model.dict_to_array(stds)**2 potential = quadpotential.QuadPotentialDiag(cov) elif init == 'advi_map': start = pm.find_MAP(include_transformed=True) approx = pm.MeanField(model=model, start=start) pm.fit(random_seed=random_seed, n=n_init, method=pm.KLqp(approx), callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window) start = approx.sample(draws=chains) start = list(start) stds = approx.bij.rmap(approx.std.eval()) cov = model.dict_to_array(stds)**2 potential = quadpotential.QuadPotentialDiag(cov) elif init == 'map': start = pm.find_MAP(include_transformed=True) cov = pm.find_hessian(point=start) start = [start] * chains potential = quadpotential.QuadPotentialFull(cov) elif init == 'nuts': init_trace = pm.sample(draws=n_init, step=pm.NUTS(), tune=n_init // 2, random_seed=random_seed) cov = np.atleast_1d(pm.trace_cov(init_trace)) start = list(np.random.choice(init_trace, chains)) potential = quadpotential.QuadPotentialFull(cov) else: raise NotImplementedError( 'Initializer {} is not supported.'.format(init)) step = pm.NUTS(potential=potential, **kwargs) return start, step
mu=0, tau=tau2 * (D2 - rho2 * W2), shape=(1, len(y))) mu = tt.exp(phi1.T + phi2.T) + b0 + offset[:, np.newaxis] ncrimes = pm.Poisson('ncrimes', mu=mu, observed=y) pooled_trace = pm.sample(1000, njobs=4) #%% njobs = 4 from pymc3.step_methods.hmc import quadpotential with model: approx = pm.fit( n=200000, method='advi', progressbar=True, obj_optimizer=pm.adagrad_window, ) start = approx.sample(draws=njobs) start = list(start) stds = approx.gbij.rmap(approx.std.eval()) cov = model.dict_to_array(stds)**2 mean = approx.gbij.rmap(approx.mean.get_value()) mean = model.dict_to_array(mean) weight = 50 potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, cov, weight) step = pm.NUTS(potential=potential) pooled_trace = pm.sample(1000, step=step, njobs=njobs, tune=1000) pm.traceplot(pooled_trace)
import pymc3 as pm from pymc3.step_methods.hmc import quadpotential from pymc3.step_methods import step_sizes n_chains = 4 with pm.Model() as m: x = pm.Normal("x", shape=10) # init == 'jitter+adapt_diag' start = [] for _ in range(n_chains): mean = {var: val.copy() for var, val in m.test_point.items()} for val in mean.values(): val[...] += 2 * np.random.rand(*val.shape) - 1 start.append(mean) mean = np.mean([m.dict_to_array(vals) for vals in start], axis=0) var = np.ones_like(mean) potential = quadpotential.QuadPotentialDiagAdapt(m.ndim, mean, var, 10) step = pm.NUTS(potential=potential) trace1 = pm.sample(1000, step=step, tune=1000, cores=n_chains) with m: # need to be the same model step_size = trace1.get_sampler_stats("step_size_bar")[-1] step.tune = False step.step_adapt = step_sizes.DualAverageAdaptation( step_size, step.target_accept, 0.05, 0.75, 10 ) trace2 = pm.sample(draws=100, step=step, tune=0, cores=n_chains) print(trace2[-1])