示例#1
0
 def test_convert_shape(self):
     assert convert_shape(5) == (5, )
     with pytest.raises(ValueError,
                        match="tuple, TensorVariable, int or list"):
         convert_shape(shape="notashape")
     with pytest.raises(ValueError,
                        match="may only appear in the last position"):
         convert_shape(shape=(3, ..., 2))
示例#2
0
    def dist(
        cls,
        dist_params,
        *,
        shape: Optional[Shape] = None,
        size: Optional[Size] = None,
        **kwargs,
    ) -> TensorVariable:
        """Creates a TensorVariable corresponding to the `cls` symbolic distribution.

        Parameters
        ----------
        dist_params : array-like
            The inputs to the `RandomVariable` `Op`.
        shape : int, tuple, Variable, optional
            A tuple of sizes for each dimension of the new RV.
            An Ellipsis (...) may be inserted in the last position to short-hand refer to
            all the dimensions that the RV would get if no shape/size/dims were passed at all.
        size : int, tuple, Variable, optional
            For creating the RV like in Aesara/NumPy.

        Returns
        -------
        var : TensorVariable
        """

        if "testval" in kwargs:
            kwargs.pop("testval")
            warnings.warn(
                "The `.dist(testval=...)` argument is deprecated and has no effect. "
                "Initial values for sampling/optimization can be specified with `initval` in a modelcontext. "
                "For using Aesara's test value features, you must assign the `.tag.test_value` yourself.",
                FutureWarning,
                stacklevel=2,
            )
        if "initval" in kwargs:
            raise TypeError(
                "Unexpected keyword argument `initval`. "
                "This argument is not available for the `.dist()` API.")

        if "dims" in kwargs:
            raise NotImplementedError(
                "The use of a `.dist(dims=...)` API is not supported.")
        if shape is not None and size is not None:
            raise ValueError(
                f"Passing both `shape` ({shape}) and `size` ({size}) is not supported!"
            )

        shape = convert_shape(shape)
        size = convert_size(size)

        create_size, ndim_expected, ndim_batch, ndim_supp = find_size(
            shape=shape, size=size, ndim_supp=cls.ndim_supp(*dist_params))
        # Create the RV with a `size` right away.
        # This is not necessarily the final result.
        graph = cls.rv_op(*dist_params, size=create_size, **kwargs)

        # Replicate dimensions may be prepended via a shape with Ellipsis as the last element:
        if shape is not None and Ellipsis in shape:
            replicate_shape = cast(StrongShape, shape[:-1])
            graph = cls.change_size(rv=graph,
                                    new_size=replicate_shape,
                                    expand=True)

        # TODO: Create new attr error stating that these are not available for DerivedDistribution
        # rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
        # rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
        # rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
        return graph
示例#3
0
    def dist(
        cls,
        dist_params,
        *,
        shape: Optional[Shape] = None,
        size: Optional[Size] = None,
        **kwargs,
    ) -> RandomVariable:
        """Creates a RandomVariable corresponding to the `cls` distribution.

        Parameters
        ----------
        dist_params : array-like
            The inputs to the `RandomVariable` `Op`.
        shape : int, tuple, Variable, optional
            A tuple of sizes for each dimension of the new RV.

            An Ellipsis (...) may be inserted in the last position to short-hand refer to
            all the dimensions that the RV would get if no shape/size/dims were passed at all.
        size : int, tuple, Variable, optional
            For creating the RV like in Aesara/NumPy.

        Returns
        -------
        rv : RandomVariable
            The created RV.
        """
        if "testval" in kwargs:
            kwargs.pop("testval")
            warnings.warn(
                "The `.dist(testval=...)` argument is deprecated and has no effect. "
                "Initial values for sampling/optimization can be specified with `initval` in a modelcontext. "
                "For using Aesara's test value features, you must assign the `.tag.test_value` yourself.",
                DeprecationWarning,
                stacklevel=2,
            )
        if "initval" in kwargs:
            raise TypeError(
                "Unexpected keyword argument `initval`. "
                "This argument is not available for the `.dist()` API.")

        if "dims" in kwargs:
            raise NotImplementedError(
                "The use of a `.dist(dims=...)` API is not supported.")
        if shape is not None and size is not None:
            raise ValueError(
                f"Passing both `shape` ({shape}) and `size` ({size}) is not supported!"
            )

        shape = convert_shape(shape)
        size = convert_size(size)

        create_size, ndim_expected, ndim_batch, ndim_supp = find_size(
            shape=shape, size=size, ndim_supp=cls.rv_op.ndim_supp)
        # Create the RV with a `size` right away.
        # This is not necessarily the final result.
        rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
        rv_out = maybe_resize(
            rv_out,
            cls.rv_op,
            dist_params,
            ndim_expected,
            ndim_batch,
            ndim_supp,
            shape,
            size,
            **kwargs,
        )

        rng = kwargs.pop("rng", None)
        if (rv_out.owner and isinstance(rv_out.owner.op, RandomVariable)
                and isinstance(rng, RandomStateSharedVariable)
                and not getattr(rng, "default_update", None)):
            # This tells `aesara.function` that the shared RNG variable
            # is mutable, which--in turn--tells the `FunctionGraph`
            # `Supervisor` feature to allow in-place updates on the variable.
            # Without it, the `RandomVariable`s could not be optimized to allow
            # in-place RNG updates, forcing all sample results from compiled
            # functions to be the same on repeated evaluations.
            new_rng = rv_out.owner.outputs[0]
            rv_out.update = (rng, new_rng)
            rng.default_update = new_rng

        rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
        rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)",
                                              "pm.logcdf(rv, x)")
        rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
        return rv_out