示例#1
0
def test_mixed_contexts():
    modelA = Model()
    modelB = Model()
    with raises((ValueError, TypeError)):
        modelcontext(None)
    with modelA:
        with modelB:
            assert Model.get_context() == modelB
            assert modelcontext(None) == modelB
            dvc = _DrawValuesContext()
            with dvc:
                assert Model.get_context() == modelB
                assert modelcontext(None) == modelB
                assert _DrawValuesContext.get_context() == dvc
                dvcb = _DrawValuesContextBlocker()
                with dvcb:
                    assert _DrawValuesContext.get_context() == dvcb
                    assert _DrawValuesContextBlocker.get_context() == dvcb
                assert _DrawValuesContext.get_context() == dvc
                assert _DrawValuesContextBlocker.get_context() is dvc
                assert Model.get_context() == modelB
                assert modelcontext(None) == modelB
            assert _DrawValuesContext.get_context(error_if_none=False) is None
            with raises(TypeError):
                _DrawValuesContext.get_context()
            assert Model.get_context() == modelB
            assert modelcontext(None) == modelB
        assert Model.get_context() == modelA
        assert modelcontext(None) == modelA
    assert Model.get_context(error_if_none=False) is None
    with raises(TypeError):
        Model.get_context(error_if_none=True)
    with raises((ValueError, TypeError)):
        modelcontext(None)
示例#2
0
 def test_blocking_context(self):
     with _DrawValuesContext() as context0:
         assert context0.parent is None
         context0.drawn_vars['root_test'] = 1
         with _DrawValuesContext() as context1:
             assert id(context1.drawn_vars) == id(context0.drawn_vars)
             assert context1.parent == context0
             with _DrawValuesContextBlocker() as blocker:
                 assert id(blocker.drawn_vars) != id(context0.drawn_vars)
                 assert blocker.parent is None
                 blocker.drawn_vars['root_test'] = 2
                 with _DrawValuesContext() as context2:
                     assert id(context2.drawn_vars) == id(blocker.drawn_vars)
                     assert context2.parent == blocker
                     context2.drawn_vars['root_test'] = 3
                     context2.drawn_vars['leaf_test'] = 4
                 assert blocker.drawn_vars['root_test'] == 3
             assert 'leaf_test' not in context1.drawn_vars
         assert context0.drawn_vars['root_test'] == 1
示例#3
0
    def draw_value(self, param, trace: _TraceDict | None = None, givens=None):
        """Draw a set of random values from a distribution or return a constant.

        Parameters
        ----------
        param: number, array like, aesara variable or pymc3 random variable
            The value or distribution. Constants or shared variables
            will be converted to an array and returned. Aesara variables
            are evaluated. If `param` is a pymc3 random variable, draw
            values from it and return that (as ``np.ndarray``), unless a
            value is specified in the ``trace``.
        trace: pm.MultiTrace, optional
            A dictionary from pymc3 variable names to samples of their values
            used to provide context for evaluating ``param``.
        givens: dict, optional
            A dictionary from aesara variables to their values. These values
            are used to evaluate ``param`` if it is a aesara variable.
        """
        samples = self.samples

        def random_sample(
            meth: Callable[..., np.ndarray],
            param,
            point: _TraceDict,
            size: int,
            shape: tuple[int, ...],
        ) -> np.ndarray:
            val = meth(point=point, size=size)
            try:
                assert val.shape == (size, ) + shape, (
                    "Sampling from random of %s yields wrong shape" % param)
            # error-quashing here is *extremely* ugly, but it seems to be what the logic in DensityDist wants.
            except AssertionError as e:
                if (hasattr(param, "distribution") and hasattr(
                        param.distribution, "wrap_random_with_dist_shape") and
                        not param.distribution.wrap_random_with_dist_shape):
                    pass
                else:
                    raise e

            return val

        if isinstance(param, (numbers.Number, np.ndarray)):
            return param
        elif isinstance(param, Constant):
            return param.value
        elif isinstance(param, SharedVariable):
            return param.get_value()
        elif isinstance(param, (TensorVariable, MultiObservedRV)):
            if hasattr(param,
                       "model") and trace and param.name in trace.varnames:
                return trace[param.name]
            elif hasattr(param, "random") and param.random is not None:
                model = modelcontext(None)
                assert isinstance(model, Model)
                shape: tuple[int, ...] = tuple(_param_shape(param, model))
                return random_sample(param.random,
                                     param,
                                     point=trace,
                                     size=samples,
                                     shape=shape)
            elif (hasattr(param, "distribution")
                  and hasattr(param.distribution, "random")
                  and param.distribution.random is not None):
                if hasattr(param, "observations"):
                    # shape inspection for ObservedRV
                    dist_tmp = param.distribution
                    try:
                        distshape: tuple[int, ...] = tuple(
                            param.observations.shape.eval())
                    except AttributeError:
                        distshape = tuple(param.observations.shape)

                    dist_tmp.shape = distshape
                    try:
                        return random_sample(
                            dist_tmp.random,
                            param,
                            point=trace,
                            size=samples,
                            shape=distshape,
                        )
                    except (ValueError, TypeError):
                        # reset shape to account for shape changes
                        # with aesara.shared inputs
                        dist_tmp.shape = ()
                        # We want to draw values to infer the dist_shape,
                        # we don't want to store these drawn values to the context
                        with _DrawValuesContextBlocker():
                            point = trace[0] if trace else None
                            temp_val = np.atleast_1d(
                                dist_tmp.random(point=point, size=None))
                        # if hasattr(param, 'name') and param.name == 'obs':
                        #     import pdb; pdb.set_trace()
                        # Sometimes point may change the size of val but not the
                        # distribution's shape
                        if point and samples is not None:
                            temp_size = np.atleast_1d(samples)
                            if all(temp_val.shape[:len(temp_size)] ==
                                   temp_size):
                                dist_tmp.shape = tuple(
                                    temp_val.shape[len(temp_size):])
                            else:
                                dist_tmp.shape = tuple(temp_val.shape)
                        # I am not sure why I need to do this, but I do in order to trim off a
                        # degenerate dimension [2019/09/05:rpg]
                        if dist_tmp.shape[0] == 1 and len(dist_tmp.shape) > 1:
                            dist_tmp.shape = dist_tmp.shape[1:]
                        return random_sample(
                            dist_tmp.random,
                            point=trace,
                            size=samples,
                            param=param,
                            shape=tuple(dist_tmp.shape),
                        )
                else:  # has a distribution, but no observations
                    distshape = tuple(param.distribution.shape)
                    return random_sample(
                        meth=param.distribution.random,
                        param=param,
                        point=trace,
                        size=samples,
                        shape=distshape,
                    )
            # NOTE: I think the following is already vectorized.
            else:
                if givens:
                    variables, values = list(zip(*givens))
                else:
                    variables = values = []
                # We only truly care if the ancestors of param that were given
                # value have the matching dshape and val.shape
                param_ancestors = set(
                    aesara.graph.basic.ancestors([param],
                                                 blockers=list(variables)))
                inputs = [(var, val) for var, val in zip(variables, values)
                          if var in param_ancestors]
                if inputs:
                    input_vars, input_vals = list(zip(*inputs))
                else:
                    input_vars = []
                    input_vals = []
                func = _compile_aesara_function(param, input_vars)
                if not input_vars:
                    assert input_vals == [
                    ]  # AFAICT if there are now vars, there can't be vals
                    output = func(*input_vals)
                    if hasattr(output, "shape"):
                        val = np.repeat(np.expand_dims(output, 0),
                                        samples,
                                        axis=0)
                    else:
                        val = np.full(samples, output)

                else:
                    val = func(*input_vals)
                    # np.ndarray([func(*input_vals) for inp in zip(*input_vals)])
                return val
        raise ValueError("Unexpected type in draw_value: %s" % type(param))
示例#4
0
    def infer_comp_dist_shapes(self, point=None):
        """Try to infer the shapes of the component distributions,
        `comp_dists`, and how they should broadcast together.
        The behavior is slightly different if `comp_dists` is a `Distribution`
        as compared to when it is a list of `Distribution`s. When it is a list
        the following procedure is repeated for each element in the list:
        1. Look up the `comp_dists.shape`
        2. If it is not empty, use it as `comp_dist_shape`
        3. If it is an empty tuple, a single random sample is drawn by calling
        `comp_dists.random(point=point, size=None)`, and the returned
        test_sample's shape is used as the inferred `comp_dists.shape`

        Parameters
        ----------
        point: None or dict (optional)
            Dictionary that maps rv names to values, to supply to
            `self.comp_dists.random`

        Returns
        -------
        comp_dist_shapes: shape tuple or list of shape tuples.
            If `comp_dists` is a `Distribution`, it is a shape tuple of the
            inferred distribution shape.
            If `comp_dists` is a list of `Distribution`s, it is a list of
            shape tuples inferred for each element in `comp_dists`
        broadcast_shape: shape tuple
            The shape that results from broadcasting all component's shapes
            together.
        """
        if self.comp_is_distribution:
            if len(self._comp_dist_shapes) > 0:
                comp_dist_shapes = self._comp_dist_shapes
            else:
                # Happens when the distribution is a scalar or when it was not
                # given a shape. In these cases we try to draw a single value
                # to check its shape, we use the provided point dictionary
                # hoping that it can circumvent the Flat and HalfFlat
                # undrawable distributions.
                with _DrawValuesContextBlocker():
                    test_sample = self._comp_dists.random(point=point, size=None)
                    comp_dist_shapes = test_sample.shape
            broadcast_shape = comp_dist_shapes
        else:
            # Now we check the comp_dists distribution shape, see what
            # the broadcast shape would be. This shape will be the dist_shape
            # used by generate samples (the shape of a single random sample)
            # from the mixture
            comp_dist_shapes = []
            for dist_shape, comp_dist in zip(self._comp_dist_shapes, self._comp_dists):
                if dist_shape == tuple():
                    # Happens when the distribution is a scalar or when it was
                    # not given a shape. In these cases we try to draw a single
                    # value to check its shape, we use the provided point
                    # dictionary hoping that it can circumvent the Flat and
                    # HalfFlat undrawable distributions.
                    with _DrawValuesContextBlocker():
                        test_sample = comp_dist.random(point=point, size=None)
                        dist_shape = test_sample.shape
                comp_dist_shapes.append(dist_shape)
            # All component distributions must broadcast with each other
            try:
                broadcast_shape = np.broadcast(
                    *[np.empty(shape) for shape in comp_dist_shapes]
                ).shape
            except Exception:
                raise TypeError(
                    "Inferred comp_dist shapes do not broadcast "
                    "with each other. comp_dists inferred shapes "
                    "are: {}".format(comp_dist_shapes)
                )
        return comp_dist_shapes, broadcast_shape