def reparameterized_discrete_model(args, data): # Sample global parameters. rate_s, prob_i, rho = global_model(args.population) # Sequentially sample time-local variables. S_curr = torch.tensor(args.population - 1.0) I_curr = torch.tensor(1.0) for t, datum in enumerate(data): # Sample reparameterizing variables. # When reparameterizing to a factor graph, we ignored density via # .mask(False). Thus distributions are used only for initialization. S_prev, I_prev = S_curr, I_curr S_curr = pyro.sample("S_{}".format(t), dist.Binomial(args.population, 0.5).mask(False)) I_curr = pyro.sample("I_{}".format(t), dist.Binomial(args.population, 0.5).mask(False)) # Now we reverse the computation. S2I = S_prev - S_curr I2R = I_prev - I_curr + S2I pyro.sample( "S2I_{}".format(t), dist.ExtendedBinomial(S_prev, -(rate_s * I_prev).expm1()), obs=S2I, ) pyro.sample("I2R_{}".format(t), dist.ExtendedBinomial(I_prev, prob_i), obs=I2R) pyro.sample("obs_{}".format(t), dist.ExtendedBinomial(S2I, rho), obs=datum)
def continuous_model(args, data): # Sample global parameters. rate_s, prob_i, rho = global_model(args.population) # Sample reparameterizing variables. S_aux = pyro.sample("S_aux", dist.Uniform(-0.5, args.population + 0.5) .mask(False).expand(data.shape).to_event(1)) I_aux = pyro.sample("I_aux", dist.Uniform(-0.5, args.population + 0.5) .mask(False).expand(data.shape).to_event(1)) # Sequentially sample time-local variables. S_curr = torch.tensor(args.population - 1.) I_curr = torch.tensor(1.) for t, datum in poutine.markov(enumerate(data)): S_prev, I_prev = S_curr, I_curr S_curr = quantize("S_{}".format(t), S_aux[..., t], min=0, max=args.population) I_curr = quantize("I_{}".format(t), I_aux[..., t], min=0, max=args.population) # Now we reverse the computation. S2I = S_prev - S_curr I2R = I_prev - I_curr + S2I pyro.sample("S2I_{}".format(t), dist.ExtendedBinomial(S_prev, -(rate_s * I_prev).expm1()), obs=S2I) pyro.sample("I2R_{}".format(t), dist.ExtendedBinomial(I_prev, prob_i), obs=I2R) pyro.sample("obs_{}".format(t), dist.ExtendedBinomial(S2I, rho), obs=datum)
def test_extended_binomial(tol): with set_approx_log_prob_tol(tol): total_count = torch.tensor([0.0, 1.0, 2.0, 10.0]) probs = torch.tensor([0.5, 0.5, 0.4, 0.2]).requires_grad_() d1 = dist.Binomial(total_count, probs) d2 = dist.ExtendedBinomial(total_count, probs) # Check on good data. data = d1.sample((100, )) assert_equal(d1.log_prob(data), d2.log_prob(data)) # Check on extended data. data = torch.arange(-10.0, 20.0).unsqueeze(-1) with pytest.raises(ValueError): d1.log_prob(data) log_prob = d2.log_prob(data) valid = d1.support.check(data) assert ((log_prob > -math.inf) == valid).all() check_grad(log_prob, probs) # Check on shape error. with pytest.raises(ValueError): d2.log_prob(torch.tensor([0.0, 0.0])) # Check on value error. with pytest.raises(ValueError): d2.log_prob(torch.tensor(0.5)) # Check on negative total_count. total_count = torch.arange(-10, 0.0) probs = torch.tensor(0.5).requires_grad_() d = dist.ExtendedBinomial(total_count, probs) log_prob = d.log_prob(data) assert (log_prob == -math.inf).all() check_grad(log_prob, probs)
def vectorized_model(args, data): # Sample global parameters. rate_s, prob_i, rho = global_model(args.population) # Sample reparameterizing variables. S_aux = pyro.sample( "S_aux", dist.Uniform(-0.5, args.population + 0.5).mask(False).expand( data.shape).to_event(1), ) I_aux = pyro.sample( "I_aux", dist.Uniform(-0.5, args.population + 0.5).mask(False).expand( data.shape).to_event(1), ) # Manually enumerate. S_curr, S_logp = quantize_enumerate(S_aux, min=0, max=args.population) I_curr, I_logp = quantize_enumerate(I_aux, min=0, max=args.population) # Truncate final value from the right then pad initial value onto the left. S_prev = torch.nn.functional.pad(S_curr[:-1], (0, 0, 1, 0), value=args.population - 1) I_prev = torch.nn.functional.pad(I_curr[:-1], (0, 0, 1, 0), value=1) # Reshape to support broadcasting, similar to EnumMessenger. T = len(data) Q = 4 S_prev = S_prev.reshape(T, Q, 1, 1, 1) I_prev = I_prev.reshape(T, 1, Q, 1, 1) S_curr = S_curr.reshape(T, 1, 1, Q, 1) S_logp = S_logp.reshape(T, 1, 1, Q, 1) I_curr = I_curr.reshape(T, 1, 1, 1, Q) I_logp = I_logp.reshape(T, 1, 1, 1, Q) data = data.reshape(T, 1, 1, 1, 1) # Reverse the S2I,I2R computation. S2I = S_prev - S_curr I2R = I_prev - I_curr + S2I # Compute probability factors. S2I_logp = dist.ExtendedBinomial(S_prev, -(rate_s * I_prev).expm1()).log_prob(S2I) I2R_logp = dist.ExtendedBinomial(I_prev, prob_i).log_prob(I2R) obs_logp = dist.ExtendedBinomial(S2I, rho).log_prob(data) # Manually perform variable elimination. logp = S_logp + (I_logp + obs_logp) + S2I_logp + I2R_logp logp = logp.reshape(-1, Q * Q, Q * Q) logp = pyro.distributions.hmm._sequential_logmatmulexp(logp) logp = logp.reshape(-1).logsumexp(0) logp = logp - math.log(4) # Account for S,I initial distributions. warn_if_nan(logp) pyro.factor("obs", logp)
def binomial_dist(total_count, probs, *, overdispersion=0.0): """ Returns a Beta-Binomial distribution that is an overdispersed version of a Binomial distribution, according to a parameter ``overdispersion``, typically set in the range 0.1 to 0.5. This is useful for (1) fitting real data that is overdispersed relative to a Binomial distribution, and (2) relaxing models of large populations to improve inference. In particular the ``overdispersion`` parameter lower bounds the relative uncertainty in stochastic models such that increasing population leads to a limiting scale-free dynamical system with bounded stochasticity, in contrast to Binomial-based SDEs that converge to deterministic ODEs in the large population limit. This parameterization satisfies the following properties: 1. Variance increases monotonically in ``overdispersion``. 2. ``overdispersion = 0`` results in a Binomial distribution. 3. ``overdispersion`` lower bounds the relative uncertainty ``std_dev / (total_count * p * q)``, where ``probs = p = 1 - q``, and serves as an asymptote for relative uncertainty as ``total_count → ∞``. This contrasts the Binomial whose relative uncertainty tends to zero. 4. If ``X ~ binomial_dist(n, p, overdispersion=σ)`` then in the large population limit ``n → ∞``, the scaled random variable ``X / n`` converges in distribution to ``LogitNormal(log(p/(1-p)), σ)``. To achieve these properties we set ``p = probs``, ``q = 1 - p``, and:: concentration = 1 / (p * q * overdispersion**2) - 1 :param total_count: Number of Bernoulli trials. :type total_count: int or torch.Tensor :param probs: Event probabilities. :type probs: float or torch.Tensor :param overdispersion: Amount of overdispersion, in the half open interval [0,2). Defaults to zero. :type overdispersion: float or torch.tensor """ _validate_overdispersion(overdispersion) if _is_zero(overdispersion): if _RELAX: return _relaxed_binomial(total_count, probs) return dist.ExtendedBinomial(total_count, probs) p = probs q = 1 - p od2 = (overdispersion + 1e-8)**2 concentration1 = 1 / (q * od2 + 1e-8) - p concentration0 = 1 / (p * od2 + 1e-8) - q # At this point we have # concentration1 + concentration0 == 1 / (p + q + od2 + 1e-8) - 1 if _RELAX: return _relaxed_beta_binomial(concentration1, concentration0, total_count) return dist.ExtendedBetaBinomial(concentration1, concentration0, total_count)
def discrete_model(args, data): # Sample global parameters. rate_s, prob_i, rho = global_model(args.population) # Sequentially sample time-local variables. S = torch.tensor(args.population - 1.0) I = torch.tensor(1.0) for t, datum in enumerate(data): S2I = pyro.sample("S2I_{}".format(t), dist.Binomial(S, -(rate_s * I).expm1())) I2R = pyro.sample("I2R_{}".format(t), dist.Binomial(I, prob_i)) S = pyro.deterministic("S_{}".format(t), S - S2I) I = pyro.deterministic("I_{}".format(t), I + S2I - I2R) pyro.sample("obs_{}".format(t), dist.ExtendedBinomial(S2I, rho), obs=datum)