def transition_fn(carry, t): lam_prev, x_prev = carry gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev)) U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][:, t], beliefs[1][:, t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) mixing_dist = dist.CategoricalProbs(weights) component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask( mask[t][..., None]) with npyro.plate('subjects', N): y = npyro.sample( 'y', dist.MixtureSameFamily(mixing_dist, component_dist)) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return (lam_next, x_next), None
def transition_fn(carry, t): logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) return None, None
def transition_fn(carry, t): logs = logits((beliefs[0][:, t], beliefs[1][:, t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) mixing_dist = dist.CategoricalProbs(weights) component_dist = dist.CategoricalLogits(logs).mask(mask[t]) npyro.sample('y', dist.MixtureSameFamily(mixing_dist, component_dist)) return None, None
def transition_fn(carry, t): x_prev = carry dyn_gamma = npyro.deterministic('dyn_gamma', nn.softplus(mu + x_prev)) logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(dyn_gamma, -1), jnp.expand_dims(U, -2)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return x_next, None
def transition_fn(carry, t): lam_prev = carry U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) return lam_next, None
def generative_model(beliefs, y=None, mask=True): # generative model T, N = beliefs[0].shape[:2] with npyro.plate('N', N): gamma = npyro.sample('gamma', dist.Gamma(20., 2.)) td = TransformedDistribution( dist.Normal(jnp.array([-1., .7]), jnp.array([1., 0.2])).to_event(1), transforms.OrderedTransform()) lams = npyro.sample('lams', td) U = jnp.pad(lams, ((0, 0), (0, 2)), 'constant', constant_values=(0, )) with npyro.plate('T', T): logs = npyro.deterministic( 'logits', logits(beliefs, jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask), obs=y)
def transition_fn(carry, t): x_prev = carry gamma_dyn = npyro.deterministic('gamma_dyn', nn.softplus(mu + x_prev)) logs = logits((beliefs[0][:, t], beliefs[1][:, t]), jnp.expand_dims(gamma_dyn, -1), jnp.expand_dims(U, -2)) mixing_dist = dist.CategoricalProbs(weights) component_dist = dist.CategoricalLogits(logs).mask(mask[t]) npyro.sample('y', dist.MixtureSameFamily(mixing_dist, component_dist)) with npyro.handlers.reparam( config={"x_next": npyro.infer.reparam.TransformReparam()}): affine = dist.transforms.AffineTransform(rho * x_prev, sigma) x_next = npyro.sample( 'x_next', dist.TransformedDistribution(dist.Normal(0., 1.), affine)) return (x_next), None
def transition_fn(carry, t): lam_prev, x_prev = carry gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev)) U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return (lam_next, x_next), None