def test_density_dist_default_moment_multivariate(with_random, size): def _random(mu, rng=None, size=None): return rng.normal(mu, scale=1, size=to_tuple(size) + mu.shape) if with_random: random = _random else: random = None mu_val = np.random.normal(loc=2, scale=1, size=5).astype(aesara.config.floatX) with pm.Model(): mu = pm.Normal("mu", size=5) a = pm.DensityDist("a", mu, random=random, ndims_params=[1], ndim_supp=1, size=size) if with_random: evaled_moment = get_moment(a).eval({mu: mu_val}) assert evaled_moment.shape == to_tuple(size) + (5, ) assert np.all(evaled_moment == 0) else: with pytest.raises( TypeError, match= "Cannot safely infer the size of a multivariate random variable's moment.", ): evaled_moment = get_moment(a).eval({mu: mu_val})
def test_basic(self): # Standard distributions rv = pm.Normal.dist(mu=2.3) np.testing.assert_allclose(get_moment(rv).eval(), 2.3) # Special distributions rv = pm.Flat.dist() assert get_moment(rv).eval() == np.zeros(()) rv = pm.HalfFlat.dist() assert get_moment(rv).eval() == np.ones(()) rv = pm.Flat.dist(size=(2, 4)) assert np.all(get_moment(rv).eval() == np.zeros((2, 4))) rv = pm.HalfFlat.dist(size=(2, 4)) assert np.all(get_moment(rv).eval() == np.ones((2, 4)))
def get_moment_marginal_mixture(op, rv, rng, weights, *components): ndim_supp = components[0].owner.op.ndim_supp weights = at.shape_padright(weights, ndim_supp) mix_axis = -ndim_supp - 1 if len(components) == 1: moment_components = get_moment(components[0]) else: moment_components = at.stack( [get_moment(component) for component in components], axis=mix_axis, ) moment = at.sum(weights * moment_components, axis=mix_axis) if components[0].dtype in discrete_types: moment = at.round(moment) return moment
def test_moment_from_dims(self, rv_cls): with pm.Model( coords={ "year": [2019, 2020, 2021, 2022], "city": ["Bonn", "Paris", "Lisbon"], }): rv = rv_cls("rv", dims=("year", "city")) assert not hasattr(rv.tag, "test_value") assert tuple(get_moment(rv).shape.eval()) == (4, 3) pass
def test_density_dist_custom_moment_multivariate(size): def moment(rv, size, mu): return (at.ones(size)[..., None] * mu).astype(rv.dtype) mu_val = np.random.normal(loc=2, scale=1, size=5).astype(aesara.config.floatX) with pm.Model(): mu = pm.Normal("mu", size=5) a = pm.DensityDist("a", mu, get_moment=moment, ndims_params=[1], ndim_supp=1, size=size) evaled_moment = get_moment(a).eval({mu: mu_val}) assert evaled_moment.shape == to_tuple(size) + (5,) assert np.all(evaled_moment == mu_val)
def test_density_dist_custom_moment_univariate(size): def moment(rv, size, mu): return (at.ones(size) * mu).astype(rv.dtype) mu_val = np.array(np.random.normal(loc=2, scale=1)).astype(aesara.config.floatX) with pm.Model(): mu = pm.Normal("mu") a = pm.DensityDist("a", mu, get_moment=moment, size=size) evaled_moment = get_moment(a).eval({mu: mu_val}) assert evaled_moment.shape == to_tuple(size) assert np.all(evaled_moment == mu_val)
def make_initial_point_expression( *, free_rvs: Sequence[TensorVariable], rvs_to_values: Dict[TensorVariable, TensorVariable], initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]], jitter_rvs: Set[TensorVariable] = None, default_strategy: str = "moment", return_transformed: bool = False, ) -> List[TensorVariable]: """Creates the tensor variables that need to be evaluated to obtain an initial point. Parameters ---------- free_rvs : list Tensors of free random variables in the model. rvs_to_values : dict Mapping of free random variable tensors to value variable tensors. initval_strategies : dict Mapping of free random variable tensors to initial value strategies. For example the `Model.initial_values` dictionary. jitter_rvs : set The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be added to the initial value. Only available for variables that have a transform or real-valued support. default_strategy : str Which of { "moment", "prior" } to prefer if the initval strategy setting for an RV is None. return_transformed : bool Switches between returning the tensors for untransformed or transformed initial points. Returns ------- initial_points : list of TensorVariable Aesara expressions for initial values of the free random variables. """ from pymc.distributions.distribution import get_moment if jitter_rvs is None: jitter_rvs = set() initial_values = [] initial_values_transformed = [] for variable in free_rvs: strategy = initval_strategies.get(variable, None) if strategy is None: strategy = default_strategy if isinstance(strategy, str): if strategy == "moment": try: value = get_moment(variable) except NotImplementedError: warnings.warn( f"Moment not defined for variable {variable} of type " f"{variable.owner.op.__class__.__name__}, defaulting to " f"a draw from the prior. This can lead to difficulties " f"during tuning. You can manually define an initval or " f"implement a get_moment dispatched function for this " f"distribution.", UserWarning, ) value = variable elif strategy == "prior": value = variable else: raise ValueError( f'Invalid string strategy: {strategy}. It must be one of ["moment", "prior"]' ) else: value = at.as_tensor(strategy, dtype=variable.dtype).astype(variable.dtype) transform = getattr(rvs_to_values[variable].tag, "transform", None) if transform is not None: value = transform.forward(value, *variable.owner.inputs) if variable in jitter_rvs: jitter = at.random.uniform(-1, 1, size=value.shape) jitter.name = f"{variable.name}_jitter" value = value + jitter value = value.astype(variable.dtype) initial_values_transformed.append(value) if transform is not None: value = transform.backward(value, *variable.owner.inputs) initial_values.append(value) all_outputs = [] all_outputs.extend(free_rvs) all_outputs.extend(initial_values) all_outputs.extend(initial_values_transformed) copy_graph = FunctionGraph(outputs=all_outputs, clone=True) n_variables = len(free_rvs) free_rvs_clone = copy_graph.outputs[:n_variables] initial_values_clone = copy_graph.outputs[n_variables:-n_variables] initial_values_transformed_clone = copy_graph.outputs[-n_variables:] # We now replace all rvs by the respective initial_point expressions # in the constrained (untransformed) space. We do this in reverse topological # order, so that later nodes do not reintroduce expressions with earlier # rvs that would need to once again be replaced by their initial_points graph = FunctionGraph(outputs=free_rvs_clone, clone=False) replacements = reversed(list(zip(free_rvs_clone, initial_values_clone))) graph.replace_all(replacements, import_missing=True) if not return_transformed: return graph.outputs # Because the unconstrained (transformed) expressions are a subgraph of the # constrained initial point they were also automatically updated inplace # when calling graph.replace_all above, so we don't need to do anything else return initial_values_transformed_clone
def make_initial_point_expression( *, free_rvs: Sequence[TensorVariable], rvs_to_values: Dict[TensorVariable, TensorVariable], initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]], jitter_rvs: Set[TensorVariable] = None, default_strategy: str = "prior", return_transformed: bool = False, ) -> List[TensorVariable]: """Creates the tensor variables that need to be evaluated to obtain an initial point. Parameters ---------- free_rvs : list Tensors of free random variables in the model. rvs_to_values : dict Mapping of free random variable tensors to value variable tensors. initval_strategies : dict Mapping of free random variable tensors to initial value strategies. For example the `Model.initial_values` dictionary. jitter_rvs : set The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be added to the initial value. Only available for variables that have a transform or real-valued support. default_strategy : str Which of { "moment", "prior" } to prefer if the initval strategy setting for an RV is None. return_transformed : bool Switches between returning the tensors for untransformed or transformed initial points. Returns ------- initial_points : list of TensorVariable Aesara expressions for initial values of the free random variables. """ from pymc.distributions.distribution import get_moment if jitter_rvs is None: jitter_rvs = set() initial_values = [] initial_values_transformed = [] for variable in free_rvs: strategy = initval_strategies.get(variable, None) if strategy is None: strategy = default_strategy if strategy == "moment": value = get_moment(variable) elif strategy == "prior": value = variable else: value = at.as_tensor(strategy, dtype=variable.dtype).astype(variable.dtype) transform = getattr(rvs_to_values[variable].tag, "transform", None) if transform is not None: value = transform.forward(value, *variable.owner.inputs) if variable in jitter_rvs: jitter = at.random.uniform(-1, 1, size=value.shape) jitter.name = f"{variable.name}_jitter" value = value + jitter initial_values_transformed.append(value) if transform is not None: value = transform.backward(value, *variable.owner.inputs) initial_values.append(value) all_outputs = [] all_outputs.extend(free_rvs) all_outputs.extend(initial_values) all_outputs.extend(initial_values_transformed) copy_graph = FunctionGraph(outputs=all_outputs, clone=True) n_variables = len(free_rvs) free_rvs_clone = copy_graph.outputs[:n_variables] initial_values_clone = copy_graph.outputs[n_variables:-n_variables] initial_values_transformed_clone = copy_graph.outputs[-n_variables:] # In the order the variables were created, replace each previous variable # with the init_point for that variable. initial_values = [] initial_values_transformed = [] for i in range(n_variables): outputs = [ initial_values_clone[i], initial_values_transformed_clone[i] ] graph = FunctionGraph(outputs=outputs, clone=False) graph.replace_all(zip(free_rvs_clone[:i], initial_values), import_missing=True) initial_values.append(graph.outputs[0]) initial_values_transformed.append(graph.outputs[1]) if return_transformed: return initial_values_transformed return initial_values
def test_symbolic_moment_shape(self, rv_cls): s = at.scalar() rv = rv_cls.dist(shape=(s, )) assert not hasattr(rv.tag, "test_value") assert tuple(get_moment(rv).shape.eval({s: 4})) == (4, ) pass
def test_numeric_moment_shape(self, rv_cls): rv = rv_cls.dist(shape=(2, )) assert not hasattr(rv.tag, "test_value") assert tuple(get_moment(rv).shape.eval()) == (2, )