def main(args): pyro.set_rng_seed(args.seed) model = SimpleHarmonicModel(args.process_noise, args.measurement_noise) guide = SimpleHarmonicModel_Guide(model) smc = SMCFilter(model, guide, num_particles=args.num_particles, max_plate_nesting=0) logging.info("Generating data") zs, ys = generate_data(args) logging.info("Filtering") smc.init(initial=torch.tensor([1., 0.])) for y in ys[1:]: smc.step(y) logging.info("At final time step:") z = smc.get_empirical()["z"] logging.info("truth: {}".format(zs[-1])) logging.info("mean: {}".format(z.mean)) logging.info("std: {}".format(z.variance**0.5))
def main(args): pyro.set_rng_seed(args.seed) pyro.enable_validation(__debug__) model = SimpleHarmonicModel(args.process_noise, args.measurement_noise) guide = SimpleHarmonicModel_Guide(model) smc = SMCFilter(model, guide, num_particles=args.num_particles, max_plate_nesting=0) logging.info('Generating data') zs, ys = generate_data(args) logging.info('Filtering') smc.init(initial=torch.tensor([1., 0.])) for y in ys[1:]: smc.step(y) logging.info('Marginals') empirical = smc.get_empirical() for t in range(1, 1 + args.num_timesteps): z = empirical["z_{}".format(t)] logging.info("{}\t{}\t{}\t{}".format(t, zs[t], z.mean, z.variance))
def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10): """ Finds an initial feasible guess of all latent variables, consistent with observed data. This is needed because not all hypotheses are feasible and HMC needs to start at a feasible solution to progress. The default implementation attempts to find a feasible state using :class:`~pyro.infer.smcfilter.SMCFilter` with proprosals from the prior. However this method may be overridden in cases where SMC performs poorly e.g. in high-dimensional models. :param int num_particles: Number of particles used for SMC. :param float ess_threshold: Effective sample size threshold for SMC. :returns: A dictionary mapping sample site name to tensor value. :rtype: dict """ # Run SMC. model = _SMCModel(self) guide = _SMCGuide(self) for attempt in range(1, 1 + retries): smc = SMCFilter(model, guide, num_particles=num_particles, ess_threshold=ess_threshold, max_plate_nesting=self.max_plate_nesting) try: smc.init() for t in range(1, self.duration): smc.step() break except SMCFailed as e: if attempt == retries: raise logger.info("{}. Retrying...".format(e)) continue # Select the most probable hypothesis. i = int(smc.state._log_weights.max(0).indices) init = {key: value[i, 0] for key, value in smc.state.items()} # Fill in sample site values. init = self.generate(init) aux = torch.stack([init[name] for name in self.compartments], dim=0) init["auxiliary"] = clamp(aux, min=0.5, max=self.population - 0.5) return init
def test_smoke(max_plate_nesting, state_size, plate_size, num_steps): model = SmokeModel(state_size, plate_size) guide = SmokeGuide(state_size, plate_size) smc = SMCFilter(model, guide, num_particles=100, max_plate_nesting=max_plate_nesting) true_model = SmokeModel(state_size, plate_size) true_model.init() truth = [true_model.step() for t in range(num_steps)] smc.init() for xy in truth: smc.step(*xy) smc.get_values_and_log_weights() smc.get_empirical()
def test_smoke(max_plate_nesting, state_size, plate_size, num_steps): model = SmokeModel(state_size, plate_size) guide = SmokeGuide(state_size, plate_size) smc = SMCFilter(model, guide, num_particles=100, max_plate_nesting=max_plate_nesting) true_model = SmokeModel(state_size, plate_size) state = {} true_model.init(state) truth = [true_model.step(state) for t in range(num_steps)] smc.init() assert set(smc.state) == {"x_mean", "y_mean"} for x, y in truth: smc.step(x, y) assert set(smc.state) == {"x_mean", "y_mean"} smc.get_empirical()
def test_likelihood_ratio(): model = HarmonicModel() guide = HarmonicGuide() smc = SMCFilter(model, guide, num_particles=100, max_plate_nesting=0) zs, ys = generate_data() zs_true, ys_true = generate_data() smc.init() for y in ys_true[1:]: smc.step(y) i = smc.state._log_weights.max(0)[1] values = {k: v[i] for k, v in smc.state.items()} zs_pred = [torch.tensor([1., 0.])] zs_pred += [values["z_{}".format(t)] for t in range(1, 51)] assert (score_latent(zs_true, ys_true) > score_latent(zs, ys_true)) assert (score_latent(zs_pred, ys_true) > score_latent(zs_pred, ys)) assert (score_latent(zs_pred, ys_true) > score_latent(zs, ys_true))
def test_gaussian_filter(): dim = 4 init_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=torch.eye(dim) * 10) trans_mat = torch.eye(dim) trans_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=torch.eye(dim)) obs_mat = torch.eye(dim) obs_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=torch.eye(dim) * 2) hmm = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) class Model: def init(self, state): state["z"] = pyro.sample("z_init", init_dist) self.t = 0 def step(self, state, datum=None): state["z"] = pyro.sample( "z_{}".format(self.t), dist.MultivariateNormal(state["z"], scale_tril=trans_dist.scale_tril)) datum = pyro.sample( "obs_{}".format(self.t), dist.MultivariateNormal(state["z"], scale_tril=obs_dist.scale_tril), obs=datum) self.t += 1 return datum class Guide: def init(self, state): pyro.sample("z_init", init_dist) self.t = 0 def step(self, state, datum): pyro.sample( "z_{}".format(self.t), dist.MultivariateNormal(state["z"], scale_tril=trans_dist.scale_tril * 2)) self.t += 1 # Generate data. num_steps = 20 model = Model() state = {} model.init(state) data = torch.stack([model.step(state) for _ in range(num_steps)]) # Perform inference. model = Model() guide = Guide() smc = SMCFilter(model, guide, num_particles=1000, max_plate_nesting=0) smc.init() for t, datum in enumerate(data): smc.step(datum) expected = hmm.filter(data[:1 + t]) actual = smc.get_empirical()["z"] assert_close(actual.variance**0.5, expected.variance**0.5, atol=0.1, rtol=0.5) sigma = actual.variance.max().item()**0.5 assert_close(actual.mean, expected.mean, atol=3 * sigma)