Beispiel #1
0
 def rvs(self, *args, **kwargs):
     rng = kwargs.pop('random_state')
     if rng is None:
         rng = self.random_state
     # assert that rng is PRNGKey and not mtrand.RandomState object from numpy.
     assert _is_prng_key(rng)
     args = list(args)
     # If 'size' is not in kwargs, then it is either the last element of args
     # or it will take default value (which is None).
     # Note: self.numargs is the number of shape parameters.
     size = kwargs.pop(
         'size',
         args.pop() if len(args) > (self.numargs + 2) else None)
     # TODO: when args is not empty, parse_args requires either _pdf or _cdf method is implemented
     # to recognize valid arg signatures (e.g. `a` in `gamma` or `s` in lognormal)
     args, loc, scale = self._parse_args(*args, **kwargs)
     # FIXME(fehiepsi): Using _promote_args_like requires calling `super(jax_continuous, self).rvs` but
     # it will call `self._rvs` (which is written using JAX and requires JAX random state).
     loc, scale, *args = _promote_args("rvs", loc, scale, *args)
     if not size:
         shapes = [np.shape(arg) for arg in args
                   ] + [np.shape(loc), np.shape(scale)]
         size = lax.broadcast_shapes(*shapes)
     # TODO(fehiepsi): add test for int size
     elif isinstance(size, int):
         size = (size, )
     self._random_state = rng
     self._size = size
     vals = self._rvs(*args)
     return vals * scale + loc
Beispiel #2
0
    def rvs(self, *args, **kwargs):
        rng = kwargs.pop('random_state')
        if rng is None:
            rng = self.random_state
        # assert that rng is PRNGKey and not mtrand.RandomState object from numpy.
        assert _is_prng_key(rng)

        args = list(args)
        # If 'size' is not in kwargs, then it is either the last element of args
        # or it will take default value (which is None).
        # Note: self.numargs is the number of shape parameters.
        size = kwargs.pop(
            'size',
            args.pop() if len(args) > (self.numargs + 2) else None)
        # XXX when args is not empty, parse_args requires either _pdf or _cdf method is implemented
        # to recognize valid arg signatures (e.g. `a` in `gamma` or `s` in lognormal)
        args, loc, scale = self._parse_args(*args, **kwargs)
        if not size:
            shapes = [np.shape(arg) for arg in args
                      ] + [np.shape(loc), np.shape(scale)]
            size = lax.broadcast_shapes(*shapes)
        elif isinstance(size, int):
            size = (size, )

        self._random_state = rng
        self._size = size
        vals = self._rvs(*args)
        return vals * scale + loc
Beispiel #3
0
 def rvs(self, *args, **kwargs):
     rng = kwargs.pop('random_state')
     if rng is None:
         rng = self.random_state
     # assert that rng is PRNGKey and not mtrand.RandomState object from numpy.
     assert _is_prng_key(rng)
     kwargs['random_state'] = onp.random.RandomState(rng)
     sample = super(jax_discrete, self).rvs(*args, **kwargs)
     return device_put(sample)
Beispiel #4
0
    def rvs(self, *args, **kwargs):
        rng_key = kwargs.pop('random_state')
        if rng_key is None:
            rng_key = self.random_state
        # assert that rng_key is PRNGKey and not mtrand.RandomState object from numpy.
        assert _is_prng_key(rng_key)

        args = list(args)
        size = kwargs.pop('size', args.pop() if len(args) > self.numargs else None)
        args, _, _ = self._parse_args(*args, **kwargs)
        if not size:
            size = self._batch_shape(*args)
        elif isinstance(size, int):
            size = (size,)

        self._random_state = rng_key
        self._size = size
        return self._rvs(*args)
Beispiel #5
0
    def rvs(self, *args, **kwargs):
        rng_key = kwargs.pop('random_state')
        if rng_key is None:
            rng_key = self.random_state
        # assert that rng_key is PRNGKey and not mtrand.RandomState object from numpy.
        assert _is_prng_key(rng_key)

        args = list(args)
        size = kwargs.pop('size', args.pop() if len(args) > (self.numargs + 1) else None)
        args, loc, _ = self._parse_args(*args, **kwargs)
        if not size:
            shapes = [jnp.shape(arg) for arg in args] + [jnp.shape(loc)]
            size = lax.broadcast_shapes(*shapes)
        elif isinstance(size, int):
            size = (size,)

        self._random_state = rng_key
        self._size = size
        vals = self._rvs(*args)
        return vals + loc
Beispiel #6
0
    def rvs(self, *args, **kwargs):
        rng = kwargs.pop('random_state')
        if rng is None:
            rng = self.random_state
        # assert that rng is PRNGKey and not mtrand.RandomState object from numpy.
        assert _is_prng_key(rng)

        args = list(args)
        size = kwargs.pop('size',
                          args.pop() if len(args) > self.numargs else None)

        args, _, _ = self._parse_args(*args, **kwargs)
        # XXX we might not need to verify that args is empty for multivariate distributions
        args = _promote_args("rvs", *args)

        # TODO: make this code compatible to mvn distribution
        if not size:
            size = args[-1].shape[:-1]
        elif isinstance(size, int):
            size = (size, )

        self._random_state = rng
        self._size = size
        return self._rvs(*args)