예제 #1
0
def test_coalescent_likelihood_smoke(duration, forecast, options, algo):
    population = 100
    incubation_time = 2.0
    recovery_time = 7.0

    # Generate data.
    model = SuperspreadingSEIRModel(
        population, incubation_time, recovery_time, [None] * duration)
    for attempt in range(100):
        data = model.generate({"R0": 1.5, "rho": 0.5, "k": 1.0})["obs"]
        if data.sum():
            break
    assert data.sum() > 0, "failed to generate positive data"
    leaf_times = torch.rand(5).pow(0.5) * duration
    coal_times = dist.CoalescentTimes(leaf_times).sample()
    coal_times = coal_times[..., torch.randperm(coal_times.size(-1))]

    # Infer.
    model = SuperspreadingSEIRModel(
        population, incubation_time, recovery_time, data,
        leaf_times=leaf_times, coal_times=coal_times)
    num_samples = 5
    if algo == "mcmc":
        model.fit_mcmc(warmup_steps=2, num_samples=num_samples, max_tree_depth=2,
                       **options)
    else:
        model.fit_svi(num_steps=2, num_samples=num_samples, **options)

    # Predict and forecast.
    samples = model.predict(forecast=forecast)
    assert samples["S"].shape == (num_samples, duration + forecast)
    assert samples["E"].shape == (num_samples, duration + forecast)
    assert samples["I"].shape == (num_samples, duration + forecast)
예제 #2
0
def random_phylo_logits(num_leaves, dtype):
    # Construct a random phylogenetic problem.
    leaf_times = torch.randn(num_leaves, dtype=dtype)
    coal_times = dist.CoalescentTimes(leaf_times).sample()
    times = torch.cat([leaf_times, coal_times]).requires_grad_()
    assert times.dtype == dtype

    # Convert to a one-two-matching problem.
    ids = torch.arange(len(times))
    root = times.min(0).indices.item()
    sources = torch.cat([ids[:root], ids[root + 1:]])
    destins = ids[num_leaves:]
    dt = times[sources][:, None] - times[destins]
    dt = dt * 10 / dt.detach().std()
    logits = torch.where(dt > 0, -dt, dt.new_tensor(-math.inf))
    assert logits.dtype == dtype
    logits.data += torch.empty_like(logits).uniform_()  # add jitter

    return logits, times