Example #1
0
 def test_nest_context_works(self):
     with pm.Model() as m:
         new = NewModel()
         with new:
             assert pm.modelcontext(None) is new
         assert pm.modelcontext(None) is m
     assert "v1" in m.named_vars
     assert "v2" in m.named_vars
Example #2
0
    def __init__(self, *args, **kwargs):
        """
        Initialise DEMetropolisZMLDA, uses parent class __init__
        and extra code specific for use within MLDA.
        """

        # flag used for signaling the end of tuning
        self.tuning_end_trigger = False

        model = pm.modelcontext(kwargs.get("model", None))
        initial_values = model.initial_point

        # flag to that variance reduction is activated - forces DEMetropolisZMLDA
        # to store quantities of interest in a register if True
        self.mlda_variance_reduction = kwargs.pop("mlda_variance_reduction",
                                                  False)
        if self.mlda_variance_reduction:
            # Subsampling rate of MLDA sampler one level up
            self.mlda_subsampling_rate_above = kwargs.pop(
                "mlda_subsampling_rate_above")
            self.sub_counter = 0
            self.Q_last = np.nan
            self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above

        # call parent class __init__
        super().__init__(*args, **kwargs)

        # modify the delta function and point to model if VR is used
        if self.mlda_variance_reduction:
            self.model = model
            self.delta_logp_factory = self.delta_logp
            self.delta_logp = lambda q, q0: -self.delta_logp_factory(q0, q)
Example #3
0
def model_to_graphviz(model=None, *, formatting: str = "plain"):
    """Produce a graphviz Digraph from a PyMC model.

    Requires graphviz, which may be installed most easily with
        conda install -c conda-forge python-graphviz

    Alternatively, you may install the `graphviz` binaries yourself,
    and then `pip install graphviz` to get the python bindings.  See
    http://graphviz.readthedocs.io/en/stable/manual.html
    for more information.

    Parameters
    ----------
    model : pm.Model
        The model to plot. Not required when called from inside a modelcontext.
    formatting : str
        one of { "plain" }
    """
    if not "plain" in formatting:
        raise ValueError(
            f"Unsupported formatting for graph nodes: '{formatting}'. See docstring."
        )
    if formatting != "plain":
        warnings.warn(
            "Formattings other than 'plain' are currently not supported.",
            UserWarning,
            stacklevel=2)
    model = pm.modelcontext(model)
    return ModelGraph(model).make_graph(formatting=formatting)
Example #4
0
def subsample(
    draws=1,
    step=None,
    start=None,
    trace=None,
    tune=0,
    model=None,
):
    """
    A stripped down version of sample(), which is called only
    by the RecursiveDAProposal (which is the proposal used in the MLDA
    stepper). RecursiveDAProposal only requires a small set of the input
    parameters and checks normally performed by sample(), and this
    function thus skips some of the code in sampler(). It directly calls
    _iter_sample(), rather than sample_many(). The result is a reduced
    overhead when running multiple levels in MLDA.
    """

    model = pm.modelcontext(model)
    chain = 0
    random_seed = np.random.randint(2**30)
    callback = None

    draws += tune

    sampling = pm.sampling._iter_sample(draws, step, start, trace, chain, tune,
                                        model, random_seed, callback)

    try:
        for it, (trace, _) in enumerate(sampling):
            pass
    except KeyboardInterrupt:
        pass

    return trace
Example #5
0
    def __init__(self, *args, **kwargs):
        """
        Initialise MetropolisMLDA. This is a mix of the parent's class' initialisation
        and some extra code specific for MLDA.
        """
        model = pm.modelcontext(kwargs.get("model", None))
        initial_values = model.compute_initial_point()

        # flag to that variance reduction is activated - forces MetropolisMLDA
        # to store quantities of interest in a register if True
        self.mlda_variance_reduction = kwargs.pop("mlda_variance_reduction",
                                                  False)
        if self.mlda_variance_reduction:
            # Subsampling rate of MLDA sampler one level up
            self.mlda_subsampling_rate_above = kwargs.pop(
                "mlda_subsampling_rate_above")
            self.sub_counter = 0
            self.Q_last = np.nan
            self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above

        # call parent class __init__
        super().__init__(*args, **kwargs)

        # modify the delta function and point to model if VR is used
        if self.mlda_variance_reduction:
            self.model = model
            self.delta_logp_factory = self.delta_logp
            self.delta_logp = lambda q, q0: -self.delta_logp_factory(q0, q)
Example #6
0
def model_to_graphviz(model=None,
                      *,
                      var_names: Optional[Iterable[VarName]] = None,
                      formatting: str = "plain"):
    """Produce a graphviz Digraph from a PyMC model.

    Requires graphviz, which may be installed most easily with
        conda install -c conda-forge python-graphviz

    Alternatively, you may install the `graphviz` binaries yourself,
    and then `pip install graphviz` to get the python bindings.  See
    http://graphviz.readthedocs.io/en/stable/manual.html
    for more information.

    Parameters
    ----------
    model : pm.Model
        The model to plot. Not required when called from inside a modelcontext.
    var_names : iterable of variable names, optional
        Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
    formatting : str, optional
        one of { "plain" }

    Examples
    --------
    How to plot the graph of the model.

    .. code-block:: python

        import numpy as np
        from pymc import HalfCauchy, Model, Normal, model_to_graphviz

        J = 8
        y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
        sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])

        with Model() as schools:

            eta = Normal("eta", 0, 1, shape=J)
            mu = Normal("mu", 0, sigma=1e6)
            tau = HalfCauchy("tau", 25)

            theta = mu + tau * eta

            obs = Normal("obs", theta, sigma=sigma, observed=y)

        model_to_graphviz(schools)
    """
    if not "plain" in formatting:
        raise ValueError(
            f"Unsupported formatting for graph nodes: '{formatting}'. See docstring."
        )
    if formatting != "plain":
        warnings.warn(
            "Formattings other than 'plain' are currently not supported.",
            UserWarning,
            stacklevel=2)
    model = pm.modelcontext(model)
    return ModelGraph(model).make_graph(var_names=var_names,
                                        formatting=formatting)
Example #7
0
    def __init__(self, vars, order="random", transit_p=0.8, model=None):

        model = pm.modelcontext(model)

        # transition probabilities
        self.transit_p = transit_p

        initial_point = model.initial_point()
        vars = [model.rvs_to_values.get(var, var) for var in vars]
        self.dim = sum(initial_point[v.name].size for v in vars)

        if order == "random":
            self.shuffle_dims = True
            self.order = list(range(self.dim))
        else:
            if sorted(order) != list(range(self.dim)):
                raise ValueError("Argument 'order' has to be a permutation")
            self.shuffle_dims = False
            self.order = order

        if not all([v.dtype in pm.discrete_types for v in vars]):
            raise ValueError(
                "All variables must be binary for BinaryGibbsMetropolis")

        super().__init__(vars, [model.compile_logp()])
Example #8
0
    def __init__(self, vars, proposal="uniform", order="random", model=None):

        model = pm.modelcontext(model)

        vars = [model.rvs_to_values.get(var, var) for var in vars]
        vars = pm.inputvars(vars)

        initial_point = model.initial_point()

        dimcats = []
        # The above variable is a list of pairs (aggregate dimension, number
        # of categories). For example, if vars = [x, y] with x being a 2-D
        # variable with M categories and y being a 3-D variable with N
        # categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)].
        for v in vars:

            v_init_val = initial_point[v.name]

            rv_var = model.values_to_rvs[v]
            distr = getattr(rv_var.owner, "op", None)

            if isinstance(distr, CategoricalRV):
                k_graph = rv_var.owner.inputs[3].shape[-1]
                (k_graph, ), _ = rvs_to_value_vars((k_graph, ),
                                                   apply_transforms=True)
                k = model.compile_fn(k_graph,
                                     inputs=model.value_vars,
                                     on_unused_input="ignore")(initial_point)
            elif isinstance(distr, BernoulliRV):
                k = 2
            else:
                raise ValueError(
                    "All variables must be categorical or binary" +
                    "for CategoricalGibbsMetropolis")
            start = len(dimcats)
            dimcats += [(dim, k)
                        for dim in range(start, start + v_init_val.size)]

        if order == "random":
            self.shuffle_dims = True
            self.dimcats = dimcats
        else:
            if sorted(order) != list(range(len(dimcats))):
                raise ValueError("Argument 'order' has to be a permutation")
            self.shuffle_dims = False
            self.dimcats = [dimcats[j] for j in order]

        if proposal == "uniform":
            self.astep = self.astep_unif
        elif proposal == "proportional":
            # Use the optimized "Metropolized Gibbs Sampler" described in Liu96.
            self.astep = self.astep_prop
        else:
            raise ValueError(
                "Argument 'proposal' should either be 'uniform' or 'proportional'"
            )

        super().__init__(vars, [model.compile_logp()])
Example #9
0
 def __init__(self, name="", model=None):
     super().__init__(name, model)
     assert pm.modelcontext(None) is self
     # 1) init variables with Var method
     self.register_rv(pm.Normal.dist(), "v1")
     self.v2 = pm.Normal("v2", mu=0, sigma=1)
     # 2) Potentials and Deterministic variables with method too
     # be sure that names will not overlap with other same models
     pm.Deterministic("d", at.constant(1))
     pm.Potential("p", at.constant(1))
Example #10
0
    def __init__(self,
                 vars=None,
                 S=None,
                 proposal_dist=None,
                 lamb=None,
                 scaling=0.001,
                 tune=None,
                 tune_interval=100,
                 model=None,
                 mode=None,
                 **kwargs):

        model = pm.modelcontext(model)
        initial_values = model.initial_point()
        initial_values_size = sum(initial_values[n.name].size
                                  for n in model.value_vars)

        if vars is None:
            vars = model.cont_vars
        else:
            vars = [model.rvs_to_values.get(var, var) for var in vars]
        vars = pm.inputvars(vars)

        if S is None:
            S = np.ones(initial_values_size)

        if proposal_dist is not None:
            self.proposal_dist = proposal_dist(S)
        else:
            self.proposal_dist = UniformProposal(S)

        self.scaling = np.atleast_1d(scaling).astype("d")
        if lamb is None:
            # default to the optimal lambda for normally distributed targets
            lamb = 2.38 / np.sqrt(2 * initial_values_size)
        self.lamb = float(lamb)
        if tune not in {None, "scaling", "lambda"}:
            raise ValueError(
                'The parameter "tune" must be one of {None, scaling, lambda}')
        self.tune = tune
        self.tune_interval = tune_interval
        self.steps_until_tune = tune_interval
        self.accepted = 0

        self.mode = mode

        shared = pm.make_shared_replacements(initial_values, vars, model)
        self.delta_logp = delta_logp(initial_values, model.logpt(), vars,
                                     shared)
        super().__init__(vars, shared)
Example #11
0
    def _get_priors(self, model=None):
        """Return prior distributions of the likelihood.

        Returns
        -------
        dict : mapping name -> pymc distribution
        """
        model = pymc.modelcontext(model)
        priors = {}
        for key, val in self.priors.iteritems():
            if isinstance(val, numbers.Number):
                priors[key] = val
            else:
                priors[key] = model.Var(val[0], val[1])

        return priors
Example #12
0
    def _get_priors(self, model=None):
        """Return prior distributions of the likelihood.

        Returns
        -------
        dict : mapping name -> pymc distribution
        """
        model = pymc.modelcontext(model)
        priors = {}
        for key, val in self.priors.iteritems():
            if isinstance(val, numbers.Number):
                priors[key] = val
            else:
                priors[key] = model.Var(val[0], val[1])

        return priors
Example #13
0
    def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):

        model = pm.modelcontext(model)

        self.scaling = scaling
        self.tune = tune
        self.tune_interval = tune_interval
        self.steps_until_tune = tune_interval
        self.accepted = 0

        vars = [model.rvs_to_values.get(var, var) for var in vars]

        if not all([v.dtype in pm.discrete_types for v in vars]):
            raise ValueError("All variables must be Bernoulli for BinaryMetropolis")

        super().__init__(vars, [model.fastlogp])
Example #14
0
def sample_numpyro_nuts(
    draws=1000,
    tune=1000,
    chains=4,
    target_accept=0.8,
    random_seed=10,
    model=None,
    var_names=None,
    progress_bar=True,
    keep_untransformed=False,
    chain_method="parallel",
):
    from numpyro.infer import MCMC, NUTS

    model = modelcontext(model)

    if var_names is None:
        var_names = model.unobserved_value_vars

    vars_to_sample = list(
        get_default_varnames(var_names,
                             include_transformed=keep_untransformed))

    coords = {
        cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
        for cname, cvals in model.coords.items() if cvals is not None
    }

    if hasattr(model, "RV_dims"):
        dims = {
            var_name: [dim for dim in dims if dim is not None]
            for var_name, dims in model.RV_dims.items()
        }
    else:
        dims = {}

    tic1 = pd.Timestamp.now()
    print("Compiling...", file=sys.stdout)

    rv_names = [rv.name for rv in model.value_vars]
    init_state = [model.initial_point[rv_name] for rv_name in rv_names]
    init_state_batched = jax.tree_map(
        lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

    logp_fn = get_jaxified_logp(model)

    nuts_kernel = NUTS(
        potential_fn=logp_fn,
        target_accept_prob=target_accept,
        adapt_step_size=True,
        adapt_mass_matrix=True,
        dense_mass=False,
    )

    pmap_numpyro = MCMC(
        nuts_kernel,
        num_warmup=tune,
        num_samples=draws,
        num_chains=chains,
        postprocess_fn=None,
        chain_method=chain_method,
        progress_bar=progress_bar,
    )

    tic2 = pd.Timestamp.now()
    print("Compilation time = ", tic2 - tic1, file=sys.stdout)

    print("Sampling...", file=sys.stdout)

    seed = jax.random.PRNGKey(random_seed)
    map_seed = jax.random.split(seed, chains)

    if chains == 1:
        init_params = init_state
        map_seed = seed
    else:
        init_params = init_state_batched

    pmap_numpyro.run(
        map_seed,
        init_params=init_params,
        extra_fields=(
            "num_steps",
            "potential_energy",
            "energy",
            "adapt_state.step_size",
            "accept_prob",
            "diverging",
        ),
    )

    raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)

    tic3 = pd.Timestamp.now()
    print("Sampling time = ", tic3 - tic2, file=sys.stdout)

    print("Transforming variables...", file=sys.stdout)
    mcmc_samples = {}
    for v in vars_to_sample:
        fgraph = FunctionGraph(model.value_vars, [v], clone=False)
        optimize_graph(fgraph,
                       include=["fast_run"],
                       exclude=["cxx_only", "BlasOpt"])
        jax_fn = jax_funcify(fgraph)
        result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
        mcmc_samples[v.name] = result

    tic4 = pd.Timestamp.now()
    print("Transformation time = ", tic4 - tic3, file=sys.stdout)

    posterior = mcmc_samples
    az_trace = az.from_dict(
        posterior=posterior,
        log_likelihood=_get_log_likelihood(model, raw_mcmc_samples),
        observed_data=find_observations(model),
        sample_stats=_sample_stats_to_xarray(pmap_numpyro),
        coords=coords,
        dims=dims,
    )

    return az_trace
Example #15
0
    def __init__(
        self,
        coarse_models: List[Model],
        vars: Optional[list] = None,
        base_sampler="DEMetropolisZ",
        base_S: Optional = None,
        base_proposal_dist: Optional[Type[Proposal]] = None,
        base_scaling: Optional = None,
        tune: bool = True,
        base_tune_target: str = "lambda",
        base_tune_interval: int = 100,
        base_lamb: Optional = None,
        base_tune_drop_fraction: float = 0.9,
        model: Optional[Model] = None,
        mode: Optional = None,
        subsampling_rates: List[int] = 5,
        base_blocked: bool = False,
        variance_reduction: bool = False,
        store_Q_fine: bool = False,
        adaptive_error_model: bool = False,
        **kwargs,
    ) -> None:

        # this variable is used to identify MLDA objects which are
        # not in the finest level (i.e. child MLDA objects)
        self.is_child = kwargs.get("is_child", False)

        if not isinstance(coarse_models, list):
            raise ValueError(
                "MLDA step method cannot use coarse_models if it is not a list"
            )
        if len(coarse_models) == 0:
            raise ValueError("MLDA step method was given an empty "
                             "list of coarse models. Give at least "
                             "one coarse model.")

        # assign internal state
        model = pm.modelcontext(model)
        initial_values = model.compute_initial_point()
        self.model = model
        self.coarse_models = coarse_models
        self.model_below = self.coarse_models[-1]
        self.num_levels = len(self.coarse_models) + 1

        # set up variance reduction.
        self.variance_reduction = variance_reduction
        self.store_Q_fine = store_Q_fine

        # check that certain requirements hold
        # for the variance reduction feature to work
        if self.variance_reduction or self.store_Q_fine:
            if not hasattr(self.model, "Q"):
                raise AttributeError("Model given to MLDA does not contain"
                                     "variable 'Q'. You need to include"
                                     "the variable in the model definition"
                                     "for variance reduction to work or"
                                     "for storing the fine Q."
                                     "Use pm.Data() to define it.")
            if not isinstance(self.model.Q, TensorSharedVariable):
                raise TypeError(
                    "The variable 'Q' in the model definition is not of type "
                    "'TensorSharedVariable'. Use pm.Data() to define the"
                    "variable.")

        if self.is_child and self.variance_reduction:
            # this is the subsampling rate applied to the current level
            # it is stored in the level above and transferred here
            self.subsampling_rate_above = kwargs.pop("subsampling_rate_above",
                                                     None)

        # set up adaptive error model
        self.adaptive_error_model = adaptive_error_model

        # check that certain requirements hold
        # for the adaptive error model feature to work
        if self.adaptive_error_model:
            if not hasattr(self.model_below, "mu_B"):
                raise AttributeError(
                    "Model below in hierarchy does not contain"
                    "variable 'mu_B'. You need to include"
                    "the variable in the model definition"
                    "for adaptive error model to work."
                    "Use pm.Data() to define it.")
            if not hasattr(self.model_below, "Sigma_B"):
                raise AttributeError(
                    "Model below in hierarchy does not contain"
                    "variable 'Sigma_B'. You need to include"
                    "the variable in the model definition"
                    "for adaptive error model to work."
                    "Use pm.Data() to define it.")
            if not (isinstance(self.model_below.mu_B, TensorSharedVariable)
                    and isinstance(self.model_below.Sigma_B,
                                   TensorSharedVariable)):
                raise TypeError(
                    "At least one of the variables 'mu_B' and 'Sigma_B' "
                    "in the definition of the below model is not of type "
                    "'TensorSharedVariable'. Use pm.Data() to define those "
                    "variables.")

            # this object is used to recursively update the mean and
            # variance of the bias correction given new differences
            # between levels
            self.bias = RecursiveSampleMoments(
                self.model_below.mu_B.get_value(),
                self.model_below.Sigma_B.get_value())

            # this list holds the bias objects from all levels
            # it is gradually constructed when MLDA objects are
            # created and then shared between all levels
            self.bias_all = kwargs.pop("bias_all", None)
            if self.bias_all is None:
                self.bias_all = [self.bias]
            else:
                self.bias_all.append(self.bias)

            # variables used for adaptive error model
            self.last_synced_output_diff = None
            self.adaptation_started = False

        # set up subsampling rates.
        if isinstance(subsampling_rates, int):
            self.subsampling_rates = [subsampling_rates] * len(
                self.coarse_models)
        else:
            if len(subsampling_rates) != len(self.coarse_models):
                raise ValueError(
                    f"List of subsampling rates needs to have the same "
                    f"length as list of coarse models but the lengths "
                    f"were {len(subsampling_rates)}, {len(self.coarse_models)}"
                )
            self.subsampling_rates = subsampling_rates

        self.subsampling_rate = self.subsampling_rates[-1]
        self.subchain_selection = None

        # set up base sampling
        self.base_sampler = base_sampler

        # VR is not compatible with compound base samplers so an automatic conversion
        # to a block sampler happens here if
        if self.variance_reduction and self.base_sampler == "Metropolis" and not base_blocked:
            warnings.warn(
                "Variance reduction is not compatible with non-blocked (compound) samplers."
                "Automatically switching to a blocked Metropolis sampler.")
            self.base_blocked = True
        else:
            self.base_blocked = base_blocked

        self.base_S = base_S
        self.base_proposal_dist = base_proposal_dist

        if base_scaling is None:
            if self.base_sampler == "Metropolis":
                self.base_scaling = 1.0
            else:
                self.base_scaling = 0.001
        else:
            self.base_scaling = float(base_scaling)

        self.tune = tune
        if not self.tune and self.base_sampler == "DEMetropolisZ":
            raise ValueError(
                f"The argument tune was set to False while using"
                f" a 'DEMetropolisZ' base sampler. 'DEMetropolisZ' "
                f" tune needs to be True.")

        self.base_tune_target = base_tune_target
        self.base_tune_interval = base_tune_interval
        self.base_lamb = base_lamb
        self.base_tune_drop_fraction = float(base_tune_drop_fraction)
        self.base_tuning_stats = None

        self.mode = mode

        # Process model variables
        if vars is None:
            vars = model.value_vars
        else:
            vars = [model.rvs_to_values.get(var, var) for var in vars]
        vars = pm.inputvars(vars)
        self.vars = vars
        self.var_names = [var.name for var in self.vars]

        self.accepted = 0

        # Construct Aesara function for current-level model likelihood
        # (for use in acceptance)
        shared = pm.make_shared_replacements(initial_values, vars, model)
        self.delta_logp = delta_logp(initial_values, model.logpt(), vars,
                                     shared)

        # Construct Aesara function for below-level model likelihood
        # (for use in acceptance)
        model_below = pm.modelcontext(self.model_below)
        vars_below = [
            var for var in model_below.value_vars if var.name in self.var_names
        ]
        vars_below = pm.inputvars(vars_below)
        shared_below = pm.make_shared_replacements(initial_values, vars_below,
                                                   model_below)
        self.delta_logp_below = delta_logp(initial_values, model_below.logpt(),
                                           vars_below, shared_below)

        super().__init__(vars, shared)

        # initialise complete step method hierarchy
        if self.num_levels == 2:
            with self.model_below:
                # make sure the correct variables are selected from model_below
                vars_below = [
                    var for var in self.model_below.value_vars
                    if var.name in self.var_names
                ]

                # create kwargs
                if self.variance_reduction:
                    base_kwargs = {
                        "mlda_subsampling_rate_above": self.subsampling_rate,
                        "mlda_variance_reduction": True,
                    }
                else:
                    base_kwargs = {}

                if self.base_sampler == "Metropolis":
                    # MetropolisMLDA sampler in base level (level=0), targeting self.model_below
                    self.step_method_below = pm.MetropolisMLDA(
                        vars=vars_below,
                        proposal_dist=self.base_proposal_dist,
                        S=self.base_S,
                        scaling=self.base_scaling,
                        tune=self.tune,
                        tune_interval=self.base_tune_interval,
                        model=None,
                        mode=self.mode,
                        blocked=self.base_blocked,
                        **base_kwargs,
                    )
                else:
                    # DEMetropolisZMLDA sampler in base level (level=0), targeting self.model_below
                    self.step_method_below = pm.DEMetropolisZMLDA(
                        vars=vars_below,
                        S=self.base_S,
                        proposal_dist=self.base_proposal_dist,
                        lamb=self.base_lamb,
                        scaling=self.base_scaling,
                        tune=self.base_tune_target,
                        tune_interval=self.base_tune_interval,
                        tune_drop_fraction=self.base_tune_drop_fraction,
                        model=None,
                        mode=self.mode,
                        **base_kwargs,
                    )
        else:
            # drop the last coarse model
            coarse_models_below = self.coarse_models[:-1]
            subsampling_rates_below = self.subsampling_rates[:-1]

            with self.model_below:
                # make sure the correct variables are selected from model_below
                vars_below = [
                    var for var in self.model_below.value_vars
                    if var.name in self.var_names
                ]

                # create kwargs
                if self.variance_reduction:
                    mlda_kwargs = {
                        "is_child": True,
                        "subsampling_rate_above": self.subsampling_rate,
                    }
                else:
                    mlda_kwargs = {"is_child": True}
                if self.adaptive_error_model:
                    mlda_kwargs = {
                        **mlda_kwargs,
                        **{
                            "bias_all": self.bias_all
                        }
                    }

                # MLDA sampler in some intermediate level, targeting self.model_below
                self.step_method_below = pm.MLDA(
                    vars=vars_below,
                    base_S=self.base_S,
                    base_sampler=self.base_sampler,
                    base_proposal_dist=self.base_proposal_dist,
                    base_scaling=self.base_scaling,
                    tune=self.tune,
                    base_tune_target=self.base_tune_target,
                    base_tune_interval=self.base_tune_interval,
                    base_lamb=self.base_lamb,
                    base_tune_drop_fraction=self.base_tune_drop_fraction,
                    model=None,
                    mode=self.mode,
                    subsampling_rates=subsampling_rates_below,
                    coarse_models=coarse_models_below,
                    base_blocked=self.base_blocked,
                    variance_reduction=self.variance_reduction,
                    store_Q_fine=False,
                    adaptive_error_model=self.adaptive_error_model,
                    **mlda_kwargs,
                )

        # instantiate the recursive DA proposal.
        # this is the main proposal used for
        # all levels (Recursive Delayed Acceptance)
        # (except for level 0 where the step method is MetropolisMLDA
        # or DEMetropolisZMLDA - not MLDA)
        self.proposal_dist = RecursiveDAProposal(self.step_method_below,
                                                 self.model_below, self.tune,
                                                 self.subsampling_rate)

        # set up data types of stats.
        if isinstance(self.step_method_below, MLDA):
            # get the stat types from the level below if that level is MLDA
            self.stats_dtypes = self.step_method_below.stats_dtypes

        else:
            # otherwise, set it up from scratch.
            self.stats_dtypes = [{
                "accept": np.float64,
                "accepted": bool,
                "tune": bool
            }]

            if isinstance(self.step_method_below, MetropolisMLDA):
                self.stats_dtypes.append({"base_scaling": np.float64})
            elif isinstance(self.step_method_below, DEMetropolisZMLDA):
                self.stats_dtypes.append({
                    "base_scaling": np.float64,
                    "base_lambda": np.float64
                })
            elif isinstance(self.step_method_below, CompoundStep):
                for method in self.step_method_below.methods:
                    if isinstance(method, MetropolisMLDA):
                        self.stats_dtypes.append({"base_scaling": np.float64})
                    elif isinstance(method, DEMetropolisZMLDA):
                        self.stats_dtypes.append({
                            "base_scaling": np.float64,
                            "base_lambda": np.float64
                        })

        # initialise necessary variables for doing variance reduction
        if self.variance_reduction:
            self.sub_counter = 0
            self.Q_diff = []
            if self.is_child:
                self.Q_reg = [np.nan] * self.subsampling_rate_above
            if self.num_levels == 2:
                self.Q_base_full = []
            if not self.is_child:
                for level in range(self.num_levels - 1, 0, -1):
                    self.stats_dtypes[0][f"Q_{level}_{level - 1}"] = object
                self.stats_dtypes[0]["Q_0"] = object

        # initialise necessary variables for doing variance reduction or storing fine Q
        if self.variance_reduction or self.store_Q_fine:
            self.Q_last = np.nan
            self.Q_diff_last = np.nan
        if self.store_Q_fine and not self.is_child:
            self.stats_dtypes[0][f"Q_{self.num_levels - 1}"] = object
Example #16
0
def sample_numpyro_nuts(
    draws=1000,
    tune=1000,
    chains=4,
    target_accept=0.8,
    random_seed=10,
    model=None,
    progress_bar=True,
    keep_untransformed=False,
):
    from numpyro.infer import MCMC, NUTS

    model = modelcontext(model)

    tic1 = pd.Timestamp.now()
    print("Compiling...", file=sys.stdout)

    rv_names = [rv.name for rv in model.value_vars]
    init_state = [model.initial_point[rv_name] for rv_name in rv_names]
    init_state_batched = jax.tree_map(
        lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

    logp_fn = get_jaxified_logp(model)

    nuts_kernel = NUTS(
        potential_fn=logp_fn,
        target_accept_prob=target_accept,
        adapt_step_size=True,
        adapt_mass_matrix=True,
        dense_mass=False,
    )

    pmap_numpyro = MCMC(
        nuts_kernel,
        num_warmup=tune,
        num_samples=draws,
        num_chains=chains,
        postprocess_fn=None,
        chain_method="parallel",
        progress_bar=progress_bar,
    )

    tic2 = pd.Timestamp.now()
    print("Compilation time = ", tic2 - tic1, file=sys.stdout)

    print("Sampling...", file=sys.stdout)

    seed = jax.random.PRNGKey(random_seed)
    map_seed = jax.random.split(seed, chains)

    pmap_numpyro.run(map_seed,
                     init_params=init_state_batched,
                     extra_fields=("num_steps", ))
    raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)

    tic3 = pd.Timestamp.now()
    print("Sampling time = ", tic3 - tic2, file=sys.stdout)

    print("Transforming variables...", file=sys.stdout)
    mcmc_samples = []
    for i, (value_var,
            raw_samples) in enumerate(zip(model.value_vars, raw_mcmc_samples)):
        raw_samples = at.constant(np.asarray(raw_samples))

        rv = model.values_to_rvs[value_var]
        transform = getattr(value_var.tag, "transform", None)

        if transform is not None:
            # TODO: This will fail when the transformation depends on another variable
            #  such as in interval transform with RVs as edges
            trans_samples = transform.backward(raw_samples, *rv.owner.inputs)
            trans_samples.name = rv.name
            mcmc_samples.append(trans_samples)

            if keep_untransformed:
                raw_samples.name = value_var.name
                mcmc_samples.append(raw_samples)
        else:
            raw_samples.name = rv.name
            mcmc_samples.append(raw_samples)

    mcmc_varnames = [var.name for var in mcmc_samples]
    mcmc_samples = compile_rv_inplace(
        [],
        mcmc_samples,
        mode="JAX",
    )()

    tic4 = pd.Timestamp.now()
    print("Transformation time = ", tic4 - tic3, file=sys.stdout)

    posterior = {k: v for k, v in zip(mcmc_varnames, mcmc_samples)}
    az_trace = az.from_dict(posterior=posterior)

    return az_trace
Example #17
0
def fit(
    n=10000,
    local_rv=None,
    method="advi",
    model=None,
    random_seed=None,
    start=None,
    inf_kwargs=None,
    **kwargs,
):
    r"""Handy shortcut for using inference methods in functional way

    Parameters
    ----------
    n: `int`
        number of iterations
    local_rv: dict[var->tuple]
        mapping {model_variable -> approx params}
        Local Vars are used for Autoencoding Variational Bayes
        See (AEVB; Kingma and Welling, 2014) for details
    method: str or :class:`Inference`
        string name is case insensitive in:

        -   'advi'  for ADVI
        -   'fullrank_advi'  for FullRankADVI
        -   'svgd'  for Stein Variational Gradient Descent
        -   'asvgd'  for Amortized Stein Variational Gradient Descent
        -   'nfvi'  for Normalizing Flow with default `scale-loc` flow
        -   'nfvi=<formula>'  for Normalizing Flow using formula

    model: :class:`Model`
        PyMC model for inference
    random_seed: None or int
        leave None to use package global RandomStream or other
        valid value to create instance specific one
    inf_kwargs: dict
        additional kwargs passed to :class:`Inference`
    start: `Point`
        starting point for inference

    Other Parameters
    ----------------
    score: bool
            evaluate loss on each iteration or not
    callbacks: list[function: (Approximation, losses, i) -> None]
        calls provided functions after each iteration step
    progressbar: bool
        whether to show progressbar or not
    obj_n_mc: `int`
        Number of monte carlo samples used for approximation of objective gradients
    tf_n_mc: `int`
        Number of monte carlo samples used for approximation of test function gradients
    obj_optimizer: function (grads, params) -> updates
        Optimizer that is used for objective params
    test_optimizer: function (grads, params) -> updates
        Optimizer that is used for test function params
    more_obj_params: `list`
        Add custom params for objective optimizer
    more_tf_params: `list`
        Add custom params for test function optimizer
    more_updates: `dict`
        Add custom updates to resulting updates
    total_grad_norm_constraint: `float`
        Bounds gradient norm, prevents exploding gradient problem
    fn_kwargs: `dict`
        Add kwargs to aesara.function (e.g. `{'profile': True}`)
    more_replacements: `dict`
        Apply custom replacements before calculating gradients

    Returns
    -------
    :class:`Approximation`
    """
    if inf_kwargs is None:
        inf_kwargs = dict()
    else:
        inf_kwargs = inf_kwargs.copy()
    if local_rv is not None:
        inf_kwargs["local_rv"] = local_rv
    if random_seed is not None:
        inf_kwargs["random_seed"] = random_seed
    if start is not None:
        inf_kwargs["start"] = start
    if model is None:
        model = pm.modelcontext(model)
    _select = dict(advi=ADVI,
                   fullrank_advi=FullRankADVI,
                   svgd=SVGD,
                   asvgd=ASVGD,
                   nfvi=NFVI)
    if isinstance(method, str):
        method = method.lower()
        if method.startswith("nfvi="):
            formula = method[5:]
            inference = NFVI(formula, **inf_kwargs)
        elif method in _select:

            inference = _select[method](model=model, **inf_kwargs)
        else:
            raise KeyError(
                f"method should be one of {set(_select.keys())} or Inference instance"
            )
    elif isinstance(method, Inference):
        inference = method
    else:
        raise TypeError(
            f"method should be one of {set(_select.keys())} or Inference instance"
        )
    return inference.fit(n, **kwargs)
Example #18
0
    def __init__(self,
                 vars=None,
                 S=None,
                 proposal_dist=None,
                 scaling=1.0,
                 tune=True,
                 tune_interval=100,
                 model=None,
                 mode=None,
                 **kwargs):
        """Create an instance of a Metropolis stepper

        Parameters
        ----------
        vars: list
            List of value variables for sampler
        S: standard deviation or covariance matrix
            Some measure of variance to parameterize proposal distribution
        proposal_dist: function
            Function that returns zero-mean deviates when parameterized with
            S (and n). Defaults to normal.
        scaling: scalar or array
            Initial scale factor for proposal. Defaults to 1.
        tune: bool
            Flag for tuning. Defaults to True.
        tune_interval: int
            The frequency of tuning. Defaults to 100 iterations.
        model: PyMC Model
            Optional model for sampling step. Defaults to None (taken from context).
        mode: string or `Mode` instance.
            compilation mode passed to Aesara functions
        """

        model = pm.modelcontext(model)
        initial_values = model.initial_point()

        if vars is None:
            vars = model.value_vars
        else:
            vars = [model.rvs_to_values.get(var, var) for var in vars]

        vars = pm.inputvars(vars)

        initial_values_shape = [initial_values[v.name].shape for v in vars]
        if S is None:
            S = np.ones(int(sum(np.prod(ivs) for ivs in initial_values_shape)))

        if proposal_dist is not None:
            self.proposal_dist = proposal_dist(S)
        elif S.ndim == 1:
            self.proposal_dist = NormalProposal(S)
        elif S.ndim == 2:
            self.proposal_dist = MultivariateNormalProposal(S)
        else:
            raise ValueError("Invalid rank for variance: %s" % S.ndim)

        self.scaling = np.atleast_1d(scaling).astype("d")
        self.tune = tune
        self.tune_interval = tune_interval
        self.steps_until_tune = tune_interval

        # Determine type of variables
        self.discrete = np.concatenate([[v.dtype in pm.discrete_types] *
                                        (initial_values[v.name].size or 1)
                                        for v in vars])
        self.any_discrete = self.discrete.any()
        self.all_discrete = self.discrete.all()

        # Metropolis will try to handle one batched dimension at a time This, however,
        # is not safe for discrete multivariate distributions (looking at you Multinomial),
        # due to high dependency among the support dimensions. For continuous multivariate
        # distributions we assume they are being transformed in a way that makes each
        # dimension semi-independent.
        is_scalar = len(
            initial_values_shape) == 1 and initial_values_shape[0] == ()
        self.elemwise_update = not (is_scalar or (self.any_discrete and max(
            getattr(model.values_to_rvs[var].owner.op, "ndim_supp", 1)
            for var in vars) > 0))
        if self.elemwise_update:
            dims = int(sum(np.prod(ivs) for ivs in initial_values_shape))
        else:
            dims = 1
        self.enum_dims = np.arange(dims, dtype=int)
        self.accept_rate_iter = np.zeros(dims, dtype=float)
        self.accepted_iter = np.zeros(dims, dtype=bool)
        self.accepted_sum = np.zeros(dims, dtype=int)

        # remember initial settings before tuning so they can be reset
        self._untuned_settings = dict(scaling=self.scaling,
                                      steps_until_tune=tune_interval)

        # TODO: This is not being used when compiling the logp function!
        self.mode = mode

        shared = pm.make_shared_replacements(initial_values, vars, model)
        self.delta_logp = delta_logp(initial_values, model.logp(), vars,
                                     shared)
        super().__init__(vars, shared)
Example #19
0
def sample_blackjax_nuts(
    draws=1000,
    tune=1000,
    chains=4,
    target_accept=0.8,
    random_seed=10,
    initvals=None,
    model=None,
    var_names=None,
    keep_untransformed=False,
    chain_method="parallel",
    idata_kwargs=None,
):
    """
    Draw samples from the posterior using the NUTS method from the ``blackjax`` library.

    Parameters
    ----------
    draws : int, default 1000
        The number of samples to draw. The number of tuned samples are discarded by default.
    tune : int, default 1000
        Number of iterations to tune. Samplers adjust the step sizes, scalings or
        similar during tuning. Tuning samples will be drawn in addition to the number specified in
        the ``draws`` argument.
    chains : int, default 4
        The number of chains to sample.
    target_accept : float in [0, 1].
        The step size is tuned such that we approximate this acceptance rate. Higher values like
        0.9 or 0.95 often work better for problematic posteriors.
    random_seed : int, default 10
        Random seed used by the sampling steps.
    model : Model, optional
        Model to sample from. The model needs to have free random variables. When inside a ``with`` model
        context, it defaults to that model, otherwise the model must be passed explicitly.
    var_names : iterable of str, optional
        Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior
    keep_untransformed : bool, default False
        Include untransformed variables in the posterior samples. Defaults to False.
    chain_method : str, default "parallel"
        Specify how samples should be drawn. The choices include "parallel", and "vectorized".
    idata_kwargs : dict, optional
        Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
        for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
        not be included in the returned object.

    Returns
    -------
    InferenceData
        ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and
        pointwise log likeihood values (unless skipped with ``idata_kwargs``).
    """
    import blackjax

    model = modelcontext(model)

    if var_names is None:
        var_names = model.unobserved_value_vars

    vars_to_sample = list(
        get_default_varnames(var_names,
                             include_transformed=keep_untransformed))

    coords = {
        cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
        for cname, cvals in model.coords.items() if cvals is not None
    }

    if hasattr(model, "RV_dims"):
        dims = {
            var_name: [dim for dim in dims if dim is not None]
            for var_name, dims in model.RV_dims.items()
        }
    else:
        dims = {}

    tic1 = datetime.now()
    print("Compiling...", file=sys.stdout)

    init_params = _get_batched_jittered_initial_points(
        model=model,
        chains=chains,
        initvals=initvals,
        random_seed=random_seed,
    )

    if chains == 1:
        init_params = [np.stack(init_params)]
        init_params = [
            np.stack(init_state) for init_state in zip(*init_params)
        ]

    logprob_fn = get_jaxified_logp(model)

    seed = jax.random.PRNGKey(random_seed)
    keys = jax.random.split(seed, chains)

    get_posterior_samples = partial(
        _blackjax_inference_loop,
        logprob_fn=logprob_fn,
        tune=tune,
        draws=draws,
        target_accept=target_accept,
    )

    tic2 = datetime.now()
    print("Compilation time = ", tic2 - tic1, file=sys.stdout)

    print("Sampling...", file=sys.stdout)

    # Adapted from numpyro
    if chain_method == "parallel":
        map_fn = jax.pmap
    elif chain_method == "vectorized":
        map_fn = jax.vmap
    else:
        raise ValueError(
            "Only supporting the following methods to draw chains:"
            ' "parallel" or "vectorized"')

    states, _ = map_fn(get_posterior_samples)(keys, init_params)
    raw_mcmc_samples = states.position

    tic3 = datetime.now()
    print("Sampling time = ", tic3 - tic2, file=sys.stdout)

    print("Transforming variables...", file=sys.stdout)
    mcmc_samples = {}
    for v in vars_to_sample:
        jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v])
        result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
        mcmc_samples[v.name] = result

    tic4 = datetime.now()
    print("Transformation time = ", tic4 - tic3, file=sys.stdout)

    if idata_kwargs is None:
        idata_kwargs = {}
    else:
        idata_kwargs = idata_kwargs.copy()

    if idata_kwargs.pop("log_likelihood", True):
        log_likelihood = _get_log_likelihood(model, raw_mcmc_samples)
    else:
        log_likelihood = None

    attrs = {
        "sampling_time": (tic3 - tic2).total_seconds(),
    }

    posterior = mcmc_samples
    az_trace = az.from_dict(
        posterior=posterior,
        log_likelihood=log_likelihood,
        observed_data=find_observations(model),
        coords=coords,
        dims=dims,
        attrs=make_attrs(attrs, library=blackjax),
        **idata_kwargs,
    )

    return az_trace
Example #20
0
    def __init__(self,
                 vars=None,
                 S=None,
                 proposal_dist=None,
                 scaling=1.0,
                 tune=True,
                 tune_interval=100,
                 model=None,
                 mode=None,
                 **kwargs):
        """Create an instance of a Metropolis stepper

        Parameters
        ----------
        vars: list
            List of value variables for sampler
        S: standard deviation or covariance matrix
            Some measure of variance to parameterize proposal distribution
        proposal_dist: function
            Function that returns zero-mean deviates when parameterized with
            S (and n). Defaults to normal.
        scaling: scalar or array
            Initial scale factor for proposal. Defaults to 1.
        tune: bool
            Flag for tuning. Defaults to True.
        tune_interval: int
            The frequency of tuning. Defaults to 100 iterations.
        model: PyMC Model
            Optional model for sampling step. Defaults to None (taken from context).
        mode: string or `Mode` instance.
            compilation mode passed to Aesara functions
        """

        model = pm.modelcontext(model)
        initial_values = model.initial_point()

        if vars is None:
            vars = model.value_vars
        else:
            vars = [model.rvs_to_values.get(var, var) for var in vars]
        vars = pm.inputvars(vars)

        if S is None:
            S = np.ones(sum(initial_values[v.name].size for v in vars))

        if proposal_dist is not None:
            self.proposal_dist = proposal_dist(S)
        elif S.ndim == 1:
            self.proposal_dist = NormalProposal(S)
        elif S.ndim == 2:
            self.proposal_dist = MultivariateNormalProposal(S)
        else:
            raise ValueError("Invalid rank for variance: %s" % S.ndim)

        self.scaling = np.atleast_1d(scaling).astype("d")
        self.tune = tune
        self.tune_interval = tune_interval
        self.steps_until_tune = tune_interval
        self.accepted = 0

        # Determine type of variables
        self.discrete = np.concatenate([[v.dtype in pm.discrete_types] *
                                        (initial_values[v.name].size or 1)
                                        for v in vars])
        self.any_discrete = self.discrete.any()
        self.all_discrete = self.discrete.all()

        # remember initial settings before tuning so they can be reset
        self._untuned_settings = dict(scaling=self.scaling,
                                      steps_until_tune=tune_interval,
                                      accepted=self.accepted)

        self.mode = mode

        shared = pm.make_shared_replacements(initial_values, vars, model)
        self.delta_logp = delta_logp(initial_values, model.logpt(), vars,
                                     shared)
        super().__init__(vars, shared)
Example #21
0
def sample_numpyro_nuts(
    draws: int = 1000,
    tune: int = 1000,
    chains: int = 4,
    target_accept: float = 0.8,
    random_seed: int = None,
    initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
    model: Optional[Model] = None,
    var_names=None,
    progress_bar: bool = True,
    keep_untransformed: bool = False,
    chain_method: str = "parallel",
    idata_kwargs: Optional[Dict] = None,
    nuts_kwargs: Optional[Dict] = None,
):
    """
    Draw samples from the posterior using the NUTS method from the ``numpyro`` library.

    Parameters
    ----------
    draws : int, default 1000
        The number of samples to draw. The number of tuned samples are discarded by default.
    tune : int, default 1000
        Number of iterations to tune. Samplers adjust the step sizes, scalings or
        similar during tuning. Tuning samples will be drawn in addition to the number specified in
        the ``draws`` argument.
    chains : int, default 4
        The number of chains to sample.
    target_accept : float in [0, 1].
        The step size is tuned such that we approximate this acceptance rate. Higher values like
        0.9 or 0.95 often work better for problematic posteriors.
    random_seed : int, default 10
        Random seed used by the sampling steps.
    model : Model, optional
        Model to sample from. The model needs to have free random variables. When inside a ``with`` model
        context, it defaults to that model, otherwise the model must be passed explicitly.
    var_names : iterable of str, optional
        Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior
    progress_bar : bool, default True
        Whether or not to display a progress bar in the command line. The bar shows the percentage
        of completion, the sampling speed in samples per second (SPS), and the estimated remaining
        time until completion ("expected time of arrival"; ETA).
    keep_untransformed : bool, default False
        Include untransformed variables in the posterior samples. Defaults to False.
    chain_method : str, default "parallel"
        Specify how samples should be drawn. The choices include "sequential", "parallel", and "vectorized".
    idata_kwargs : dict, optional
        Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
        for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
        not be included in the returned object.

    Returns
    -------
    InferenceData
        ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and
        pointwise log likeihood values (unless skipped with ``idata_kwargs``).
    """

    import numpyro

    from numpyro.infer import MCMC, NUTS

    model = modelcontext(model)

    if var_names is None:
        var_names = model.unobserved_value_vars

    vars_to_sample = list(
        get_default_varnames(var_names,
                             include_transformed=keep_untransformed))

    coords = {
        cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
        for cname, cvals in model.coords.items() if cvals is not None
    }

    if hasattr(model, "RV_dims"):
        dims = {
            var_name: [dim for dim in dims if dim is not None]
            for var_name, dims in model.RV_dims.items()
        }
    else:
        dims = {}

    if random_seed is None:
        random_seed = model.rng_seeder.randint(2**30, dtype=np.int64)

    tic1 = datetime.now()
    print("Compiling...", file=sys.stdout)

    init_params = _get_batched_jittered_initial_points(
        model=model,
        chains=chains,
        initvals=initvals,
        random_seed=random_seed,
    )

    logp_fn = get_jaxified_logp(model, negative_logp=False)

    if nuts_kwargs is None:
        nuts_kwargs = {}
    nuts_kernel = NUTS(
        potential_fn=logp_fn,
        target_accept_prob=target_accept,
        adapt_step_size=True,
        adapt_mass_matrix=True,
        dense_mass=False,
        **nuts_kwargs,
    )

    pmap_numpyro = MCMC(
        nuts_kernel,
        num_warmup=tune,
        num_samples=draws,
        num_chains=chains,
        postprocess_fn=None,
        chain_method=chain_method,
        progress_bar=progress_bar,
    )

    tic2 = datetime.now()
    print("Compilation time = ", tic2 - tic1, file=sys.stdout)

    print("Sampling...", file=sys.stdout)

    map_seed = jax.random.PRNGKey(random_seed)
    if chains > 1:
        map_seed = jax.random.split(map_seed, chains)

    pmap_numpyro.run(
        map_seed,
        init_params=init_params,
        extra_fields=(
            "num_steps",
            "potential_energy",
            "energy",
            "adapt_state.step_size",
            "accept_prob",
            "diverging",
        ),
    )

    raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)

    tic3 = datetime.now()
    print("Sampling time = ", tic3 - tic2, file=sys.stdout)

    print("Transforming variables...", file=sys.stdout)
    mcmc_samples = {}
    for v in vars_to_sample:
        jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v])
        result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
        mcmc_samples[v.name] = result

    tic4 = datetime.now()
    print("Transformation time = ", tic4 - tic3, file=sys.stdout)

    if idata_kwargs is None:
        idata_kwargs = {}
    else:
        idata_kwargs = idata_kwargs.copy()

    if idata_kwargs.pop("log_likelihood", True):
        log_likelihood = _get_log_likelihood(model, raw_mcmc_samples)
    else:
        log_likelihood = None

    attrs = {
        "sampling_time": (tic3 - tic2).total_seconds(),
    }

    posterior = mcmc_samples
    az_trace = az.from_dict(
        posterior=posterior,
        log_likelihood=log_likelihood,
        observed_data=find_observations(model),
        sample_stats=_sample_stats_to_xarray(pmap_numpyro),
        coords=coords,
        dims=dims,
        attrs=make_attrs(attrs, library=numpyro),
        **idata_kwargs,
    )

    return az_trace
Example #22
0
    def __init__(self,
                 vars=None,
                 S=None,
                 proposal_dist=None,
                 lamb=None,
                 scaling=0.001,
                 tune="lambda",
                 tune_interval=100,
                 tune_drop_fraction: float = 0.9,
                 model=None,
                 mode=None,
                 **kwargs):
        model = pm.modelcontext(model)
        initial_values = model.recompute_initial_point()
        initial_values_size = sum(initial_values[n.name].size
                                  for n in model.value_vars)

        if vars is None:
            vars = model.cont_vars
        else:
            vars = [model.rvs_to_values.get(var, var) for var in vars]
        vars = pm.inputvars(vars)

        if S is None:
            S = np.ones(initial_values_size)

        if proposal_dist is not None:
            self.proposal_dist = proposal_dist(S)
        else:
            self.proposal_dist = UniformProposal(S)

        self.scaling = np.atleast_1d(scaling).astype("d")
        if lamb is None:
            # default to the optimal lambda for normally distributed targets
            lamb = 2.38 / np.sqrt(2 * initial_values_size)
        self.lamb = float(lamb)
        if tune not in {None, "scaling", "lambda"}:
            raise ValueError(
                'The parameter "tune" must be one of {None, scaling, lambda}')
        self.tune = True
        self.tune_target = tune
        self.tune_interval = tune_interval
        self.tune_drop_fraction = tune_drop_fraction
        self.steps_until_tune = tune_interval
        self.accepted = 0

        # cache local history for the Z-proposals
        self._history = []
        # remember initial settings before tuning so they can be reset
        self._untuned_settings = dict(
            scaling=self.scaling,
            lamb=self.lamb,
            steps_until_tune=tune_interval,
            accepted=self.accepted,
        )

        self.mode = mode

        shared = pm.make_shared_replacements(initial_values, vars, model)
        self.delta_logp = delta_logp(initial_values, model.logpt, vars, shared)
        super().__init__(vars, shared)