def test_coalescent_likelihood_smoke(duration, forecast, options, algo): population = 100 incubation_time = 2.0 recovery_time = 7.0 # Generate data. model = SuperspreadingSEIRModel( population, incubation_time, recovery_time, [None] * duration) for attempt in range(100): data = model.generate({"R0": 1.5, "rho": 0.5, "k": 1.0})["obs"] if data.sum(): break assert data.sum() > 0, "failed to generate positive data" leaf_times = torch.rand(5).pow(0.5) * duration coal_times = dist.CoalescentTimes(leaf_times).sample() coal_times = coal_times[..., torch.randperm(coal_times.size(-1))] # Infer. model = SuperspreadingSEIRModel( population, incubation_time, recovery_time, data, leaf_times=leaf_times, coal_times=coal_times) num_samples = 5 if algo == "mcmc": model.fit_mcmc(warmup_steps=2, num_samples=num_samples, max_tree_depth=2, **options) else: model.fit_svi(num_steps=2, num_samples=num_samples, **options) # Predict and forecast. samples = model.predict(forecast=forecast) assert samples["S"].shape == (num_samples, duration + forecast) assert samples["E"].shape == (num_samples, duration + forecast) assert samples["I"].shape == (num_samples, duration + forecast)
def random_phylo_logits(num_leaves, dtype): # Construct a random phylogenetic problem. leaf_times = torch.randn(num_leaves, dtype=dtype) coal_times = dist.CoalescentTimes(leaf_times).sample() times = torch.cat([leaf_times, coal_times]).requires_grad_() assert times.dtype == dtype # Convert to a one-two-matching problem. ids = torch.arange(len(times)) root = times.min(0).indices.item() sources = torch.cat([ids[:root], ids[root + 1:]]) destins = ids[num_leaves:] dt = times[sources][:, None] - times[destins] dt = dt * 10 / dt.detach().std() logits = torch.where(dt > 0, -dt, dt.new_tensor(-math.inf)) assert logits.dtype == dtype logits.data += torch.empty_like(logits).uniform_() # add jitter return logits, times