示例#1
0
    def __init__(self, name, model=None, vars=None, test_point=None):
        self.name = name

        model = modelcontext(model)
        self.model = model
        if vars is None:
            vars = model.unobserved_value_vars

        self.vars = vars
        self.varnames = [var.name for var in vars]
        self.fn = model.fastfn(vars)

        # Get variable shapes. Most backends will need this
        # information.
        if test_point is None:
            test_point = model.initial_point
        else:
            test_point_ = model.initial_point.copy()
            test_point_.update(test_point)
            test_point = test_point_
        var_values = list(zip(self.varnames, self.fn(test_point)))
        self.var_shapes = {var: value.shape for var, value in var_values}
        self.var_dtypes = {var: value.dtype for var, value in var_values}
        self.chain = None
        self._is_base_setup = False
        self.sampler_vars = None
        self._warnings = []
示例#2
0
    def _argument_checks(cls, distribution, **kwargs):
        if "observed" in kwargs:
            raise ValueError("Observed Bound distributions are not supported. "
                             "If you want to model truncated data "
                             "you can use a pm.Potential in combination "
                             "with the cumulative probability function.")

        if not isinstance(distribution, TensorVariable):
            raise ValueError(
                "Passing a distribution class to `Bound` is no longer supported.\n"
                "Please pass the output of a distribution instantiated via the "
                "`.dist()` API such as:\n"
                '`pm.Bound("bound", pm.Normal.dist(0, 1), lower=0)`')

        try:
            model = modelcontext(None)
        except TypeError:
            pass
        else:
            if distribution in model.basic_RVs:
                raise ValueError(
                    f"The distribution passed into `Bound` was already registered "
                    f"in the current model.\nYou should pass an unregistered "
                    f"(unnamed) distribution created via the `.dist()` API, such as:\n"
                    f'`pm.Bound("bound", pm.Normal.dist(0, 1), lower=0)`')

        if distribution.owner.op.ndim_supp != 0:
            raise NotImplementedError(
                "Bounding of MultiVariate RVs is not yet supported.")

        if not isinstance(distribution.owner.op, (Discrete, Continuous)):
            raise ValueError(
                f"`distribution` {distribution} must be a Discrete or Continuous"
                " distribution subclass")
示例#3
0
def guess_scaling(point, vars=None, model=None, scaling_bound=1e-8):
    model = modelcontext(model)
    try:
        h = find_hessian_diag(point, vars, model=model)
    except NotImplementedError:
        h = fixed_hessian(point, vars, model=model)
    return adjust_scaling(h, scaling_bound)
示例#4
0
def trace_cov(trace, vars=None, model=None):
    """
    Calculate the flattened covariance matrix using a sample trace

    Useful if you want to base your covariance matrix for further sampling on some initial samples.

    Parameters
    ----------
    trace: Trace
    vars: list
        variables for which to calculate covariance matrix

    Returns
    -------
    r: array (n,n)
        covariance matrix
    """
    model = modelcontext(model)

    if model is not None:
        vars = model.free_RVs
    elif vars is None:
        vars = trace.varnames

    def flat_t(var):
        x = trace[get_var_name(var)]
        return x.reshape((x.shape[0], np.prod(x.shape[1:], dtype=int)))

    return np.cov(np.concatenate(list(map(flat_t, vars)), 1).T)
示例#5
0
def get_steps(
    steps: Optional[Union[int, np.ndarray, TensorVariable]],
    *,
    shape: Optional[Shape] = None,
    dims: Optional[Dims] = None,
    observed: Optional[Any] = None,
    step_shape_offset: int = 0,
):
    """Extract number of steps from shape / dims / observed information

    Parameters
    ----------
    steps:
        User specified steps for timeseries distribution
    shape:
        User specified shape for timeseries distribution
    dims:
        User specified dims for timeseries distribution
    observed:
        User specified observed data from timeseries distribution
    step_shape_offset:
        Difference between last shape dimension and number of steps in timeseries
        distribution, defaults to 0

    Returns
    -------
    steps
        Steps, if specified directly by user, or inferred from the last dimension of
        shape / dims / observed. When two sources of step information are provided,
        a symbolic Assert is added to ensure they are consistent.
    """
    inferred_steps = None
    if shape is not None:
        shape = to_tuple(shape)
        if shape[-1] is not ...:
            inferred_steps = shape[-1] - step_shape_offset

    if inferred_steps is None and dims is not None:
        dims = convert_dims(dims)
        if dims[-1] is not ...:
            model = modelcontext(None)
            inferred_steps = model.dim_lengths[dims[-1]] - step_shape_offset

    if inferred_steps is None and observed is not None:
        observed = convert_observed_data(observed)
        inferred_steps = observed.shape[-1] - step_shape_offset

    if inferred_steps is None:
        inferred_steps = steps
    # If there are two sources of information for the steps, assert they are consistent
    elif steps is not None:
        inferred_steps = Assert(msg="Steps do not match last shape dimension")(
            inferred_steps, at.eq(inferred_steps, steps)
        )
    return inferred_steps
示例#6
0
def compile_pymc(inputs, outputs, mode=None, **kwargs):
    """Use ``aesara.function`` with specialized pymc rewrites always enabled.

    Included rewrites
    -----------------
    random_make_inplace
        Ensures that compiled functions containing random variables will produce new
        samples on each call.
    local_check_parameter_to_ninf_switch
        Replaces Aeppl's CheckParameterValue assertions is logp expressions with Switches
        that return -inf in place of the assert.

    Optional rewrites
    -----------------
    local_remove_check_parameter
        Replaces Aeppl's CheckParameterValue assertions is logp expressions. This is used
        as an alteranative to the default local_check_parameter_to_ninf_switch whenenver
        this function is called within a model context and the model `check_bounds` flag
        is set to False.
    """

    # Avoid circular dependency
    from pymc.distributions import NoDistribution

    # Set the default update of a NoDistribution RNG so that it is automatically
    # updated after every function call
    output_to_list = outputs if isinstance(outputs, list) else [outputs]
    for rv in (
        node
        for node in walk_model(output_to_list, walk_past_rvs=True)
        if node.owner and isinstance(node.owner.op, NoDistribution)
    ):
        rng = rv.owner.inputs[0]
        if not hasattr(rng, "default_update"):
            rng.default_update = rv.owner.outputs[0]

    # If called inside a model context, see if check_bounds flag is set to False
    try:
        from pymc.model import modelcontext

        model = modelcontext(None)
        check_bounds = model.check_bounds
    except TypeError:
        check_bounds = True
    check_parameter_opt = (
        "local_check_parameter_to_ninf_switch" if check_bounds else "local_remove_check_parameter"
    )

    mode = get_mode(mode)
    opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
    mode = Mode(linker=mode.linker, optimizer=opt_qry)
    aesara_function = aesara.function(inputs, outputs, mode=mode, **kwargs)
    return aesara_function
示例#7
0
文件: smc.py 项目: sthagen/pymc3
    def __init__(
        self,
        draws=2000,
        start=None,
        model=None,
        random_seed=None,
        threshold=0.5,
    ):
        """

        Parameters
        ----------
        draws: int
            The number of samples to draw from the posterior (i.e. last stage). And also the number of
            independent chains. Defaults to 2000.
        start: dict, or array of dict
            Starting point in parameter space. It should be a list of dict with length `chains`.
            When None (default) the starting point is sampled from the prior distribution.
        model: Model (optional if in ``with`` context)).
        random_seed: int
            Value used to initialize the random number generator.
        threshold: float
            Determines the change of beta from stage to stage, i.e.indirectly the number of stages,
            the higher the value of `threshold` the higher the number of stages. Defaults to 0.5.
            It should be between 0 and 1.

        """

        self.draws = draws
        self.start = start
        if threshold < 0 or threshold > 1:
            raise ValueError(
                f"Threshold value {threshold} must be between 0 and 1")
        self.threshold = threshold
        self.model = model
        self.rng = np.random.default_rng(seed=random_seed)

        self.model = modelcontext(model)
        self.variables = inputvars(self.model.value_vars)

        self.var_info = {}
        self.tempered_posterior = None
        self.prior_logp = None
        self.likelihood_logp = None
        self.tempered_posterior_logp = None
        self.prior_logp_func = None
        self.likelihood_logp_func = None
        self.log_marginal_likelihood = 0
        self.beta = 0
        self.iteration = 0
        self.resampling_indexes = None
        self.weights = np.ones(self.draws) / self.draws
示例#8
0
    def __init__(
        self, vars, model=None, blocked=True, dtype=None, logp_dlogp_func=None, **aesara_kwargs
    ):
        model = modelcontext(model)

        if logp_dlogp_func is None:
            func = model.logp_dlogp_function(vars, dtype=dtype, **aesara_kwargs)
        else:
            func = logp_dlogp_func

        self._logp_dlogp_func = func

        super().__init__(vars, func._extra_vars_shared, blocked)
示例#9
0
def find_hessian(point, vars=None, model=None):
    """
    Returns Hessian of logp at the point passed.

    Parameters
    ----------
    model: Model (optional if in `with` context)
    point: dict
    vars: list
        Variables for which Hessian is to be calculated.
    """
    model = modelcontext(model)
    H = model.fastd2logp(vars)
    return H(Point(point, filter_model_vars=True, model=model))
示例#10
0
def find_hessian_diag(point, vars=None, model=None):
    """
    Returns Hessian of logp at the point passed.

    Parameters
    ----------
    model: Model (optional if in `with` context)
    point: dict
    vars: list
        Variables for which Hessian is to be calculated.
    """
    model = modelcontext(model)
    H = model.compile_fn(hessian_diag(model.logpt(), vars))
    return H(Point(point, model=model))
示例#11
0
def replace_with_values(vars_needed, replacements=None, model=None):
    R"""
    Replace random variable nodes in the graph with values given by the replacements dict.
    Uses untransformed versions of the inputs, performs some basic input validation.

    Parameters
    ----------
    vars_needed: list of TensorVariables
        A list of variable outputs
    replacements: dict with string keys, numeric values
        The variable name and values to be replaced in the model graph.
    model: Model
        A PyMC model object
    """
    model = modelcontext(model)

    inputs, input_names = [], []
    for rv in walk_model(vars_needed, walk_past_rvs=True):
        if rv in model.named_vars.values() and not isinstance(
                rv, SharedVariable):
            inputs.append(rv)
            input_names.append(rv.name)

    # Then it's deterministic, no inputs are required, can eval and return
    if len(inputs) == 0:
        return tuple(v.eval() for v in vars_needed)

    fn = compile_pymc(
        inputs,
        vars_needed,
        allow_input_downcast=True,
        accept_inplace=True,
        on_unused_input="ignore",
    )

    # Remove unneeded inputs
    replacements = {
        name: val
        for name, val in replacements.items() if name in input_names
    }
    missing = set(input_names) - set(replacements.keys())

    # Error if more inputs are needed
    if len(missing) > 0:
        missing_str = ", ".join(missing)
        raise ValueError(
            f"Values for {missing_str} must be included in `replacements`.")

    return fn(**replacements)
示例#12
0
def check_dist_not_registered(dist, model=None):
    """Check that a dist is not registered in the model already"""
    from pymc.model import modelcontext

    try:
        model = modelcontext(None)
    except TypeError:
        pass
    else:
        if dist in model.basic_RVs:
            raise ValueError(
                f"The dist {dist} was already registered in the current model.\n"
                f"You should use an unregistered (unnamed) distribution created via "
                f"the `.dist()` API instead, such as:\n`dist=pm.Normal.dist(0, 1)`"
            )
示例#13
0
    def __new__(
        cls,
        name,
        distribution,
        lower=None,
        upper=None,
        size=None,
        shape=None,
        initval=None,
        dims=None,
        **kwargs,
    ):

        cls._argument_checks(distribution, **kwargs)

        if dims is not None:
            model = modelcontext(None)
            if dims in model.coords:
                dim_obj = np.asarray(model.coords[dims])
                size = dim_obj.shape
            else:
                raise ValueError(
                    "Given dims do not exist in model coordinates.")

        lower, upper, initval = cls._set_values(lower, upper, size, shape,
                                                initval)
        distribution.tag.ignore_logprob = True

        if isinstance(distribution.owner.op, Continuous):
            res = _ContinuousBounded(
                name,
                [distribution, lower, upper],
                initval=floatX(initval),
                size=size,
                shape=shape,
                **kwargs,
            )
        else:
            res = _DiscreteBounded(
                name,
                [distribution, lower, upper],
                initval=intX(initval),
                size=size,
                shape=shape,
                **kwargs,
            )
        return res
示例#14
0
文件: base.py 项目: wmastersonV/pymc
    def __init__(self, name, model=None, vars=None):
        self.name = name

        model = modelcontext(model)
        self.model = model
        if vars is None:
            vars = model.unobserved_RVs
        self.vars = vars
        self.varnames = [str(var) for var in vars]
        self.fn = model.fastfn(vars)

        ## Get variable shapes. Most backends will need this
        ## information.
        var_values = zip(self.varnames, self.fn(model.test_point))
        self.var_shapes = {var: value.shape
                           for var, value in var_values}
        self.chain = None
示例#15
0
    def __init__(self,
                 vars=None,
                 prior_cov=None,
                 prior_chol=None,
                 model=None,
                 **kwargs):
        self.model = modelcontext(model)
        chol = get_chol(prior_cov, prior_chol)
        self.prior_chol = at.as_tensor_variable(chol)

        if vars is None:
            vars = self.model.cont_vars
        else:
            vars = [self.model.rvs_to_values.get(var, var) for var in vars]
        vars = inputvars(vars)

        super().__init__(vars, [self.model.compile_logp()], **kwargs)
示例#16
0
def point_list_to_multitrace(point_list: List[Dict[str, np.ndarray]],
                             model: Optional[Model] = None) -> MultiTrace:
    """transform point list into MultiTrace"""
    _model = modelcontext(model)
    varnames = list(point_list[0].keys())
    with _model:
        chain = NDArray(model=_model, vars=[_model[vn] for vn in varnames])
        chain.setup(draws=len(point_list), chain=0)

        # since we are simply loading a trace by hand, we need only a vacuous function for
        # chain.record() to use. This crushes the default.
        def point_fun(point):
            return [point[vn] for vn in varnames]

        chain.fn = point_fun
        for point in point_list:
            chain.record(point)
    return MultiTrace([chain])
示例#17
0
    def __new__(cls, *args, **kwargs):
        blocked = kwargs.get("blocked")
        if blocked is None:
            # Try to look up default value from class
            blocked = getattr(cls, "default_blocked", True)
            kwargs["blocked"] = blocked

        model = modelcontext(kwargs.get("model"))
        kwargs.update({"model": model})

        # vars can either be first arg or a kwarg
        if "vars" not in kwargs and len(args) >= 1:
            vars = args[0]
            args = args[1:]
        elif "vars" in kwargs:
            vars = kwargs.pop("vars")
        else:  # Assume all model variables
            vars = model.value_vars

        if not isinstance(vars, (tuple, list)):
            vars = [vars]

        if len(vars) == 0:
            raise ValueError("No free random variables to sample.")

        if not blocked and len(vars) > 1:
            # In this case we create a separate sampler for each var
            # and append them to a CompoundStep
            steps = []
            for var in vars:
                step = super().__new__(cls)
                # If we don't return the instance we have to manually
                # call __init__
                step.__init__([var], *args, **kwargs)
                # Hack for creating the class correctly when unpickling.
                step.__newargs = ([var],) + args, kwargs
                steps.append(step)

            return CompoundStep(steps)
        else:
            step = super().__new__(cls)
            # Hack for creating the class correctly when unpickling.
            step.__newargs = (vars,) + args, kwargs
            return step
示例#18
0
文件: dist_math.py 项目: kc611/pymc3
def bound(logp, *conditions, broadcast_conditions=True):
    """
    Bounds a log probability density with several conditions.
    When conditions are not met, the logp values are replaced by -inf.

    Note that bound should not be used to enforce the logic of the logp under the normal
    support as it can be disabled by the user via check_bounds = False in pm.Model()

    Parameters
    ----------
    logp: float
    *conditions: booleans
    broadcast_conditions: bool (optional, default=True)
        If True, conditions are broadcasted and applied element-wise to each value in logp.
        If False, conditions are collapsed via at.all(). As a consequence the entire logp
        array is either replaced by -inf or unchanged.

        Setting broadcasts_conditions to False is necessary for most (all?) multivariate
        distributions where the dimensions of the conditions do not unambigously match
        that of the logp.

    Returns
    -------
    logp with elements set to -inf where any condition is False
    """

    # If called inside a model context, see if bounds check is disabled
    try:
        from pymc.model import modelcontext

        model = modelcontext(None)
        if not model.check_bounds:
            return logp
    except TypeError:
        pass  # no model found

    if broadcast_conditions:
        alltrue = alltrue_elemwise
    else:
        alltrue = alltrue_scalar

    return at.switch(alltrue(conditions), logp, -np.inf)
示例#19
0
    def __init__(self,
                 vars=None,
                 w=1.0,
                 tune=True,
                 model=None,
                 iter_limit=np.inf,
                 **kwargs):
        self.model = modelcontext(model)
        self.w = w
        self.tune = tune
        self.n_tunes = 0.0
        self.iter_limit = iter_limit

        if vars is None:
            vars = self.model.cont_vars
        else:
            vars = [self.model.rvs_to_values.get(var, var) for var in vars]
        vars = inputvars(vars)

        super().__init__(vars, [self.model.compile_logp()], **kwargs)
示例#20
0
def fixed_hessian(point, vars=None, model=None):
    """
    Returns a fixed Hessian for any chain location.

    Parameters
    ----------
    model: Model (optional if in `with` context)
    point: dict
    vars: list
        Variables for which Hessian is to be calculated.
    """

    model = modelcontext(model)
    if vars is None:
        vars = model.cont_vars
    vars = inputvars(vars)

    point = Point(point, model=model)

    rval = np.ones(DictToArrayBijection.map(point).size) / 10
    return rval
示例#21
0
def test_mixed_contexts():
    modelA = Model()
    modelB = Model()
    with raises((ValueError, TypeError)):
        modelcontext(None)
    with modelA:
        with modelB:
            assert Model.get_context() == modelB
            assert modelcontext(None) == modelB
        assert Model.get_context() == modelA
        assert modelcontext(None) == modelA
    assert Model.get_context(error_if_none=False) is None
    with raises(TypeError):
        Model.get_context(error_if_none=True)
    with raises((ValueError, TypeError)):
        modelcontext(None)
示例#22
0
 def __init__(self, vars=None, Z=None, gamma2=0.1, nu2=1., kernel=None,
              tune=True, tune_interval=100, model=None, dist=None):
     model = modelcontext(model)
     if vars is None:
         vars = model.vars
         
     self.Z = Z
     self.kernel = kernel
     self.gamma2 = gamma2
     self.nu2 = nu2
     self.tune = tune
     
     # empty proposal distribution and last likelihood
     self.q_dist = None
     self.log_target = -np.inf
     
     # statistics for tuning scaling
     self.tune = tune
     self.tune_interval = tune_interval
     self.steps_until_tune = tune_interval
     self.accepted = 0
     
     super(KameleonOracle, self).__init__(vars, [model.fastlogp])
示例#23
0
    def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", model=None):
        _log.warning("BART is experimental. Use with caution.")
        model = modelcontext(model)
        initial_values = model.recompute_initial_point()
        value_bart = inputvars(vars)[0]
        self.bart = model.values_to_rvs[value_bart].owner.op

        self.X = self.bart.X
        self.Y = self.bart.Y
        self.missing_data = np.any(np.isnan(self.X))
        self.m = self.bart.m
        self.alpha = self.bart.alpha
        self.k = self.bart.k
        self.alpha_vec = self.bart.split_prior
        if self.alpha_vec is None:
            self.alpha_vec = np.ones(self.X.shape[1])

        self.init_mean = self.Y.mean()
        # if data is binary
        Y_unique = np.unique(self.Y)
        if Y_unique.size == 2 and np.all(Y_unique == [0, 1]):
            self.mu_std = 6 / (self.k * self.m ** 0.5)
        # maybe we need to check for count data
        else:
            self.mu_std = (2 * self.Y.std()) / (self.k * self.m ** 0.5)

        self.num_observations = self.X.shape[0]
        self.num_variates = self.X.shape[1]
        self.available_predictors = list(range(self.num_variates))

        self.sum_trees = np.full_like(self.Y, self.init_mean).astype(aesara.config.floatX)
        self.a_tree = Tree.init_tree(
            leaf_node_value=self.init_mean / self.m,
            idx_data_points=np.arange(self.num_observations, dtype="int32"),
        )
        self.mean = fast_mean()

        self.normal = NormalSampler()
        self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
        self.ssv = SampleSplittingVariable(self.alpha_vec)

        self.tune = True

        if batch == "auto":
            batch = max(1, int(self.m * 0.1))
            self.batch = (batch, batch)
        else:
            if isinstance(batch, (tuple, list)):
                self.batch = batch
            else:
                self.batch = (batch, batch)

        self.log_num_particles = np.log(num_particles)
        self.indices = list(range(2, num_particles))
        self.len_indices = len(self.indices)
        self.max_stages = max_stages

        shared = make_shared_replacements(initial_values, vars, model)
        self.likelihood_logp = logp(initial_values, [model.datalogpt], vars, shared)
        self.all_particles = []
        for i in range(self.m):
            self.a_tree.leaf_node_value = self.init_mean / self.m
            p = ParticleTree(self.a_tree)
            self.all_particles.append(p)
        self.all_trees = np.array([p.tree for p in self.all_particles])
        super().__init__(vars, shared)
示例#24
0
def compile_pymc(
    inputs,
    outputs,
    random_seed: SeedSequenceSeed = None,
    mode=None,
    **kwargs,
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
    """Use ``aesara.function`` with specialized pymc rewrites always enabled.

    This function also ensures shared RandomState/Generator used by RandomVariables
    in the graph are updated across calls, to ensure independent draws.

    Parameters
    ----------
    inputs: list of TensorVariables, optional
        Inputs of the compiled Aesara function
    outputs: list of TensorVariables, optional
        Outputs of the compiled Aesara function
    random_seed: int, array-like of int or SeedSequence, optional
        Seed used to override any RandomState/Generator shared variables in the graph.
        If not specified, the value of original shared variables will still be overwritten.
    mode: optional
        Aesara mode used to compile the function

    Included rewrites
    -----------------
    random_make_inplace
        Ensures that compiled functions containing random variables will produce new
        samples on each call.
    local_check_parameter_to_ninf_switch
        Replaces Aeppl's CheckParameterValue assertions is logp expressions with Switches
        that return -inf in place of the assert.

    Optional rewrites
    -----------------
    local_remove_check_parameter
        Replaces Aeppl's CheckParameterValue assertions is logp expressions. This is used
        as an alteranative to the default local_check_parameter_to_ninf_switch whenenver
        this function is called within a model context and the model `check_bounds` flag
        is set to False.
    """
    # Create an update mapping of RandomVariable's RNG so that it is automatically
    # updated after every function call
    rng_updates = {}
    output_to_list = outputs if isinstance(outputs,
                                           (list, tuple)) else [outputs]
    for random_var in (
            var for var in vars_between(inputs, output_to_list)
            if var.owner and isinstance(var.owner.op, (
                RandomVariable, MeasurableVariable)) and var not in inputs):
        if isinstance(random_var.owner.op, RandomVariable):
            rng = random_var.owner.inputs[0]
            if not hasattr(rng, "default_update"):
                rng_updates[rng] = random_var.owner.outputs[0]
            else:
                rng_updates[rng] = rng.default_update
        else:
            update_fn = getattr(random_var.owner.op, "update", None)
            if update_fn is not None:
                rng_updates.update(update_fn(random_var.owner))

    # We always reseed random variables as this provides RNGs with no chances of collision
    if rng_updates:
        reseed_rngs(rng_updates.keys(), random_seed)

    # If called inside a model context, see if check_bounds flag is set to False
    try:
        from pymc.model import modelcontext

        model = modelcontext(None)
        check_bounds = model.check_bounds
    except TypeError:
        check_bounds = True
    check_parameter_opt = ("local_check_parameter_to_ninf_switch"
                           if check_bounds else "local_remove_check_parameter")

    mode = get_mode(mode)
    opt_qry = mode.provided_optimizer.including("random_make_inplace",
                                                check_parameter_opt)
    mode = Mode(linker=mode.linker, optimizer=opt_qry)
    aesara_function = aesara.function(
        inputs,
        outputs,
        updates={
            **rng_updates,
            **kwargs.pop("updates", {})
        },
        mode=mode,
        **kwargs,
    )
    return aesara_function
示例#25
0
文件: sgmcmc.py 项目: t-triobox/pymc3
    def __init__(
        self,
        vars=None,
        batch_size=None,
        total_size=None,
        step_size=1.0,
        model=None,
        random_seed=None,
        minibatches=None,
        minibatch_tensors=None,
        **kwargs
    ):
        warnings.warn(EXPERIMENTAL_WARNING)

        model = modelcontext(model)

        if vars is None:
            vars = model.value_vars
        else:
            vars = [model.rvs_to_values.get(var, var) for var in vars]

        vars = inputvars(vars)

        self.model = model
        self.vars = vars
        self.batch_size = batch_size
        self.total_size = total_size
        _value_error(
            total_size != None or batch_size != None,
            "total_size and batch_size of training data have to be specified",
        )
        self.expected_iter = int(total_size / batch_size)

        # set random stream
        self.random = None
        if random_seed is None:
            self.random = at_rng()
        else:
            self.random = at_rng(random_seed)

        self.step_size = step_size

        shared = make_shared_replacements(vars, model)

        self.updates = OrderedDict()
        # XXX: This needs to be refactored
        self.q_size = None  # int(sum(v.dsize for v in self.vars))

        # This seems to be the only place that `Model.flatten` is used.
        # TODO: Why not _actually_ flatten the variables?
        # E.g. `flat_vars = at.concatenate([var.ravel() for var in vars])`
        # or `set_subtensor` the `vars` into a `at.vector`?

        flat_view = model.flatten(vars)
        self.inarray = [flat_view.input]

        self.dlog_prior = prior_dlogp(vars, model, flat_view)
        self.dlogp_elemwise = elemwise_dlogL(vars, model, flat_view)
        # XXX: This needs to be refactored
        self.q_size = None  # int(sum(v.dsize for v in self.vars))

        if minibatch_tensors is not None:
            _check_minibatches(minibatch_tensors, minibatches)
            self.minibatches = minibatches

            # Replace input shared variables with tensors
            def is_shared(t):
                return isinstance(t, aesara.compile.sharedvalue.SharedVariable)

            tensors = [(t.type() if is_shared(t) else t) for t in minibatch_tensors]
            updates = OrderedDict(
                {t: t_ for t, t_ in zip(minibatch_tensors, tensors) if is_shared(t)}
            )
            self.minibatch_tensors = tensors
            self.inarray += self.minibatch_tensors
            self.updates.update(updates)

        self._initialize_values()
        super().__init__(vars, shared)
示例#26
0
def find_MAP(start=None,
             vars=None,
             method="L-BFGS-B",
             return_raw=False,
             include_transformed=True,
             progressbar=True,
             maxeval=5000,
             model=None,
             *args,
             seed: Optional[int] = None,
             **kwargs):
    """Finds the local maximum a posteriori point given a model.

    `find_MAP` should not be used to initialize the NUTS sampler. Simply call
    ``pymc.sample()`` and it will automatically initialize NUTS in a better
    way.

    Parameters
    ----------
    start: `dict` of parameter values (Defaults to `model.initial_point`)
    vars: list
        List of variables to optimize and set to optimum (Defaults to all continuous).
    method: string or callable
        Optimization algorithm (Defaults to 'L-BFGS-B' unless
        discrete variables are specified in `vars`, then
        `Powell` which will perform better).  For instructions on use of a callable,
        refer to SciPy's documentation of `optimize.minimize`.
    return_raw: bool
        Whether to return the full output of scipy.optimize.minimize (Defaults to `False`)
    include_transformed: bool, optional defaults to True
        Flag for reporting automatically transformed variables in addition
        to original variables.
    progressbar: bool, optional defaults to True
        Whether or not to display a progress bar in the command line.
    maxeval: int, optional, defaults to 5000
        The maximum number of times the posterior distribution is evaluated.
    model: Model (optional if in `with` context)
    *args, **kwargs
        Extra args passed to scipy.optimize.minimize

    Notes
    -----
    Older code examples used `find_MAP` to initialize the NUTS sampler,
    but this is not an effective way of choosing starting values for sampling.
    As a result, we have greatly enhanced the initialization of NUTS and
    wrapped it inside ``pymc.sample()`` and you should thus avoid this method.
    """
    model = modelcontext(model)

    if vars is None:
        vars = model.cont_vars
        if not vars:
            raise ValueError("Model has no unobserved continuous variables.")
    vars = inputvars(vars)
    disc_vars = list(typefilter(vars, discrete_types))
    allinmodel(vars, model)
    ipfn = make_initial_point_fn(
        model=model,
        jitter_rvs={},
        return_transformed=True,
        overrides=start,
    )
    if seed is None:
        seed = model.rng_seeder.randint(2**30, dtype=np.int64)
    start = ipfn(seed)
    model.check_start_vals(start)

    x0 = DictToArrayBijection.map(start)

    # TODO: If the mapping is fixed, we can simply create graphs for the
    # mapping and avoid all this bijection overhead
    def logp_func(x):
        return DictToArrayBijection.mapf(model.fastlogp_nojac)(RaveledVars(
            x, x0.point_map_info))

    try:
        # This might be needed for calls to `dlogp_func`
        # start_map_info = tuple((v.name, v.shape, v.dtype) for v in vars)

        def dlogp_func(x):
            return DictToArrayBijection.mapf(model.fastdlogp_nojac(vars))(
                RaveledVars(x, x0.point_map_info))

        compute_gradient = True
    except (AttributeError, NotImplementedError, tg.NullTypeGradError):
        compute_gradient = False

    if disc_vars or not compute_gradient:
        pm._log.warning(
            "Warning: gradient not available." +
            "(E.g. vars contains discrete variables). MAP " +
            "estimates may not be accurate for the default " +
            "parameters. Defaulting to non-gradient minimization " +
            "'Powell'.")
        method = "Powell"

    if compute_gradient:
        cost_func = CostFuncWrapper(maxeval, progressbar, logp_func,
                                    dlogp_func)
    else:
        cost_func = CostFuncWrapper(maxeval, progressbar, logp_func)

    try:
        opt_result = minimize(cost_func,
                              x0.data,
                              method=method,
                              jac=compute_gradient,
                              *args,
                              **kwargs)
        mx0 = opt_result["x"]  # r -> opt_result
    except (KeyboardInterrupt, StopIteration) as e:
        mx0, opt_result = cost_func.previous_x, None
        if isinstance(e, StopIteration):
            pm._log.info(e)
    finally:
        last_v = cost_func.n_eval
        if progressbar:
            assert isinstance(cost_func.progress, ProgressBar)
            cost_func.progress.total = last_v
            cost_func.progress.update(last_v)
            print(file=sys.stdout)

    mx0 = RaveledVars(mx0, x0.point_map_info)

    vars = get_default_varnames(model.unobserved_value_vars,
                                include_transformed)
    mx = {
        var.name: value
        for var, value in zip(
            vars,
            model.fastfn(vars)(DictToArrayBijection.rmap(mx0)))
    }

    if return_raw:
        return mx, opt_result
    else:
        return mx
示例#27
0
def compile_pymc(
    inputs, outputs, mode=None, **kwargs
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
    """Use ``aesara.function`` with specialized pymc rewrites always enabled.

    Included rewrites
    -----------------
    random_make_inplace
        Ensures that compiled functions containing random variables will produce new
        samples on each call.
    local_check_parameter_to_ninf_switch
        Replaces Aeppl's CheckParameterValue assertions is logp expressions with Switches
        that return -inf in place of the assert.

    Optional rewrites
    -----------------
    local_remove_check_parameter
        Replaces Aeppl's CheckParameterValue assertions is logp expressions. This is used
        as an alteranative to the default local_check_parameter_to_ninf_switch whenenver
        this function is called within a model context and the model `check_bounds` flag
        is set to False.
    """
    # Create an update mapping of RandomVariable's RNG so that it is automatically
    # updated after every function call
    # TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph)
    rng_updates = {}
    output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
    for random_var in (
        var
        for var in vars_between(inputs, output_to_list)
        if var.owner
        and isinstance(var.owner.op, (RandomVariable, MeasurableVariable))
        and var not in inputs
    ):
        if isinstance(random_var.owner.op, RandomVariable):
            rng = random_var.owner.inputs[0]
            if not hasattr(rng, "default_update"):
                rng_updates[rng] = random_var.owner.outputs[0]
        else:
            update_fn = getattr(random_var.owner.op, "update", None)
            if update_fn is not None:
                rng_updates.update(update_fn(random_var.owner))

    # If called inside a model context, see if check_bounds flag is set to False
    try:
        from pymc.model import modelcontext

        model = modelcontext(None)
        check_bounds = model.check_bounds
    except TypeError:
        check_bounds = True
    check_parameter_opt = (
        "local_check_parameter_to_ninf_switch" if check_bounds else "local_remove_check_parameter"
    )

    mode = get_mode(mode)
    opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
    mode = Mode(linker=mode.linker, optimizer=opt_qry)
    aesara_function = aesara.function(
        inputs,
        outputs,
        updates={**rng_updates, **kwargs.pop("updates", {})},
        mode=mode,
        **kwargs,
    )
    return aesara_function
示例#28
0
    def __init__(
        self,
        *,
        trace=None,
        prior=None,
        posterior_predictive=None,
        log_likelihood=True,
        predictions=None,
        coords: Optional[CoordSpec] = None,
        dims: Optional[DimSpec] = None,
        model=None,
        save_warmup: Optional[bool] = None,
    ):

        self.save_warmup = rcParams[
            "data.save_warmup"] if save_warmup is None else save_warmup
        self.trace = trace

        # this permits us to get the model from command-line argument or from with model:
        self.model = modelcontext(model)

        self.attrs = None
        if trace is not None:
            self.nchains = trace.nchains if hasattr(trace, "nchains") else 1
            if hasattr(trace.report,
                       "n_draws") and trace.report.n_draws is not None:
                self.ndraws = trace.report.n_draws
                self.attrs = {
                    "sampling_time": trace.report.t_sampling,
                    "tuning_steps": trace.report.n_tune,
                }
            else:
                self.ndraws = len(trace)
                if self.save_warmup:
                    warnings.warn(
                        "Warmup samples will be stored in posterior group and will not be"
                        " excluded from stats and diagnostics."
                        " Do not slice the trace manually before conversion",
                        UserWarning,
                    )
            self.ntune = len(self.trace) - self.ndraws
            self.posterior_trace, self.warmup_trace = self.split_trace()
        else:
            self.nchains = self.ndraws = 0

        self.prior = prior
        self.posterior_predictive = posterior_predictive
        self.log_likelihood = log_likelihood
        self.predictions = predictions

        if all(elem is None
               for elem in (trace, predictions, posterior_predictive, prior)):
            raise ValueError(
                "When constructing InferenceData you must pass at least"
                " one of trace, prior, posterior_predictive or predictions.")

        self.coords = {**self.model.coords, **(coords or {})}
        self.coords = {
            cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
            for cname, cvals in self.coords.items() if cvals is not None
        }

        self.dims = {} if dims is None else dims
        if hasattr(self.model, "RV_dims"):
            model_dims = {
                var_name: [dim for dim in dims if dim is not None]
                for var_name, dims in self.model.RV_dims.items()
            }
            self.dims = {**model_dims, **self.dims}

        self.observations = find_observations(self.model)
示例#29
0
def sample_smc(
    draws=2000,
    kernel=IMH,
    *,
    start=None,
    model=None,
    random_seed=None,
    chains=None,
    cores=None,
    compute_convergence_checks=True,
    return_inferencedata=True,
    idata_kwargs=None,
    progressbar=True,
    **kernel_kwargs,
):
    r"""
    Sequential Monte Carlo based sampling.

    Parameters
    ----------
    draws: int
        The number of samples to draw from the posterior (i.e. last stage). And also the number of
        independent chains. Defaults to 2000.
    kernel: SMC Kernel used. Defaults to pm.smc.IMH (Independent Metropolis Hastings)
    start: dict, or array of dict
        Starting point in parameter space. It should be a list of dict with length `chains`.
        When None (default) the starting point is sampled from the prior distribution.
    model: Model (optional if in ``with`` context)).
    random_seed: int
        random seed
    chains : int
        The number of chains to sample. Running independent chains is important for some
        convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever
        is larger.
    cores : int
        The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
        system.
    compute_convergence_checks : bool
        Whether to compute sampler statistics like Gelman-Rubin and ``effective_n``.
        Defaults to ``True``.
    return_inferencedata : bool, default=True
        Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
        Defaults to ``True``.
    idata_kwargs : dict, optional
        Keyword arguments for :func:`pymc.to_inference_data`
    progressbar : bool, optional default=True
        Whether or not to display a progress bar in the command line.
    **kernel_kwargs: keyword arguments passed to the SMC kernel.
        The default IMH kernel takes the following keywords:
            threshold: float
                Determines the change of beta from stage to stage, i.e. indirectly the number of stages,
                the higher the value of `threshold` the higher the number of stages. Defaults to 0.5.
                It should be between 0 and 1.
            n_steps: int
                The number of steps of each Markov Chain. If ``tune_steps == True`` ``n_steps`` will be used
                for the first stage and for the others it will be determined automatically based on the
                acceptance rate and `p_acc_rate`, the max number of steps is ``n_steps``.
            tune_steps: bool
                Whether to compute the number of steps automatically or not. Defaults to True
            p_acc_rate: float
                Used to compute ``n_steps`` when ``tune_steps == True``. The higher the value of
                ``p_acc_rate`` the higher the number of steps computed automatically. Defaults to 0.85.
                It should be between 0 and 1.
        Keyword arguments for other kernels should be checked in the respective docstrings

    Notes
    -----
    SMC works by moving through successive stages. At each stage the inverse temperature
    :math:`\beta` is increased a little bit (starting from 0 up to 1). When :math:`\beta` = 0
    we have the prior distribution and when :math:`\beta` =1 we have the posterior distribution.
    So in more general terms we are always computing samples from a tempered posterior that we can
    write as:

    .. math::

        p(\theta \mid y)_{\beta} = p(y \mid \theta)^{\beta} p(\theta)

    A summary of the algorithm is:

     1. Initialize :math:`\beta` at zero and stage at zero.
     2. Generate N samples :math:`S_{\beta}` from the prior (because when :math `\beta = 0` the
        tempered posterior is the prior).
     3. Increase :math:`\beta` in order to make the effective sample size equals some predefined
        value (we use :math:`Nt`, where :math:`t` is 0.5 by default).
     4. Compute a set of N importance weights W. The weights are computed as the ratio of the
        likelihoods of a sample at stage i+1 and stage i.
     5. Obtain :math:`S_{w}` by re-sampling according to W.
     6. Use W to compute the mean and covariance for the proposal distribution, a MVNormal.
     7. For stages other than 0 use the acceptance rate from the previous stage to estimate
        `n_steps`.
     8. Run N independent Metropolis-Hastings (IMH) chains (each one of length `n_steps`),
        starting each one from a different sample in :math:`S_{w}`. Samples are IMH as the proposal
        mean is the of the previous posterior stage and not the current point in parameter space.
     9. Repeat from step 3 until :math:`\beta \ge 1`.
     10. The final result is a collection of N samples from the posterior.


    References
    ----------
    .. [Minson2013] Minson, S. E. and Simons, M. and Beck, J. L., (2013),
        Bayesian inversion for finite fault earthquake source models I- Theory and algorithm.
        Geophysical Journal International, 2013, 194(3), pp.1701-1726,
        `link <https://gji.oxfordjournals.org/content/194/3/1701.full>`__

    .. [Ching2007] Ching, J. and Chen, Y. (2007).
        Transitional Markov Chain Monte Carlo Method for Bayesian Model Updating, Model Class
        Selection, and Model Averaging. J. Eng. Mech., 10.1061/(ASCE)0733-9399(2007)133:7(816),
        816-832. `link <http://ascelibrary.org/doi/abs/10.1061/%28ASCE%290733-9399
        %282007%29133:7%28816%29>`__
    """

    if isinstance(kernel, str) and kernel.lower() in ("abc", "metropolis"):
        warnings.warn(
            f'The kernel string argument "{kernel}" in sample_smc has been deprecated. '
            f"It is no longer needed to distinguish between `abc` and `metropolis`",
            FutureWarning,
            stacklevel=2,
        )
        kernel = IMH

    if kernel_kwargs.pop("save_sim_data", None) is not None:
        warnings.warn(
            "save_sim_data has been deprecated. Use pm.sample_posterior_predictive "
            "to obtain the same type of samples.",
            FutureWarning,
            stacklevel=2,
        )

    if kernel_kwargs.pop("save_log_pseudolikelihood", None) is not None:
        warnings.warn(
            "save_log_pseudolikelihood has been deprecated. This information is "
            "now saved as log_likelihood in models with Simulator distributions.",
            FutureWarning,
            stacklevel=2,
        )

    parallel = kernel_kwargs.pop("parallel", None)
    if parallel is not None:
        warnings.warn(
            "The argument parallel is deprecated, use the argument cores instead.",
            FutureWarning,
            stacklevel=2,
        )
        if parallel is False:
            cores = 1

    if cores is None:
        cores = _cpu_count()

    if chains is None:
        chains = max(2, cores)
    else:
        cores = min(chains, cores)

    if random_seed == -1:
        raise FutureWarning(
            f"random_seed should be a non-negative integer or None, got: {random_seed}"
            "This will raise a ValueError in the Future")
        random_seed = None
    if isinstance(random_seed, int) or random_seed is None:
        rng = np.random.default_rng(seed=random_seed)
        random_seed = list(rng.integers(2**30, size=chains))
    elif isinstance(random_seed, Iterable):
        if len(random_seed) != chains:
            raise ValueError(
                f"Length of seeds ({len(seeds)}) must match number of chains {chains}"
            )
    else:
        raise TypeError(
            "Invalid value for `random_seed`. Must be tuple, list, int or None"
        )

    model = modelcontext(model)

    _log = logging.getLogger("pymc")
    _log.info("Initializing SMC sampler...")
    _log.info(f"Sampling {chains} chain{'s' if chains > 1 else ''} "
              f"in {cores} job{'s' if cores > 1 else ''}")

    params = (
        draws,
        kernel,
        start,
        model,
    )

    t1 = time.time()
    if cores > 1:
        pbar = progress_bar((), total=100, display=progressbar)
        pbar.update(0)
        pbars = [pbar] + [None] * (chains - 1)

        pool = mp.Pool(cores)

        # "manually" (de)serialize params before/after multiprocessing
        params = tuple(cloudpickle.dumps(p) for p in params)
        kernel_kwargs = {
            key: cloudpickle.dumps(value)
            for key, value in kernel_kwargs.items()
        }
        results = _starmap_with_kwargs(
            pool,
            _sample_smc_int,
            [(*params, random_seed[chain], chain, pbars[chain])
             for chain in range(chains)],
            repeat(kernel_kwargs),
        )
        results = tuple(cloudpickle.loads(r) for r in results)
        pool.close()
        pool.join()

    else:
        results = []
        pbar = progress_bar((), total=100 * chains, display=progressbar)
        pbar.update(0)
        for chain in range(chains):
            pbar.offset = 100 * chain
            pbar.base_comment = f"Chain: {chain+1}/{chains}"
            results.append(
                _sample_smc_int(*params, random_seed[chain], chain, pbar,
                                **kernel_kwargs))

    (
        traces,
        sample_stats,
        sample_settings,
    ) = zip(*results)

    trace = MultiTrace(traces)
    idata = None

    # Save sample_stats
    _t_sampling = time.time() - t1
    sample_settings_dict = sample_settings[0]
    sample_settings_dict["_t_sampling"] = _t_sampling

    sample_stats_dict = sample_stats[0]
    if chains > 1:
        # Collect the stat values from each chain in a single list
        for stat in sample_stats[0].keys():
            value_list = []
            for chain_sample_stats in sample_stats:
                value_list.append(chain_sample_stats[stat])
            sample_stats_dict[stat] = value_list

    if not return_inferencedata:
        for stat, value in sample_stats_dict.items():
            setattr(trace.report, stat, value)
        for stat, value in sample_settings_dict.items():
            setattr(trace.report, stat, value)
    else:
        for stat, value in sample_stats_dict.items():
            if chains > 1:
                # Different chains might have more iteration steps, leading to a
                # non-square `sample_stats` dataset, we cast as `object` to avoid
                # numpy ragged array deprecation warning
                sample_stats_dict[stat] = np.array(value, dtype=object)
            else:
                sample_stats_dict[stat] = np.array(value)

        sample_stats = dict_to_dataset(
            sample_stats_dict,
            attrs=sample_settings_dict,
            library=pymc,
        )

        ikwargs = dict(model=model)
        if idata_kwargs is not None:
            ikwargs.update(idata_kwargs)
        idata = to_inference_data(trace, **ikwargs)
        idata = InferenceData(**idata, sample_stats=sample_stats)

    if compute_convergence_checks:
        if draws < 100:
            warnings.warn(
                "The number of samples is too small to check convergence reliably.",
                stacklevel=2,
            )
        else:
            if idata is None:
                idata = to_inference_data(trace, log_likelihood=False)
            trace.report._run_convergence_checks(idata, model)
    trace.report._log_summary()

    return idata if return_inferencedata else trace
示例#30
0
    def __init__(
        self,
        vars=None,
        scaling=None,
        step_scale=0.25,
        is_cov=False,
        model=None,
        blocked=True,
        potential=None,
        dtype=None,
        Emax=1000,
        target_accept=0.8,
        gamma=0.05,
        k=0.75,
        t0=10,
        adapt_step_size=True,
        step_rand=None,
        **aesara_kwargs
    ):
        """Set up Hamiltonian samplers with common structures.

        Parameters
        ----------
        vars: list, default=None
            List of Aesara variables. If None, all continuous RVs from the
            model are included.
        scaling: array_like, ndim={1,2}
            Scaling for momentum distribution. 1d arrays interpreted matrix
            diagonal.
        step_scale: float, default=0.25
            Size of steps to take, automatically scaled down by 1/n**(1/4),
            where n is the dimensionality of the parameter space
        is_cov: bool, default=False
            Treat scaling as a covariance matrix/vector if True, else treat
            it as a precision matrix/vector
        model: pymc.Model
        blocked: bool, default=True
        potential: Potential, optional
            An object that represents the Hamiltonian with methods `velocity`,
            `energy`, and `random` methods.
        **aesara_kwargs: passed to Aesara functions
        """
        self._model = modelcontext(model)

        if vars is None:
            vars = self._model.cont_vars
        else:
            vars = [self._model.rvs_to_values.get(var, var) for var in vars]

        super().__init__(vars, blocked=blocked, model=self._model, dtype=dtype, **aesara_kwargs)

        self.adapt_step_size = adapt_step_size
        self.Emax = Emax
        self.iter_count = 0

        # We're using the initial/test point to determine the (initial) step
        # size.
        # XXX: If the dimensions of these terms change, the step size
        # dimension-scaling should change as well, no?
        test_point = self._model.initial_point

        nuts_vars = [test_point[v.name] for v in vars]
        size = sum(v.size for v in nuts_vars)

        self.step_size = step_scale / (size ** 0.25)
        self.step_adapt = step_sizes.DualAverageAdaptation(
            self.step_size, target_accept, gamma, k, t0
        )
        self.target_accept = target_accept
        self.tune = True

        if scaling is None and potential is None:
            mean = floatX(np.zeros(size))
            var = floatX(np.ones(size))
            potential = QuadPotentialDiagAdapt(size, mean, var, 10)

        if isinstance(scaling, dict):
            point = Point(scaling, model=self._model)
            scaling = guess_scaling(point, model=self._model, vars=vars)

        if scaling is not None and potential is not None:
            raise ValueError("Can not specify both potential and scaling.")

        if potential is not None:
            self.potential = potential
        else:
            self.potential = quad_potential(scaling, is_cov)

        self.integrator = integration.CpuLeapfrogIntegrator(self.potential, self._logp_dlogp_func)

        self._step_rand = step_rand
        self._warnings = []
        self._samples_after_tune = 0
        self._num_divs_sample = 0
示例#31
0
    def __init__(
        self,
        *,
        trace=None,
        prior=None,
        posterior_predictive=None,
        log_likelihood=True,
        predictions=None,
        coords: Optional[CoordSpec] = None,
        dims: Optional[DimSpec] = None,
        model=None,
        save_warmup: Optional[bool] = None,
        density_dist_obs: bool = True,
    ):

        self.save_warmup = rcParams[
            "data.save_warmup"] if save_warmup is None else save_warmup
        self.trace = trace

        # this permits us to get the model from command-line argument or from with model:
        self.model = modelcontext(model)

        self.attrs = None
        if trace is not None:
            self.nchains = trace.nchains if hasattr(trace, "nchains") else 1
            if hasattr(trace.report,
                       "n_draws") and trace.report.n_draws is not None:
                self.ndraws = trace.report.n_draws
                self.attrs = {
                    "sampling_time": trace.report.t_sampling,
                    "tuning_steps": trace.report.n_tune,
                }
            else:
                self.ndraws = len(trace)
                if self.save_warmup:
                    warnings.warn(
                        "Warmup samples will be stored in posterior group and will not be"
                        " excluded from stats and diagnostics."
                        " Do not slice the trace manually before conversion",
                        UserWarning,
                    )
            self.ntune = len(self.trace) - self.ndraws
            self.posterior_trace, self.warmup_trace = self.split_trace()
        else:
            self.nchains = self.ndraws = 0

        self.prior = prior
        self.posterior_predictive = posterior_predictive
        self.log_likelihood = log_likelihood
        self.predictions = predictions

        def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
            return next(iter(dct.values()))

        if trace is None:
            # if you have a posterior_predictive built with keep_dims,
            # you'll lose here, but there's nothing I can do about that.
            self.nchains = 1
            get_from = None
            if predictions is not None:
                get_from = predictions
            elif posterior_predictive is not None:
                get_from = posterior_predictive
            elif prior is not None:
                get_from = prior
            if get_from is None:
                # pylint: disable=line-too-long
                raise ValueError(
                    "When constructing InferenceData must have at least"
                    " one of trace, prior, posterior_predictive or predictions."
                )

            aelem = arbitrary_element(get_from)
            self.ndraws = aelem.shape[0]

        self.coords = {**self.model.coords, **(coords or {})}
        self.coords = {
            cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
            for cname, cvals in self.coords.items() if cvals is not None
        }

        self.dims = {} if dims is None else dims
        if hasattr(self.model, "RV_dims"):
            model_dims = {
                var_name: [dim for dim in dims if dim is not None]
                for var_name, dims in self.model.RV_dims.items()
            }
            self.dims = {**model_dims, **self.dims}

        self.density_dist_obs = density_dist_obs
        self.observations = find_observations(self.model)