Beispiel #1
0
def test_coalescent_likelihood_smoke(duration, forecast, options, algo):
    population = 100
    incubation_time = 2.0
    recovery_time = 7.0

    # Generate data.
    model = SuperspreadingSEIRModel(
        population, incubation_time, recovery_time, [None] * duration)
    for attempt in range(100):
        data = model.generate({"R0": 1.5, "rho": 0.5, "k": 1.0})["obs"]
        if data.sum():
            break
    assert data.sum() > 0, "failed to generate positive data"
    leaf_times = torch.rand(5).pow(0.5) * duration
    coal_times = dist.CoalescentTimes(leaf_times).sample()
    coal_times = coal_times[..., torch.randperm(coal_times.size(-1))]

    # Infer.
    model = SuperspreadingSEIRModel(
        population, incubation_time, recovery_time, data,
        leaf_times=leaf_times, coal_times=coal_times)
    num_samples = 5
    if algo == "mcmc":
        model.fit_mcmc(warmup_steps=2, num_samples=num_samples, max_tree_depth=2,
                       **options)
    else:
        model.fit_svi(num_steps=2, num_samples=num_samples, **options)

    # Predict and forecast.
    samples = model.predict(forecast=forecast)
    assert samples["S"].shape == (num_samples, duration + forecast)
    assert samples["E"].shape == (num_samples, duration + forecast)
    assert samples["I"].shape == (num_samples, duration + forecast)
Beispiel #2
0
def test_superspreading_seir_smoke(duration, forecast, options):
    population = 100
    incubation_time = 2.0
    recovery_time = 7.0

    # Generate data.
    model = SuperspreadingSEIRModel(
        population, incubation_time, recovery_time, [None] * duration
    )
    assert model.full_mass == [("R0", "k", "rho")]
    for attempt in range(100):
        data = model.generate({"R0": 1.5, "rho": 0.5, "k": 1.0})["obs"]
        if data.sum():
            break
    assert data.sum() > 0, "failed to generate positive data"

    # Infer.
    model = SuperspreadingSEIRModel(population, incubation_time, recovery_time, data)
    num_samples = 5
    model.fit_mcmc(warmup_steps=2, num_samples=num_samples, max_tree_depth=2, **options)

    # Predict and forecast.
    samples = model.predict(forecast=forecast)
    assert samples["S"].shape == (num_samples, duration + forecast)
    assert samples["E"].shape == (num_samples, duration + forecast)
    assert samples["I"].shape == (num_samples, duration + forecast)
Beispiel #3
0
def Model(args, data):
    """Dispatch between different model classes."""
    if args.heterogeneous:
        assert args.incubation_time == 0
        assert args.overdispersion == 0
        return HeterogeneousSIRModel(args.population, args.recovery_time, data)
    elif args.incubation_time > 0:
        assert args.incubation_time > 1
        if args.concentration < math.inf:
            return SuperspreadingSEIRModel(args.population,
                                           args.incubation_time,
                                           args.recovery_time, data)
        elif args.overdispersion > 0:
            return OverdispersedSEIRModel(args.population,
                                          args.incubation_time,
                                          args.recovery_time, data)
        else:
            return SimpleSEIRModel(args.population, args.incubation_time,
                                   args.recovery_time, data)
    else:
        if args.concentration < math.inf:
            return SuperspreadingSIRModel(args.population, args.recovery_time,
                                          data)
        elif args.overdispersion > 0:
            return OverdispersedSIRModel(args.population, args.recovery_time,
                                         data)
        else:
            return SimpleSIRModel(args.population, args.recovery_time, data)