def test_nest_context_works(self): with pm.Model() as m: new = NewModel() with new: assert pm.modelcontext(None) is new assert pm.modelcontext(None) is m assert "v1" in m.named_vars assert "v2" in m.named_vars
def __init__(self, *args, **kwargs): """ Initialise DEMetropolisZMLDA, uses parent class __init__ and extra code specific for use within MLDA. """ # flag used for signaling the end of tuning self.tuning_end_trigger = False model = pm.modelcontext(kwargs.get("model", None)) initial_values = model.initial_point # flag to that variance reduction is activated - forces DEMetropolisZMLDA # to store quantities of interest in a register if True self.mlda_variance_reduction = kwargs.pop("mlda_variance_reduction", False) if self.mlda_variance_reduction: # Subsampling rate of MLDA sampler one level up self.mlda_subsampling_rate_above = kwargs.pop( "mlda_subsampling_rate_above") self.sub_counter = 0 self.Q_last = np.nan self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above # call parent class __init__ super().__init__(*args, **kwargs) # modify the delta function and point to model if VR is used if self.mlda_variance_reduction: self.model = model self.delta_logp_factory = self.delta_logp self.delta_logp = lambda q, q0: -self.delta_logp_factory(q0, q)
def model_to_graphviz(model=None, *, formatting: str = "plain"): """Produce a graphviz Digraph from a PyMC model. Requires graphviz, which may be installed most easily with conda install -c conda-forge python-graphviz Alternatively, you may install the `graphviz` binaries yourself, and then `pip install graphviz` to get the python bindings. See http://graphviz.readthedocs.io/en/stable/manual.html for more information. Parameters ---------- model : pm.Model The model to plot. Not required when called from inside a modelcontext. formatting : str one of { "plain" } """ if not "plain" in formatting: raise ValueError( f"Unsupported formatting for graph nodes: '{formatting}'. See docstring." ) if formatting != "plain": warnings.warn( "Formattings other than 'plain' are currently not supported.", UserWarning, stacklevel=2) model = pm.modelcontext(model) return ModelGraph(model).make_graph(formatting=formatting)
def subsample( draws=1, step=None, start=None, trace=None, tune=0, model=None, ): """ A stripped down version of sample(), which is called only by the RecursiveDAProposal (which is the proposal used in the MLDA stepper). RecursiveDAProposal only requires a small set of the input parameters and checks normally performed by sample(), and this function thus skips some of the code in sampler(). It directly calls _iter_sample(), rather than sample_many(). The result is a reduced overhead when running multiple levels in MLDA. """ model = pm.modelcontext(model) chain = 0 random_seed = np.random.randint(2**30) callback = None draws += tune sampling = pm.sampling._iter_sample(draws, step, start, trace, chain, tune, model, random_seed, callback) try: for it, (trace, _) in enumerate(sampling): pass except KeyboardInterrupt: pass return trace
def __init__(self, *args, **kwargs): """ Initialise MetropolisMLDA. This is a mix of the parent's class' initialisation and some extra code specific for MLDA. """ model = pm.modelcontext(kwargs.get("model", None)) initial_values = model.compute_initial_point() # flag to that variance reduction is activated - forces MetropolisMLDA # to store quantities of interest in a register if True self.mlda_variance_reduction = kwargs.pop("mlda_variance_reduction", False) if self.mlda_variance_reduction: # Subsampling rate of MLDA sampler one level up self.mlda_subsampling_rate_above = kwargs.pop( "mlda_subsampling_rate_above") self.sub_counter = 0 self.Q_last = np.nan self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above # call parent class __init__ super().__init__(*args, **kwargs) # modify the delta function and point to model if VR is used if self.mlda_variance_reduction: self.model = model self.delta_logp_factory = self.delta_logp self.delta_logp = lambda q, q0: -self.delta_logp_factory(q0, q)
def model_to_graphviz(model=None, *, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"): """Produce a graphviz Digraph from a PyMC model. Requires graphviz, which may be installed most easily with conda install -c conda-forge python-graphviz Alternatively, you may install the `graphviz` binaries yourself, and then `pip install graphviz` to get the python bindings. See http://graphviz.readthedocs.io/en/stable/manual.html for more information. Parameters ---------- model : pm.Model The model to plot. Not required when called from inside a modelcontext. var_names : iterable of variable names, optional Subset of variables to be plotted that identify a subgraph with respect to the entire model graph formatting : str, optional one of { "plain" } Examples -------- How to plot the graph of the model. .. code-block:: python import numpy as np from pymc import HalfCauchy, Model, Normal, model_to_graphviz J = 8 y = np.array([28, 8, -3, 7, -1, 1, 18, 12]) sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18]) with Model() as schools: eta = Normal("eta", 0, 1, shape=J) mu = Normal("mu", 0, sigma=1e6) tau = HalfCauchy("tau", 25) theta = mu + tau * eta obs = Normal("obs", theta, sigma=sigma, observed=y) model_to_graphviz(schools) """ if not "plain" in formatting: raise ValueError( f"Unsupported formatting for graph nodes: '{formatting}'. See docstring." ) if formatting != "plain": warnings.warn( "Formattings other than 'plain' are currently not supported.", UserWarning, stacklevel=2) model = pm.modelcontext(model) return ModelGraph(model).make_graph(var_names=var_names, formatting=formatting)
def __init__(self, vars, order="random", transit_p=0.8, model=None): model = pm.modelcontext(model) # transition probabilities self.transit_p = transit_p initial_point = model.initial_point() vars = [model.rvs_to_values.get(var, var) for var in vars] self.dim = sum(initial_point[v.name].size for v in vars) if order == "random": self.shuffle_dims = True self.order = list(range(self.dim)) else: if sorted(order) != list(range(self.dim)): raise ValueError("Argument 'order' has to be a permutation") self.shuffle_dims = False self.order = order if not all([v.dtype in pm.discrete_types for v in vars]): raise ValueError( "All variables must be binary for BinaryGibbsMetropolis") super().__init__(vars, [model.compile_logp()])
def __init__(self, vars, proposal="uniform", order="random", model=None): model = pm.modelcontext(model) vars = [model.rvs_to_values.get(var, var) for var in vars] vars = pm.inputvars(vars) initial_point = model.initial_point() dimcats = [] # The above variable is a list of pairs (aggregate dimension, number # of categories). For example, if vars = [x, y] with x being a 2-D # variable with M categories and y being a 3-D variable with N # categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)]. for v in vars: v_init_val = initial_point[v.name] rv_var = model.values_to_rvs[v] distr = getattr(rv_var.owner, "op", None) if isinstance(distr, CategoricalRV): k_graph = rv_var.owner.inputs[3].shape[-1] (k_graph, ), _ = rvs_to_value_vars((k_graph, ), apply_transforms=True) k = model.compile_fn(k_graph, inputs=model.value_vars, on_unused_input="ignore")(initial_point) elif isinstance(distr, BernoulliRV): k = 2 else: raise ValueError( "All variables must be categorical or binary" + "for CategoricalGibbsMetropolis") start = len(dimcats) dimcats += [(dim, k) for dim in range(start, start + v_init_val.size)] if order == "random": self.shuffle_dims = True self.dimcats = dimcats else: if sorted(order) != list(range(len(dimcats))): raise ValueError("Argument 'order' has to be a permutation") self.shuffle_dims = False self.dimcats = [dimcats[j] for j in order] if proposal == "uniform": self.astep = self.astep_unif elif proposal == "proportional": # Use the optimized "Metropolized Gibbs Sampler" described in Liu96. self.astep = self.astep_prop else: raise ValueError( "Argument 'proposal' should either be 'uniform' or 'proportional'" ) super().__init__(vars, [model.compile_logp()])
def __init__(self, name="", model=None): super().__init__(name, model) assert pm.modelcontext(None) is self # 1) init variables with Var method self.register_rv(pm.Normal.dist(), "v1") self.v2 = pm.Normal("v2", mu=0, sigma=1) # 2) Potentials and Deterministic variables with method too # be sure that names will not overlap with other same models pm.Deterministic("d", at.constant(1)) pm.Potential("p", at.constant(1))
def __init__(self, vars=None, S=None, proposal_dist=None, lamb=None, scaling=0.001, tune=None, tune_interval=100, model=None, mode=None, **kwargs): model = pm.modelcontext(model) initial_values = model.initial_point() initial_values_size = sum(initial_values[n.name].size for n in model.value_vars) if vars is None: vars = model.cont_vars else: vars = [model.rvs_to_values.get(var, var) for var in vars] vars = pm.inputvars(vars) if S is None: S = np.ones(initial_values_size) if proposal_dist is not None: self.proposal_dist = proposal_dist(S) else: self.proposal_dist = UniformProposal(S) self.scaling = np.atleast_1d(scaling).astype("d") if lamb is None: # default to the optimal lambda for normally distributed targets lamb = 2.38 / np.sqrt(2 * initial_values_size) self.lamb = float(lamb) if tune not in {None, "scaling", "lambda"}: raise ValueError( 'The parameter "tune" must be one of {None, scaling, lambda}') self.tune = tune self.tune_interval = tune_interval self.steps_until_tune = tune_interval self.accepted = 0 self.mode = mode shared = pm.make_shared_replacements(initial_values, vars, model) self.delta_logp = delta_logp(initial_values, model.logpt(), vars, shared) super().__init__(vars, shared)
def _get_priors(self, model=None): """Return prior distributions of the likelihood. Returns ------- dict : mapping name -> pymc distribution """ model = pymc.modelcontext(model) priors = {} for key, val in self.priors.iteritems(): if isinstance(val, numbers.Number): priors[key] = val else: priors[key] = model.Var(val[0], val[1]) return priors
def _get_priors(self, model=None): """Return prior distributions of the likelihood. Returns ------- dict : mapping name -> pymc distribution """ model = pymc.modelcontext(model) priors = {} for key, val in self.priors.iteritems(): if isinstance(val, numbers.Number): priors[key] = val else: priors[key] = model.Var(val[0], val[1]) return priors
def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None): model = pm.modelcontext(model) self.scaling = scaling self.tune = tune self.tune_interval = tune_interval self.steps_until_tune = tune_interval self.accepted = 0 vars = [model.rvs_to_values.get(var, var) for var in vars] if not all([v.dtype in pm.discrete_types for v in vars]): raise ValueError("All variables must be Bernoulli for BinaryMetropolis") super().__init__(vars, [model.fastlogp])
def sample_numpyro_nuts( draws=1000, tune=1000, chains=4, target_accept=0.8, random_seed=10, model=None, var_names=None, progress_bar=True, keep_untransformed=False, chain_method="parallel", ): from numpyro.infer import MCMC, NUTS model = modelcontext(model) if var_names is None: var_names = model.unobserved_value_vars vars_to_sample = list( get_default_varnames(var_names, include_transformed=keep_untransformed)) coords = { cname: np.array(cvals) if isinstance(cvals, tuple) else cvals for cname, cvals in model.coords.items() if cvals is not None } if hasattr(model, "RV_dims"): dims = { var_name: [dim for dim in dims if dim is not None] for var_name, dims in model.RV_dims.items() } else: dims = {} tic1 = pd.Timestamp.now() print("Compiling...", file=sys.stdout) rv_names = [rv.name for rv in model.value_vars] init_state = [model.initial_point[rv_name] for rv_name in rv_names] init_state_batched = jax.tree_map( lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) logp_fn = get_jaxified_logp(model) nuts_kernel = NUTS( potential_fn=logp_fn, target_accept_prob=target_accept, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, ) pmap_numpyro = MCMC( nuts_kernel, num_warmup=tune, num_samples=draws, num_chains=chains, postprocess_fn=None, chain_method=chain_method, progress_bar=progress_bar, ) tic2 = pd.Timestamp.now() print("Compilation time = ", tic2 - tic1, file=sys.stdout) print("Sampling...", file=sys.stdout) seed = jax.random.PRNGKey(random_seed) map_seed = jax.random.split(seed, chains) if chains == 1: init_params = init_state map_seed = seed else: init_params = init_state_batched pmap_numpyro.run( map_seed, init_params=init_params, extra_fields=( "num_steps", "potential_energy", "energy", "adapt_state.step_size", "accept_prob", "diverging", ), ) raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) tic3 = pd.Timestamp.now() print("Sampling time = ", tic3 - tic2, file=sys.stdout) print("Transforming variables...", file=sys.stdout) mcmc_samples = {} for v in vars_to_sample: fgraph = FunctionGraph(model.value_vars, [v], clone=False) optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"]) jax_fn = jax_funcify(fgraph) result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] mcmc_samples[v.name] = result tic4 = pd.Timestamp.now() print("Transformation time = ", tic4 - tic3, file=sys.stdout) posterior = mcmc_samples az_trace = az.from_dict( posterior=posterior, log_likelihood=_get_log_likelihood(model, raw_mcmc_samples), observed_data=find_observations(model), sample_stats=_sample_stats_to_xarray(pmap_numpyro), coords=coords, dims=dims, ) return az_trace
def __init__( self, coarse_models: List[Model], vars: Optional[list] = None, base_sampler="DEMetropolisZ", base_S: Optional = None, base_proposal_dist: Optional[Type[Proposal]] = None, base_scaling: Optional = None, tune: bool = True, base_tune_target: str = "lambda", base_tune_interval: int = 100, base_lamb: Optional = None, base_tune_drop_fraction: float = 0.9, model: Optional[Model] = None, mode: Optional = None, subsampling_rates: List[int] = 5, base_blocked: bool = False, variance_reduction: bool = False, store_Q_fine: bool = False, adaptive_error_model: bool = False, **kwargs, ) -> None: # this variable is used to identify MLDA objects which are # not in the finest level (i.e. child MLDA objects) self.is_child = kwargs.get("is_child", False) if not isinstance(coarse_models, list): raise ValueError( "MLDA step method cannot use coarse_models if it is not a list" ) if len(coarse_models) == 0: raise ValueError("MLDA step method was given an empty " "list of coarse models. Give at least " "one coarse model.") # assign internal state model = pm.modelcontext(model) initial_values = model.compute_initial_point() self.model = model self.coarse_models = coarse_models self.model_below = self.coarse_models[-1] self.num_levels = len(self.coarse_models) + 1 # set up variance reduction. self.variance_reduction = variance_reduction self.store_Q_fine = store_Q_fine # check that certain requirements hold # for the variance reduction feature to work if self.variance_reduction or self.store_Q_fine: if not hasattr(self.model, "Q"): raise AttributeError("Model given to MLDA does not contain" "variable 'Q'. You need to include" "the variable in the model definition" "for variance reduction to work or" "for storing the fine Q." "Use pm.Data() to define it.") if not isinstance(self.model.Q, TensorSharedVariable): raise TypeError( "The variable 'Q' in the model definition is not of type " "'TensorSharedVariable'. Use pm.Data() to define the" "variable.") if self.is_child and self.variance_reduction: # this is the subsampling rate applied to the current level # it is stored in the level above and transferred here self.subsampling_rate_above = kwargs.pop("subsampling_rate_above", None) # set up adaptive error model self.adaptive_error_model = adaptive_error_model # check that certain requirements hold # for the adaptive error model feature to work if self.adaptive_error_model: if not hasattr(self.model_below, "mu_B"): raise AttributeError( "Model below in hierarchy does not contain" "variable 'mu_B'. You need to include" "the variable in the model definition" "for adaptive error model to work." "Use pm.Data() to define it.") if not hasattr(self.model_below, "Sigma_B"): raise AttributeError( "Model below in hierarchy does not contain" "variable 'Sigma_B'. You need to include" "the variable in the model definition" "for adaptive error model to work." "Use pm.Data() to define it.") if not (isinstance(self.model_below.mu_B, TensorSharedVariable) and isinstance(self.model_below.Sigma_B, TensorSharedVariable)): raise TypeError( "At least one of the variables 'mu_B' and 'Sigma_B' " "in the definition of the below model is not of type " "'TensorSharedVariable'. Use pm.Data() to define those " "variables.") # this object is used to recursively update the mean and # variance of the bias correction given new differences # between levels self.bias = RecursiveSampleMoments( self.model_below.mu_B.get_value(), self.model_below.Sigma_B.get_value()) # this list holds the bias objects from all levels # it is gradually constructed when MLDA objects are # created and then shared between all levels self.bias_all = kwargs.pop("bias_all", None) if self.bias_all is None: self.bias_all = [self.bias] else: self.bias_all.append(self.bias) # variables used for adaptive error model self.last_synced_output_diff = None self.adaptation_started = False # set up subsampling rates. if isinstance(subsampling_rates, int): self.subsampling_rates = [subsampling_rates] * len( self.coarse_models) else: if len(subsampling_rates) != len(self.coarse_models): raise ValueError( f"List of subsampling rates needs to have the same " f"length as list of coarse models but the lengths " f"were {len(subsampling_rates)}, {len(self.coarse_models)}" ) self.subsampling_rates = subsampling_rates self.subsampling_rate = self.subsampling_rates[-1] self.subchain_selection = None # set up base sampling self.base_sampler = base_sampler # VR is not compatible with compound base samplers so an automatic conversion # to a block sampler happens here if if self.variance_reduction and self.base_sampler == "Metropolis" and not base_blocked: warnings.warn( "Variance reduction is not compatible with non-blocked (compound) samplers." "Automatically switching to a blocked Metropolis sampler.") self.base_blocked = True else: self.base_blocked = base_blocked self.base_S = base_S self.base_proposal_dist = base_proposal_dist if base_scaling is None: if self.base_sampler == "Metropolis": self.base_scaling = 1.0 else: self.base_scaling = 0.001 else: self.base_scaling = float(base_scaling) self.tune = tune if not self.tune and self.base_sampler == "DEMetropolisZ": raise ValueError( f"The argument tune was set to False while using" f" a 'DEMetropolisZ' base sampler. 'DEMetropolisZ' " f" tune needs to be True.") self.base_tune_target = base_tune_target self.base_tune_interval = base_tune_interval self.base_lamb = base_lamb self.base_tune_drop_fraction = float(base_tune_drop_fraction) self.base_tuning_stats = None self.mode = mode # Process model variables if vars is None: vars = model.value_vars else: vars = [model.rvs_to_values.get(var, var) for var in vars] vars = pm.inputvars(vars) self.vars = vars self.var_names = [var.name for var in self.vars] self.accepted = 0 # Construct Aesara function for current-level model likelihood # (for use in acceptance) shared = pm.make_shared_replacements(initial_values, vars, model) self.delta_logp = delta_logp(initial_values, model.logpt(), vars, shared) # Construct Aesara function for below-level model likelihood # (for use in acceptance) model_below = pm.modelcontext(self.model_below) vars_below = [ var for var in model_below.value_vars if var.name in self.var_names ] vars_below = pm.inputvars(vars_below) shared_below = pm.make_shared_replacements(initial_values, vars_below, model_below) self.delta_logp_below = delta_logp(initial_values, model_below.logpt(), vars_below, shared_below) super().__init__(vars, shared) # initialise complete step method hierarchy if self.num_levels == 2: with self.model_below: # make sure the correct variables are selected from model_below vars_below = [ var for var in self.model_below.value_vars if var.name in self.var_names ] # create kwargs if self.variance_reduction: base_kwargs = { "mlda_subsampling_rate_above": self.subsampling_rate, "mlda_variance_reduction": True, } else: base_kwargs = {} if self.base_sampler == "Metropolis": # MetropolisMLDA sampler in base level (level=0), targeting self.model_below self.step_method_below = pm.MetropolisMLDA( vars=vars_below, proposal_dist=self.base_proposal_dist, S=self.base_S, scaling=self.base_scaling, tune=self.tune, tune_interval=self.base_tune_interval, model=None, mode=self.mode, blocked=self.base_blocked, **base_kwargs, ) else: # DEMetropolisZMLDA sampler in base level (level=0), targeting self.model_below self.step_method_below = pm.DEMetropolisZMLDA( vars=vars_below, S=self.base_S, proposal_dist=self.base_proposal_dist, lamb=self.base_lamb, scaling=self.base_scaling, tune=self.base_tune_target, tune_interval=self.base_tune_interval, tune_drop_fraction=self.base_tune_drop_fraction, model=None, mode=self.mode, **base_kwargs, ) else: # drop the last coarse model coarse_models_below = self.coarse_models[:-1] subsampling_rates_below = self.subsampling_rates[:-1] with self.model_below: # make sure the correct variables are selected from model_below vars_below = [ var for var in self.model_below.value_vars if var.name in self.var_names ] # create kwargs if self.variance_reduction: mlda_kwargs = { "is_child": True, "subsampling_rate_above": self.subsampling_rate, } else: mlda_kwargs = {"is_child": True} if self.adaptive_error_model: mlda_kwargs = { **mlda_kwargs, **{ "bias_all": self.bias_all } } # MLDA sampler in some intermediate level, targeting self.model_below self.step_method_below = pm.MLDA( vars=vars_below, base_S=self.base_S, base_sampler=self.base_sampler, base_proposal_dist=self.base_proposal_dist, base_scaling=self.base_scaling, tune=self.tune, base_tune_target=self.base_tune_target, base_tune_interval=self.base_tune_interval, base_lamb=self.base_lamb, base_tune_drop_fraction=self.base_tune_drop_fraction, model=None, mode=self.mode, subsampling_rates=subsampling_rates_below, coarse_models=coarse_models_below, base_blocked=self.base_blocked, variance_reduction=self.variance_reduction, store_Q_fine=False, adaptive_error_model=self.adaptive_error_model, **mlda_kwargs, ) # instantiate the recursive DA proposal. # this is the main proposal used for # all levels (Recursive Delayed Acceptance) # (except for level 0 where the step method is MetropolisMLDA # or DEMetropolisZMLDA - not MLDA) self.proposal_dist = RecursiveDAProposal(self.step_method_below, self.model_below, self.tune, self.subsampling_rate) # set up data types of stats. if isinstance(self.step_method_below, MLDA): # get the stat types from the level below if that level is MLDA self.stats_dtypes = self.step_method_below.stats_dtypes else: # otherwise, set it up from scratch. self.stats_dtypes = [{ "accept": np.float64, "accepted": bool, "tune": bool }] if isinstance(self.step_method_below, MetropolisMLDA): self.stats_dtypes.append({"base_scaling": np.float64}) elif isinstance(self.step_method_below, DEMetropolisZMLDA): self.stats_dtypes.append({ "base_scaling": np.float64, "base_lambda": np.float64 }) elif isinstance(self.step_method_below, CompoundStep): for method in self.step_method_below.methods: if isinstance(method, MetropolisMLDA): self.stats_dtypes.append({"base_scaling": np.float64}) elif isinstance(method, DEMetropolisZMLDA): self.stats_dtypes.append({ "base_scaling": np.float64, "base_lambda": np.float64 }) # initialise necessary variables for doing variance reduction if self.variance_reduction: self.sub_counter = 0 self.Q_diff = [] if self.is_child: self.Q_reg = [np.nan] * self.subsampling_rate_above if self.num_levels == 2: self.Q_base_full = [] if not self.is_child: for level in range(self.num_levels - 1, 0, -1): self.stats_dtypes[0][f"Q_{level}_{level - 1}"] = object self.stats_dtypes[0]["Q_0"] = object # initialise necessary variables for doing variance reduction or storing fine Q if self.variance_reduction or self.store_Q_fine: self.Q_last = np.nan self.Q_diff_last = np.nan if self.store_Q_fine and not self.is_child: self.stats_dtypes[0][f"Q_{self.num_levels - 1}"] = object
def sample_numpyro_nuts( draws=1000, tune=1000, chains=4, target_accept=0.8, random_seed=10, model=None, progress_bar=True, keep_untransformed=False, ): from numpyro.infer import MCMC, NUTS model = modelcontext(model) tic1 = pd.Timestamp.now() print("Compiling...", file=sys.stdout) rv_names = [rv.name for rv in model.value_vars] init_state = [model.initial_point[rv_name] for rv_name in rv_names] init_state_batched = jax.tree_map( lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) logp_fn = get_jaxified_logp(model) nuts_kernel = NUTS( potential_fn=logp_fn, target_accept_prob=target_accept, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, ) pmap_numpyro = MCMC( nuts_kernel, num_warmup=tune, num_samples=draws, num_chains=chains, postprocess_fn=None, chain_method="parallel", progress_bar=progress_bar, ) tic2 = pd.Timestamp.now() print("Compilation time = ", tic2 - tic1, file=sys.stdout) print("Sampling...", file=sys.stdout) seed = jax.random.PRNGKey(random_seed) map_seed = jax.random.split(seed, chains) pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps", )) raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) tic3 = pd.Timestamp.now() print("Sampling time = ", tic3 - tic2, file=sys.stdout) print("Transforming variables...", file=sys.stdout) mcmc_samples = [] for i, (value_var, raw_samples) in enumerate(zip(model.value_vars, raw_mcmc_samples)): raw_samples = at.constant(np.asarray(raw_samples)) rv = model.values_to_rvs[value_var] transform = getattr(value_var.tag, "transform", None) if transform is not None: # TODO: This will fail when the transformation depends on another variable # such as in interval transform with RVs as edges trans_samples = transform.backward(raw_samples, *rv.owner.inputs) trans_samples.name = rv.name mcmc_samples.append(trans_samples) if keep_untransformed: raw_samples.name = value_var.name mcmc_samples.append(raw_samples) else: raw_samples.name = rv.name mcmc_samples.append(raw_samples) mcmc_varnames = [var.name for var in mcmc_samples] mcmc_samples = compile_rv_inplace( [], mcmc_samples, mode="JAX", )() tic4 = pd.Timestamp.now() print("Transformation time = ", tic4 - tic3, file=sys.stdout) posterior = {k: v for k, v in zip(mcmc_varnames, mcmc_samples)} az_trace = az.from_dict(posterior=posterior) return az_trace
def fit( n=10000, local_rv=None, method="advi", model=None, random_seed=None, start=None, inf_kwargs=None, **kwargs, ): r"""Handy shortcut for using inference methods in functional way Parameters ---------- n: `int` number of iterations local_rv: dict[var->tuple] mapping {model_variable -> approx params} Local Vars are used for Autoencoding Variational Bayes See (AEVB; Kingma and Welling, 2014) for details method: str or :class:`Inference` string name is case insensitive in: - 'advi' for ADVI - 'fullrank_advi' for FullRankADVI - 'svgd' for Stein Variational Gradient Descent - 'asvgd' for Amortized Stein Variational Gradient Descent - 'nfvi' for Normalizing Flow with default `scale-loc` flow - 'nfvi=<formula>' for Normalizing Flow using formula model: :class:`Model` PyMC model for inference random_seed: None or int leave None to use package global RandomStream or other valid value to create instance specific one inf_kwargs: dict additional kwargs passed to :class:`Inference` start: `Point` starting point for inference Other Parameters ---------------- score: bool evaluate loss on each iteration or not callbacks: list[function: (Approximation, losses, i) -> None] calls provided functions after each iteration step progressbar: bool whether to show progressbar or not obj_n_mc: `int` Number of monte carlo samples used for approximation of objective gradients tf_n_mc: `int` Number of monte carlo samples used for approximation of test function gradients obj_optimizer: function (grads, params) -> updates Optimizer that is used for objective params test_optimizer: function (grads, params) -> updates Optimizer that is used for test function params more_obj_params: `list` Add custom params for objective optimizer more_tf_params: `list` Add custom params for test function optimizer more_updates: `dict` Add custom updates to resulting updates total_grad_norm_constraint: `float` Bounds gradient norm, prevents exploding gradient problem fn_kwargs: `dict` Add kwargs to aesara.function (e.g. `{'profile': True}`) more_replacements: `dict` Apply custom replacements before calculating gradients Returns ------- :class:`Approximation` """ if inf_kwargs is None: inf_kwargs = dict() else: inf_kwargs = inf_kwargs.copy() if local_rv is not None: inf_kwargs["local_rv"] = local_rv if random_seed is not None: inf_kwargs["random_seed"] = random_seed if start is not None: inf_kwargs["start"] = start if model is None: model = pm.modelcontext(model) _select = dict(advi=ADVI, fullrank_advi=FullRankADVI, svgd=SVGD, asvgd=ASVGD, nfvi=NFVI) if isinstance(method, str): method = method.lower() if method.startswith("nfvi="): formula = method[5:] inference = NFVI(formula, **inf_kwargs) elif method in _select: inference = _select[method](model=model, **inf_kwargs) else: raise KeyError( f"method should be one of {set(_select.keys())} or Inference instance" ) elif isinstance(method, Inference): inference = method else: raise TypeError( f"method should be one of {set(_select.keys())} or Inference instance" ) return inference.fit(n, **kwargs)
def __init__(self, vars=None, S=None, proposal_dist=None, scaling=1.0, tune=True, tune_interval=100, model=None, mode=None, **kwargs): """Create an instance of a Metropolis stepper Parameters ---------- vars: list List of value variables for sampler S: standard deviation or covariance matrix Some measure of variance to parameterize proposal distribution proposal_dist: function Function that returns zero-mean deviates when parameterized with S (and n). Defaults to normal. scaling: scalar or array Initial scale factor for proposal. Defaults to 1. tune: bool Flag for tuning. Defaults to True. tune_interval: int The frequency of tuning. Defaults to 100 iterations. model: PyMC Model Optional model for sampling step. Defaults to None (taken from context). mode: string or `Mode` instance. compilation mode passed to Aesara functions """ model = pm.modelcontext(model) initial_values = model.initial_point() if vars is None: vars = model.value_vars else: vars = [model.rvs_to_values.get(var, var) for var in vars] vars = pm.inputvars(vars) initial_values_shape = [initial_values[v.name].shape for v in vars] if S is None: S = np.ones(int(sum(np.prod(ivs) for ivs in initial_values_shape))) if proposal_dist is not None: self.proposal_dist = proposal_dist(S) elif S.ndim == 1: self.proposal_dist = NormalProposal(S) elif S.ndim == 2: self.proposal_dist = MultivariateNormalProposal(S) else: raise ValueError("Invalid rank for variance: %s" % S.ndim) self.scaling = np.atleast_1d(scaling).astype("d") self.tune = tune self.tune_interval = tune_interval self.steps_until_tune = tune_interval # Determine type of variables self.discrete = np.concatenate([[v.dtype in pm.discrete_types] * (initial_values[v.name].size or 1) for v in vars]) self.any_discrete = self.discrete.any() self.all_discrete = self.discrete.all() # Metropolis will try to handle one batched dimension at a time This, however, # is not safe for discrete multivariate distributions (looking at you Multinomial), # due to high dependency among the support dimensions. For continuous multivariate # distributions we assume they are being transformed in a way that makes each # dimension semi-independent. is_scalar = len( initial_values_shape) == 1 and initial_values_shape[0] == () self.elemwise_update = not (is_scalar or (self.any_discrete and max( getattr(model.values_to_rvs[var].owner.op, "ndim_supp", 1) for var in vars) > 0)) if self.elemwise_update: dims = int(sum(np.prod(ivs) for ivs in initial_values_shape)) else: dims = 1 self.enum_dims = np.arange(dims, dtype=int) self.accept_rate_iter = np.zeros(dims, dtype=float) self.accepted_iter = np.zeros(dims, dtype=bool) self.accepted_sum = np.zeros(dims, dtype=int) # remember initial settings before tuning so they can be reset self._untuned_settings = dict(scaling=self.scaling, steps_until_tune=tune_interval) # TODO: This is not being used when compiling the logp function! self.mode = mode shared = pm.make_shared_replacements(initial_values, vars, model) self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared) super().__init__(vars, shared)
def sample_blackjax_nuts( draws=1000, tune=1000, chains=4, target_accept=0.8, random_seed=10, initvals=None, model=None, var_names=None, keep_untransformed=False, chain_method="parallel", idata_kwargs=None, ): """ Draw samples from the posterior using the NUTS method from the ``blackjax`` library. Parameters ---------- draws : int, default 1000 The number of samples to draw. The number of tuned samples are discarded by default. tune : int, default 1000 Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the ``draws`` argument. chains : int, default 4 The number of chains to sample. target_accept : float in [0, 1]. The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors. random_seed : int, default 10 Random seed used by the sampling steps. model : Model, optional Model to sample from. The model needs to have free random variables. When inside a ``with`` model context, it defaults to that model, otherwise the model must be passed explicitly. var_names : iterable of str, optional Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior keep_untransformed : bool, default False Include untransformed variables in the posterior samples. Defaults to False. chain_method : str, default "parallel" Specify how samples should be drawn. The choices include "parallel", and "vectorized". idata_kwargs : dict, optional Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value for the ``log_likelihood`` key to indicate that the pointwise log likelihood should not be included in the returned object. Returns ------- InferenceData ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and pointwise log likeihood values (unless skipped with ``idata_kwargs``). """ import blackjax model = modelcontext(model) if var_names is None: var_names = model.unobserved_value_vars vars_to_sample = list( get_default_varnames(var_names, include_transformed=keep_untransformed)) coords = { cname: np.array(cvals) if isinstance(cvals, tuple) else cvals for cname, cvals in model.coords.items() if cvals is not None } if hasattr(model, "RV_dims"): dims = { var_name: [dim for dim in dims if dim is not None] for var_name, dims in model.RV_dims.items() } else: dims = {} tic1 = datetime.now() print("Compiling...", file=sys.stdout) init_params = _get_batched_jittered_initial_points( model=model, chains=chains, initvals=initvals, random_seed=random_seed, ) if chains == 1: init_params = [np.stack(init_params)] init_params = [ np.stack(init_state) for init_state in zip(*init_params) ] logprob_fn = get_jaxified_logp(model) seed = jax.random.PRNGKey(random_seed) keys = jax.random.split(seed, chains) get_posterior_samples = partial( _blackjax_inference_loop, logprob_fn=logprob_fn, tune=tune, draws=draws, target_accept=target_accept, ) tic2 = datetime.now() print("Compilation time = ", tic2 - tic1, file=sys.stdout) print("Sampling...", file=sys.stdout) # Adapted from numpyro if chain_method == "parallel": map_fn = jax.pmap elif chain_method == "vectorized": map_fn = jax.vmap else: raise ValueError( "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"') states, _ = map_fn(get_posterior_samples)(keys, init_params) raw_mcmc_samples = states.position tic3 = datetime.now() print("Sampling time = ", tic3 - tic2, file=sys.stdout) print("Transforming variables...", file=sys.stdout) mcmc_samples = {} for v in vars_to_sample: jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v]) result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] mcmc_samples[v.name] = result tic4 = datetime.now() print("Transformation time = ", tic4 - tic3, file=sys.stdout) if idata_kwargs is None: idata_kwargs = {} else: idata_kwargs = idata_kwargs.copy() if idata_kwargs.pop("log_likelihood", True): log_likelihood = _get_log_likelihood(model, raw_mcmc_samples) else: log_likelihood = None attrs = { "sampling_time": (tic3 - tic2).total_seconds(), } posterior = mcmc_samples az_trace = az.from_dict( posterior=posterior, log_likelihood=log_likelihood, observed_data=find_observations(model), coords=coords, dims=dims, attrs=make_attrs(attrs, library=blackjax), **idata_kwargs, ) return az_trace
def __init__(self, vars=None, S=None, proposal_dist=None, scaling=1.0, tune=True, tune_interval=100, model=None, mode=None, **kwargs): """Create an instance of a Metropolis stepper Parameters ---------- vars: list List of value variables for sampler S: standard deviation or covariance matrix Some measure of variance to parameterize proposal distribution proposal_dist: function Function that returns zero-mean deviates when parameterized with S (and n). Defaults to normal. scaling: scalar or array Initial scale factor for proposal. Defaults to 1. tune: bool Flag for tuning. Defaults to True. tune_interval: int The frequency of tuning. Defaults to 100 iterations. model: PyMC Model Optional model for sampling step. Defaults to None (taken from context). mode: string or `Mode` instance. compilation mode passed to Aesara functions """ model = pm.modelcontext(model) initial_values = model.initial_point() if vars is None: vars = model.value_vars else: vars = [model.rvs_to_values.get(var, var) for var in vars] vars = pm.inputvars(vars) if S is None: S = np.ones(sum(initial_values[v.name].size for v in vars)) if proposal_dist is not None: self.proposal_dist = proposal_dist(S) elif S.ndim == 1: self.proposal_dist = NormalProposal(S) elif S.ndim == 2: self.proposal_dist = MultivariateNormalProposal(S) else: raise ValueError("Invalid rank for variance: %s" % S.ndim) self.scaling = np.atleast_1d(scaling).astype("d") self.tune = tune self.tune_interval = tune_interval self.steps_until_tune = tune_interval self.accepted = 0 # Determine type of variables self.discrete = np.concatenate([[v.dtype in pm.discrete_types] * (initial_values[v.name].size or 1) for v in vars]) self.any_discrete = self.discrete.any() self.all_discrete = self.discrete.all() # remember initial settings before tuning so they can be reset self._untuned_settings = dict(scaling=self.scaling, steps_until_tune=tune_interval, accepted=self.accepted) self.mode = mode shared = pm.make_shared_replacements(initial_values, vars, model) self.delta_logp = delta_logp(initial_values, model.logpt(), vars, shared) super().__init__(vars, shared)
def sample_numpyro_nuts( draws: int = 1000, tune: int = 1000, chains: int = 4, target_accept: float = 0.8, random_seed: int = None, initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, model: Optional[Model] = None, var_names=None, progress_bar: bool = True, keep_untransformed: bool = False, chain_method: str = "parallel", idata_kwargs: Optional[Dict] = None, nuts_kwargs: Optional[Dict] = None, ): """ Draw samples from the posterior using the NUTS method from the ``numpyro`` library. Parameters ---------- draws : int, default 1000 The number of samples to draw. The number of tuned samples are discarded by default. tune : int, default 1000 Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the ``draws`` argument. chains : int, default 4 The number of chains to sample. target_accept : float in [0, 1]. The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors. random_seed : int, default 10 Random seed used by the sampling steps. model : Model, optional Model to sample from. The model needs to have free random variables. When inside a ``with`` model context, it defaults to that model, otherwise the model must be passed explicitly. var_names : iterable of str, optional Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior progress_bar : bool, default True Whether or not to display a progress bar in the command line. The bar shows the percentage of completion, the sampling speed in samples per second (SPS), and the estimated remaining time until completion ("expected time of arrival"; ETA). keep_untransformed : bool, default False Include untransformed variables in the posterior samples. Defaults to False. chain_method : str, default "parallel" Specify how samples should be drawn. The choices include "sequential", "parallel", and "vectorized". idata_kwargs : dict, optional Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value for the ``log_likelihood`` key to indicate that the pointwise log likelihood should not be included in the returned object. Returns ------- InferenceData ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and pointwise log likeihood values (unless skipped with ``idata_kwargs``). """ import numpyro from numpyro.infer import MCMC, NUTS model = modelcontext(model) if var_names is None: var_names = model.unobserved_value_vars vars_to_sample = list( get_default_varnames(var_names, include_transformed=keep_untransformed)) coords = { cname: np.array(cvals) if isinstance(cvals, tuple) else cvals for cname, cvals in model.coords.items() if cvals is not None } if hasattr(model, "RV_dims"): dims = { var_name: [dim for dim in dims if dim is not None] for var_name, dims in model.RV_dims.items() } else: dims = {} if random_seed is None: random_seed = model.rng_seeder.randint(2**30, dtype=np.int64) tic1 = datetime.now() print("Compiling...", file=sys.stdout) init_params = _get_batched_jittered_initial_points( model=model, chains=chains, initvals=initvals, random_seed=random_seed, ) logp_fn = get_jaxified_logp(model, negative_logp=False) if nuts_kwargs is None: nuts_kwargs = {} nuts_kernel = NUTS( potential_fn=logp_fn, target_accept_prob=target_accept, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, **nuts_kwargs, ) pmap_numpyro = MCMC( nuts_kernel, num_warmup=tune, num_samples=draws, num_chains=chains, postprocess_fn=None, chain_method=chain_method, progress_bar=progress_bar, ) tic2 = datetime.now() print("Compilation time = ", tic2 - tic1, file=sys.stdout) print("Sampling...", file=sys.stdout) map_seed = jax.random.PRNGKey(random_seed) if chains > 1: map_seed = jax.random.split(map_seed, chains) pmap_numpyro.run( map_seed, init_params=init_params, extra_fields=( "num_steps", "potential_energy", "energy", "adapt_state.step_size", "accept_prob", "diverging", ), ) raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) tic3 = datetime.now() print("Sampling time = ", tic3 - tic2, file=sys.stdout) print("Transforming variables...", file=sys.stdout) mcmc_samples = {} for v in vars_to_sample: jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v]) result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] mcmc_samples[v.name] = result tic4 = datetime.now() print("Transformation time = ", tic4 - tic3, file=sys.stdout) if idata_kwargs is None: idata_kwargs = {} else: idata_kwargs = idata_kwargs.copy() if idata_kwargs.pop("log_likelihood", True): log_likelihood = _get_log_likelihood(model, raw_mcmc_samples) else: log_likelihood = None attrs = { "sampling_time": (tic3 - tic2).total_seconds(), } posterior = mcmc_samples az_trace = az.from_dict( posterior=posterior, log_likelihood=log_likelihood, observed_data=find_observations(model), sample_stats=_sample_stats_to_xarray(pmap_numpyro), coords=coords, dims=dims, attrs=make_attrs(attrs, library=numpyro), **idata_kwargs, ) return az_trace
def __init__(self, vars=None, S=None, proposal_dist=None, lamb=None, scaling=0.001, tune="lambda", tune_interval=100, tune_drop_fraction: float = 0.9, model=None, mode=None, **kwargs): model = pm.modelcontext(model) initial_values = model.recompute_initial_point() initial_values_size = sum(initial_values[n.name].size for n in model.value_vars) if vars is None: vars = model.cont_vars else: vars = [model.rvs_to_values.get(var, var) for var in vars] vars = pm.inputvars(vars) if S is None: S = np.ones(initial_values_size) if proposal_dist is not None: self.proposal_dist = proposal_dist(S) else: self.proposal_dist = UniformProposal(S) self.scaling = np.atleast_1d(scaling).astype("d") if lamb is None: # default to the optimal lambda for normally distributed targets lamb = 2.38 / np.sqrt(2 * initial_values_size) self.lamb = float(lamb) if tune not in {None, "scaling", "lambda"}: raise ValueError( 'The parameter "tune" must be one of {None, scaling, lambda}') self.tune = True self.tune_target = tune self.tune_interval = tune_interval self.tune_drop_fraction = tune_drop_fraction self.steps_until_tune = tune_interval self.accepted = 0 # cache local history for the Z-proposals self._history = [] # remember initial settings before tuning so they can be reset self._untuned_settings = dict( scaling=self.scaling, lamb=self.lamb, steps_until_tune=tune_interval, accepted=self.accepted, ) self.mode = mode shared = pm.make_shared_replacements(initial_values, vars, model) self.delta_logp = delta_logp(initial_values, model.logpt, vars, shared) super().__init__(vars, shared)