def test_coalescent_likelihood_smoke(duration, forecast, options, algo): population = 100 incubation_time = 2.0 recovery_time = 7.0 # Generate data. model = SuperspreadingSEIRModel( population, incubation_time, recovery_time, [None] * duration) for attempt in range(100): data = model.generate({"R0": 1.5, "rho": 0.5, "k": 1.0})["obs"] if data.sum(): break assert data.sum() > 0, "failed to generate positive data" leaf_times = torch.rand(5).pow(0.5) * duration coal_times = dist.CoalescentTimes(leaf_times).sample() coal_times = coal_times[..., torch.randperm(coal_times.size(-1))] # Infer. model = SuperspreadingSEIRModel( population, incubation_time, recovery_time, data, leaf_times=leaf_times, coal_times=coal_times) num_samples = 5 if algo == "mcmc": model.fit_mcmc(warmup_steps=2, num_samples=num_samples, max_tree_depth=2, **options) else: model.fit_svi(num_steps=2, num_samples=num_samples, **options) # Predict and forecast. samples = model.predict(forecast=forecast) assert samples["S"].shape == (num_samples, duration + forecast) assert samples["E"].shape == (num_samples, duration + forecast) assert samples["I"].shape == (num_samples, duration + forecast)
def test_superspreading_seir_smoke(duration, forecast, options): population = 100 incubation_time = 2.0 recovery_time = 7.0 # Generate data. model = SuperspreadingSEIRModel( population, incubation_time, recovery_time, [None] * duration ) assert model.full_mass == [("R0", "k", "rho")] for attempt in range(100): data = model.generate({"R0": 1.5, "rho": 0.5, "k": 1.0})["obs"] if data.sum(): break assert data.sum() > 0, "failed to generate positive data" # Infer. model = SuperspreadingSEIRModel(population, incubation_time, recovery_time, data) num_samples = 5 model.fit_mcmc(warmup_steps=2, num_samples=num_samples, max_tree_depth=2, **options) # Predict and forecast. samples = model.predict(forecast=forecast) assert samples["S"].shape == (num_samples, duration + forecast) assert samples["E"].shape == (num_samples, duration + forecast) assert samples["I"].shape == (num_samples, duration + forecast)
def Model(args, data): """Dispatch between different model classes.""" if args.heterogeneous: assert args.incubation_time == 0 assert args.overdispersion == 0 return HeterogeneousSIRModel(args.population, args.recovery_time, data) elif args.incubation_time > 0: assert args.incubation_time > 1 if args.concentration < math.inf: return SuperspreadingSEIRModel(args.population, args.incubation_time, args.recovery_time, data) elif args.overdispersion > 0: return OverdispersedSEIRModel(args.population, args.incubation_time, args.recovery_time, data) else: return SimpleSEIRModel(args.population, args.incubation_time, args.recovery_time, data) else: if args.concentration < math.inf: return SuperspreadingSIRModel(args.population, args.recovery_time, data) elif args.overdispersion > 0: return OverdispersedSIRModel(args.population, args.recovery_time, data) else: return SimpleSIRModel(args.population, args.recovery_time, data)