Beispiel #1
0
    def transition_fn(carry, t):
        lam_prev, x_prev = carry

        gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev))

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask(
            mask[t][..., None])
        with npyro.plate('subjects', N):
            y = npyro.sample(
                'y', dist.MixtureSameFamily(mixing_dist, component_dist))
            noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return (lam_next, x_next), None
Beispiel #2
0
    def transition_fn(carry, t):

        logs = logits((beliefs[0][t], beliefs[1][t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))
        npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t]))

        return None, None
Beispiel #3
0
    def transition_fn(carry, t):

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs).mask(mask[t])
        npyro.sample('y', dist.MixtureSameFamily(mixing_dist, component_dist))

        return None, None
Beispiel #4
0
    def transition_fn(carry, t):
        x_prev = carry

        dyn_gamma = npyro.deterministic('dyn_gamma', nn.softplus(mu + x_prev))

        logs = logits((beliefs[0][t], beliefs[1][t]),
                      jnp.expand_dims(dyn_gamma, -1), jnp.expand_dims(U, -2))

        npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t]))
        noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return x_next, None
Beispiel #5
0
    def transition_fn(carry, t):
        lam_prev = carry

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][t], beliefs[1][t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t]))

        return lam_next, None
Beispiel #6
0
def generative_model(beliefs, y=None, mask=True):
    # generative model
    T, N = beliefs[0].shape[:2]
    with npyro.plate('N', N):
        gamma = npyro.sample('gamma', dist.Gamma(20., 2.))

        td = TransformedDistribution(
            dist.Normal(jnp.array([-1., .7]), jnp.array([1.,
                                                         0.2])).to_event(1),
            transforms.OrderedTransform())

        lams = npyro.sample('lams', td)
        U = jnp.pad(lams, ((0, 0), (0, 2)), 'constant', constant_values=(0, ))

        with npyro.plate('T', T):
            logs = npyro.deterministic(
                'logits',
                logits(beliefs, jnp.expand_dims(gamma, -1),
                       jnp.expand_dims(U, -2)))
            npyro.sample('y', dist.CategoricalLogits(logs).mask(mask), obs=y)
Beispiel #7
0
    def transition_fn(carry, t):
        x_prev = carry

        gamma_dyn = npyro.deterministic('gamma_dyn', nn.softplus(mu + x_prev))

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma_dyn, -1), jnp.expand_dims(U, -2))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs).mask(mask[t])
        npyro.sample('y', dist.MixtureSameFamily(mixing_dist, component_dist))

        with npyro.handlers.reparam(
                config={"x_next": npyro.infer.reparam.TransformReparam()}):
            affine = dist.transforms.AffineTransform(rho * x_prev, sigma)
            x_next = npyro.sample(
                'x_next',
                dist.TransformedDistribution(dist.Normal(0., 1.), affine))

        return (x_next), None
Beispiel #8
0
    def transition_fn(carry, t):
        lam_prev, x_prev = carry

        gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev))

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][t], beliefs[1][t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t]))
        noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return (lam_next, x_next), None