def __init__(self, mu, covariance_matrix): if jnp.ndim(mu) < 1: mu = jnp.reshape(mu, (1,) + jnp.shape(mu)) (mu_event_shape,) = jnp.shape(mu)[-1:] covariance_event_shape = jnp.shape(covariance_matrix)[-2:] if (mu_event_shape, mu_event_shape) != covariance_event_shape: raise ValueError( ( f"The number of dimensions implied by `mu`(dims = {mu_event_shape})" ", does not match the dimensions implied by `covariance_matrix`" f"(dims = {covariance_event_shape})" ) ) mu = mu[..., jnp.newaxis] mu, covariance_matrix = promote_shapes(mu, covariance_matrix) self.event_shape = jnp.shape(covariance_matrix)[-1:] self.batch_shape = lax.broadcast_shapes( jnp.shape(mu)[:-2], jnp.shape(covariance_matrix)[:-2] ) self.mu = mu[..., 0].squeeze() self.covariance_matrix = covariance_matrix.squeeze()
def __init__(self, a, b): self.event_shape = () a, b = promote_shapes(a, b) batch_shape = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b)) self.batch_shape = batch_shape self.a = jnp.broadcast_to(a, batch_shape) self.b = jnp.broadcast_to(b, batch_shape)
def __init__(self, loc, scale): self.event_shape = () loc, scale = promote_shapes(loc, scale) batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) self.batch_shape = batch_shape self.loc = jnp.broadcast_to(loc, batch_shape) self.scale = jnp.broadcast_to(scale, batch_shape)
def __init__(self, mu, sigma): self.event_shape = () mu, sigma = promote_shapes(mu, sigma) batch_shape = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(sigma)) self.batch_shape = batch_shape self.mu = jnp.broadcast_to(mu, batch_shape) self.sigma = jnp.broadcast_to(sigma, batch_shape)
def __init__(self, lower, upper): self.support = constraints.integer_interval(lower, upper) self.event_shape = () lower, upper = promote_shapes(lower, upper) batch_shape = lax.broadcast_shapes(jnp.shape(lower), jnp.shape(upper)) self.batch_shape = batch_shape self.lower = jnp.broadcast_to(jnp.floor(lower), batch_shape) self.upper = jnp.broadcast_to(jnp.floor(upper), batch_shape)
def __init__(self, p, n): self.support = constraints.integer_interval(0, n) self.event_shape = () p, n = promote_shapes(p, n) batch_shape = lax.broadcast_shapes(jnp.shape(p), jnp.shape(n)) self.batch_shape = batch_shape self.n = jnp.broadcast_to(n, batch_shape) self.p = jnp.broadcast_to(p, batch_shape)
def __init__(self, n, a, b): self.support = constraints.integer_interval(0, n) self.event_shape = () n, a, b = promote_shapes(n, a, b) batch_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(a), jnp.shape(b)) self.batch_shape = batch_shape self.a = jnp.broadcast_to(a, batch_shape) self.b = jnp.broadcast_to(b, batch_shape) self.n = jnp.broadcast_to(n, batch_shape)