def __init__(self, a, loc, scale): self.event_shape = () self.batch_shape = broadcast_batch_shape(jnp.shape(a), jnp.shape(loc), jnp.shape(scale)) self.a = a self.loc = loc self.scale = scale
def __init__(self, lower, upper): self.support = constraints.integer_interval(lower, upper) self.event_shape = () self.batch_shape = broadcast_batch_shape(jnp.shape(lower), jnp.shape(upper)) self.lower = jnp.floor(lower) self.upper = jnp.floor(upper)
def __init__(self, p, n): self.support = constraints.integer_interval(0, n) self.event_shape = () self.batch_shape = broadcast_batch_shape(np.shape(p), np.shape(n)) self.n = n self.p = p
def __init__(self, mu, sigma): self.event_shape = () self.batch_shape = broadcast_batch_shape(jnp.shape(mu), jnp.shape(sigma)) self.mu = mu self.sigma = sigma super(Normal, self).__init__()
def __init__(self, lower, upper): self.support = constraints.closed_interval(lower, upper) self.event_shape = () self.batch_shape = broadcast_batch_shape(jnp.shape(lower), jnp.shape(upper)) self.lower = lower self.upper = upper
def __init__(self, mu, covariance_matrix): (mu_event_shape, ) = jnp.shape(mu)[-1:] covariance_event_shape = jnp.shape(covariance_matrix)[-2:] if (mu_event_shape, mu_event_shape) != covariance_event_shape: raise ValueError(( f"The number of dimensions implied by `mu` ({mu_event_shape})," "does not match the dimensions implied by `covariance_matrix` " f"({covariance_event_shape})")) self.batch_shape = broadcast_batch_shape( jnp.shape(mu)[:-1], jnp.shape(covariance_matrix)[:-2]) self.event_shape = broadcast_batch_shape( jnp.shape(mu)[-1:], jnp.shape(covariance_matrix)[-2:]) self.mu = mu self.covariance_matrix = covariance_matrix super().__init__()
def __init__(self, p): self.event_shape = () self.batch_shape = broadcast_batch_shape(np.shape(p)) self.p = p * 1.0 # will fail if p is int
def __init__(self, lmbda): self.event_shape = () self.batch_shape = broadcast_batch_shape(np.shape(lmbda)) self.lmbda = lmbda
def __init__(self, df): self.event_shape = () self.batch_shape = broadcast_batch_shape(jnp.shape(df)) self.df = df
def __init__(self, probs): self.support = constraints.integer_interval(0, jnp.shape(probs)[-1] - 1) self.event_shape = () self.batch_shape = broadcast_batch_shape(jnp.shape(probs)[:-1]) self.probs = probs
def __init__(self, mu, sigma): self.event_shape = () self.batch_shape = broadcast_batch_shape(np.shape(mu), np.shape(sigma)) self.mu = mu self.sigma = sigma
def __init__(self, a, b): self.event_shape = () self.batch_shape = broadcast_batch_shape(np.shape(a), np.shape(b)) self.a = a self.b = b