Esempio n. 1
0
    def __init__(self, lower, upper):
        self.support = constraints.integer_interval(lower, upper)

        self.event_shape = ()
        self.batch_shape = broadcast_batch_shape(jnp.shape(lower), jnp.shape(upper))
        self.lower = jnp.floor(lower)
        self.upper = jnp.floor(upper)
Esempio n. 2
0
    def __init__(self, p, n):
        self.support = constraints.integer_interval(0, n)

        self.event_shape = ()
        self.batch_shape = broadcast_batch_shape(np.shape(p), np.shape(n))
        self.n = n
        self.p = p
Esempio n. 3
0
    def __init__(self, lower, upper):
        self.support = constraints.integer_interval(lower, upper)

        self.event_shape = ()
        lower, upper = promote_shapes(lower, upper)
        batch_shape = lax.broadcast_shapes(jnp.shape(lower), jnp.shape(upper))
        self.batch_shape = batch_shape
        self.lower = jnp.broadcast_to(jnp.floor(lower), batch_shape)
        self.upper = jnp.broadcast_to(jnp.floor(upper), batch_shape)
Esempio n. 4
0
    def __init__(self, p, n):
        self.support = constraints.integer_interval(0, n)

        self.event_shape = ()
        p, n = promote_shapes(p, n)
        batch_shape = lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
        self.batch_shape = batch_shape
        self.n = jnp.broadcast_to(n, batch_shape)
        self.p = jnp.broadcast_to(p, batch_shape)
Esempio n. 5
0
 def __init__(self, n, a, b):
     self.support = constraints.integer_interval(0, n)
     self.event_shape = ()
     n, a, b = promote_shapes(n, a, b)
     batch_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(a),
                                        jnp.shape(b))
     self.batch_shape = batch_shape
     self.a = jnp.broadcast_to(a, batch_shape)
     self.b = jnp.broadcast_to(b, batch_shape)
     self.n = jnp.broadcast_to(n, batch_shape)
Esempio n. 6
0
 def __init__(self, probs):
     self.support = constraints.integer_interval(0,
                                                 jnp.shape(probs)[-1] - 1)
     self.event_shape = ()
     self.batch_shape = jnp.shape(probs)[:-1]
     self.probs = probs