def rvs(self, *args, **kwargs): rng = kwargs.pop('random_state') if rng is None: rng = self.random_state # assert that rng is PRNGKey and not mtrand.RandomState object from numpy. assert _is_prng_key(rng) args = list(args) # If 'size' is not in kwargs, then it is either the last element of args # or it will take default value (which is None). # Note: self.numargs is the number of shape parameters. size = kwargs.pop( 'size', args.pop() if len(args) > (self.numargs + 2) else None) # TODO: when args is not empty, parse_args requires either _pdf or _cdf method is implemented # to recognize valid arg signatures (e.g. `a` in `gamma` or `s` in lognormal) args, loc, scale = self._parse_args(*args, **kwargs) # FIXME(fehiepsi): Using _promote_args_like requires calling `super(jax_continuous, self).rvs` but # it will call `self._rvs` (which is written using JAX and requires JAX random state). loc, scale, *args = _promote_args("rvs", loc, scale, *args) if not size: shapes = [np.shape(arg) for arg in args ] + [np.shape(loc), np.shape(scale)] size = lax.broadcast_shapes(*shapes) # TODO(fehiepsi): add test for int size elif isinstance(size, int): size = (size, ) self._random_state = rng self._size = size vals = self._rvs(*args) return vals * scale + loc
def rvs(self, *args, **kwargs): rng = kwargs.pop('random_state') if rng is None: rng = self.random_state # assert that rng is PRNGKey and not mtrand.RandomState object from numpy. assert _is_prng_key(rng) args = list(args) # If 'size' is not in kwargs, then it is either the last element of args # or it will take default value (which is None). # Note: self.numargs is the number of shape parameters. size = kwargs.pop( 'size', args.pop() if len(args) > (self.numargs + 2) else None) # XXX when args is not empty, parse_args requires either _pdf or _cdf method is implemented # to recognize valid arg signatures (e.g. `a` in `gamma` or `s` in lognormal) args, loc, scale = self._parse_args(*args, **kwargs) if not size: shapes = [np.shape(arg) for arg in args ] + [np.shape(loc), np.shape(scale)] size = lax.broadcast_shapes(*shapes) elif isinstance(size, int): size = (size, ) self._random_state = rng self._size = size vals = self._rvs(*args) return vals * scale + loc
def rvs(self, *args, **kwargs): rng = kwargs.pop('random_state') if rng is None: rng = self.random_state # assert that rng is PRNGKey and not mtrand.RandomState object from numpy. assert _is_prng_key(rng) kwargs['random_state'] = onp.random.RandomState(rng) sample = super(jax_discrete, self).rvs(*args, **kwargs) return device_put(sample)
def rvs(self, *args, **kwargs): rng_key = kwargs.pop('random_state') if rng_key is None: rng_key = self.random_state # assert that rng_key is PRNGKey and not mtrand.RandomState object from numpy. assert _is_prng_key(rng_key) args = list(args) size = kwargs.pop('size', args.pop() if len(args) > self.numargs else None) args, _, _ = self._parse_args(*args, **kwargs) if not size: size = self._batch_shape(*args) elif isinstance(size, int): size = (size,) self._random_state = rng_key self._size = size return self._rvs(*args)
def rvs(self, *args, **kwargs): rng_key = kwargs.pop('random_state') if rng_key is None: rng_key = self.random_state # assert that rng_key is PRNGKey and not mtrand.RandomState object from numpy. assert _is_prng_key(rng_key) args = list(args) size = kwargs.pop('size', args.pop() if len(args) > (self.numargs + 1) else None) args, loc, _ = self._parse_args(*args, **kwargs) if not size: shapes = [jnp.shape(arg) for arg in args] + [jnp.shape(loc)] size = lax.broadcast_shapes(*shapes) elif isinstance(size, int): size = (size,) self._random_state = rng_key self._size = size vals = self._rvs(*args) return vals + loc
def rvs(self, *args, **kwargs): rng = kwargs.pop('random_state') if rng is None: rng = self.random_state # assert that rng is PRNGKey and not mtrand.RandomState object from numpy. assert _is_prng_key(rng) args = list(args) size = kwargs.pop('size', args.pop() if len(args) > self.numargs else None) args, _, _ = self._parse_args(*args, **kwargs) # XXX we might not need to verify that args is empty for multivariate distributions args = _promote_args("rvs", *args) # TODO: make this code compatible to mvn distribution if not size: size = args[-1].shape[:-1] elif isinstance(size, int): size = (size, ) self._random_state = rng self._size = size return self._rvs(*args)