def model(): with numpyro.plate_stack("plates", shape): with numpyro.plate("particles", 100000): return numpyro.sample( "x", dist.TransformedDistribution( dist.Normal(jnp.zeros_like(loc), jnp.ones_like(scale)), [AffineTransform(loc, scale), ExpTransform()]).expand_by([100000]))
def model(): fn = dist.TransformedDistribution( dist.Normal(jnp.zeros_like(loc), jnp.ones_like(scale)), [AffineTransform(loc, scale), ExpTransform()]).expand(shape) if event_shape: fn = fn.to_event(len(event_shape)).expand_by([100000]) with numpyro.plate_stack("plates", batch_shape): with numpyro.plate("particles", 100000): return numpyro.sample("x", fn)
def __init__(self, alpha, scale=1., validate_args=None): batch_shape = lax.broadcast_shapes(np.shape(scale), np.shape(alpha)) self.scale, self.alpha = np.broadcast_to(scale, batch_shape), np.broadcast_to( alpha, batch_shape) base_dist = Exponential(self.alpha) transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)] super(Pareto, self).__init__(base_dist, transforms, validate_args=validate_args)
def __init__(self, loc=0., scale=1., validate_args=None): base_dist = Normal(loc, scale) self.loc, self.scale = base_dist.loc, base_dist.scale super(LogNormal, self).__init__(base_dist, ExpTransform(), validate_args=validate_args)