コード例 #1
0
 def model(key):
     keys = jax.random.split(key)
     weight = func.sample('weight', dist.normal(jnp.array(0.),
                                                jnp.array(1.)), keys[0])
     measurement = func.sample('measurement',
                               dist.normal(weight, jnp.array(1.)), keys[1])
     return measurement
コード例 #2
0
 def model(key):
     keys = jax.random.split(key, 2)
     n1 = func.sample('n1',
                      dist.normal(jnp.zeros((2, 2)), jnp.full((2, 2), 1.)),
                      keys[0])
     n2 = func.sample('n2', dist.normal(n1, jnp.ones((2, 2))), keys[1])
     return n2
コード例 #3
0
 def model2(key):
     return func.sample(
         'n',
         dist.categorical(
             jnp.stack([
                 jnp.full((2, 2), 0.1),
                 jnp.full((2, 2), 0.8),
                 jnp.full((2, 2), 0.1)
             ], -1)), key)
コード例 #4
0
 def model2(key):
     return func.sample(
         'n',
         dist.multinomial(
             jnp.full((2, 2), 10),
             jnp.stack([
                 jnp.full((2, 2), 0.1),
                 jnp.full((2, 2), 0.8),
                 jnp.full((2, 2), 0.1)
             ], -1)), key)
コード例 #5
0
 def model(key):
     n1 = func.sample('n1', dist.dirichlet(jnp.array([1.0, 0.5])), key)
     return n1
コード例 #6
0
 def model3(key):
     return func.sample('n', dist.dirichlet(jnp.full((2, 3), 0.5)), key)
コード例 #7
0
 def proposal(key, **current):
     n4 = func.sample('n4', dist.normal(current['n4'], jnp.array(1.)), key)
     return {'n4': n4}
コード例 #8
0
 def model2(key):
     n = func.sample('n', dist.uniform((2, 2)), key)
     return n
コード例 #9
0
ファイル: test_log_prob.py プロジェクト: branislav1991/piper
 def model(key):
     keys = jax.random.split(key)
     n1 = func.sample('n1', dist.normal(jnp.array(0.), jnp.array(1.)),
                      keys[0])
     n2 = func.sample('n2', dist.normal(n1, jnp.array(1.)), keys[1])
     return n2
コード例 #10
0
 def model(key):
     keys = jax.random.split(key)
     n1 = func.sample('n1', dist.bernoulli(jnp.array(2.0)), keys[1])
     return n1
コード例 #11
0
 def model(key):
     keys = jax.random.split(key)
     n1 = func.sample('n1', dist.bernoulli(jnp.array(0.5)), keys[0])
     n2 = func.sample('n2', dist.binomial(jnp.array(1), n1), keys[1])
     return n2
コード例 #12
0
 def model3(key):
     return func.sample(
         'n', dist.beta(jnp.full((2, 2), 0.5), jnp.full((2, 2), 0.5)), key)
コード例 #13
0
 def model(key):
     n1 = func.sample('n1', dist.beta(jnp.array(0.5), jnp.array(0.5)), key)
     return n1
コード例 #14
0
 def model2(key):
     return func.sample('n', dist.beta(jnp.array(2.), jnp.array(5.)), key)
コード例 #15
0
 def model1(key):
     return func.sample('n', dist.beta(jnp.array(0.5), jnp.array(0.5)), key)
コード例 #16
0
 def proposal(key, **current):
     n1 = func.sample('n1', dist.normal(current['n1'], jnp.array(5.)), key)
     return {'n1': n1}
コード例 #17
0
 def model(key):
     n1 = func.sample('n1', dist.categorical(jnp.array([0.1, 0.8, 0.1])),
                      key)
     return n1
コード例 #18
0
 def model3(key):
     return func.sample('n', dist.bernoulli(jnp.array(1.)), key)
コード例 #19
0
 def model1(key):
     return func.sample('n', dist.categorical(jnp.array([0.1, 0.8, 0.1])),
                        key)
コード例 #20
0
 def model4(key):
     return func.sample('n', dist.bernoulli(jnp.full((2, 2), 0.5)), key)
コード例 #21
0
 def model1(key):
     return func.sample(
         'n', dist.multinomial(jnp.array(10), jnp.array([0.1, 0.8, 0.1])),
         key)
コード例 #22
0
 def model4(key):
     return func.sample('n', dist.binomial(jnp.array(10), jnp.array(0.5)),
                        key)
コード例 #23
0
 def model(key):
     keys = jax.random.split(key, 3)
     n1 = func.sample('n1', dist.categorical(jnp.array([0.5, 1.5])),
                      keys[2])
     return n1
コード例 #24
0
 def model(key):
     n1 = func.sample('n1', dist.normal(jnp.array(10.), jnp.array(10.)),
                      key)
     return n1
コード例 #25
0
 def model1(key):
     n = func.sample('n', dist.uniform(), key)
     return n
コード例 #26
0
 def model(key):
     keys = jax.random.split(key, 2)
     n1 = func.sample('n1', dist.beta(jnp.array(0.5), jnp.array(0.5)),
                      keys[0])
     n2 = func.sample('n2', dist.bernoulli(n1), keys[1])
     return n2
コード例 #27
0
 def model1(key):
     return func.sample('n', dist.dirichlet(jnp.array([0.5, 0.5])), key)
コード例 #28
0
 def proposal(key, **current):  # use prior as proposal
     n1 = func.sample('n1', dist.beta(jnp.array(0.5), jnp.array(0.5)), key)
     return {'n1': n1}
コード例 #29
0
 def q(params, key):
     n1 = func.sample('n1', dist.normal(params['n1_mean'],
                                        params['n1_std']), key)
     return {'n1': n1}
コード例 #30
0
 def model2(key):
     return func.sample('n', dist.dirichlet(jnp.array([2., 3., 0.5])), key)