예제 #1
0
파일: mvnormal.py 프로젝트: rlouf/mcx
    def __init__(self, mu, covariance_matrix):

        if jnp.ndim(mu) < 1:
            mu = jnp.reshape(mu, (1,) + jnp.shape(mu))

        (mu_event_shape,) = jnp.shape(mu)[-1:]
        covariance_event_shape = jnp.shape(covariance_matrix)[-2:]
        if (mu_event_shape, mu_event_shape) != covariance_event_shape:
            raise ValueError(
                (
                    f"The number of dimensions implied by `mu`(dims = {mu_event_shape})"
                    ", does not match the dimensions implied by `covariance_matrix`"
                    f"(dims = {covariance_event_shape})"
                )
            )

        mu = mu[..., jnp.newaxis]
        mu, covariance_matrix = promote_shapes(mu, covariance_matrix)
        self.event_shape = jnp.shape(covariance_matrix)[-1:]
        self.batch_shape = lax.broadcast_shapes(
            jnp.shape(mu)[:-2], jnp.shape(covariance_matrix)[:-2]
        )

        self.mu = mu[..., 0].squeeze()
        self.covariance_matrix = covariance_matrix.squeeze()
예제 #2
0
 def __init__(self, a, b):
     self.event_shape = ()
     a, b = promote_shapes(a, b)
     batch_shape = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b))
     self.batch_shape = batch_shape
     self.a = jnp.broadcast_to(a, batch_shape)
     self.b = jnp.broadcast_to(b, batch_shape)
예제 #3
0
 def __init__(self, loc, scale):
     self.event_shape = ()
     loc, scale = promote_shapes(loc, scale)
     batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
     self.batch_shape = batch_shape
     self.loc = jnp.broadcast_to(loc, batch_shape)
     self.scale = jnp.broadcast_to(scale, batch_shape)
예제 #4
0
 def __init__(self, mu, sigma):
     self.event_shape = ()
     mu, sigma = promote_shapes(mu, sigma)
     batch_shape = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(sigma))
     self.batch_shape = batch_shape
     self.mu = jnp.broadcast_to(mu, batch_shape)
     self.sigma = jnp.broadcast_to(sigma, batch_shape)
예제 #5
0
    def __init__(self, lower, upper):
        self.support = constraints.integer_interval(lower, upper)

        self.event_shape = ()
        lower, upper = promote_shapes(lower, upper)
        batch_shape = lax.broadcast_shapes(jnp.shape(lower), jnp.shape(upper))
        self.batch_shape = batch_shape
        self.lower = jnp.broadcast_to(jnp.floor(lower), batch_shape)
        self.upper = jnp.broadcast_to(jnp.floor(upper), batch_shape)
예제 #6
0
파일: binomial.py 프로젝트: tblazina/mcx
    def __init__(self, p, n):
        self.support = constraints.integer_interval(0, n)

        self.event_shape = ()
        p, n = promote_shapes(p, n)
        batch_shape = lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
        self.batch_shape = batch_shape
        self.n = jnp.broadcast_to(n, batch_shape)
        self.p = jnp.broadcast_to(p, batch_shape)
예제 #7
0
 def __init__(self, n, a, b):
     self.support = constraints.integer_interval(0, n)
     self.event_shape = ()
     n, a, b = promote_shapes(n, a, b)
     batch_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(a),
                                        jnp.shape(b))
     self.batch_shape = batch_shape
     self.a = jnp.broadcast_to(a, batch_shape)
     self.b = jnp.broadcast_to(b, batch_shape)
     self.n = jnp.broadcast_to(n, batch_shape)