コード例 #1
0
def main(args):
    pyro.set_rng_seed(args.seed)

    model = SimpleHarmonicModel(args.process_noise, args.measurement_noise)
    guide = SimpleHarmonicModel_Guide(model)

    smc = SMCFilter(model,
                    guide,
                    num_particles=args.num_particles,
                    max_plate_nesting=0)

    logging.info("Generating data")
    zs, ys = generate_data(args)

    logging.info("Filtering")

    smc.init(initial=torch.tensor([1., 0.]))
    for y in ys[1:]:
        smc.step(y)

    logging.info("At final time step:")
    z = smc.get_empirical()["z"]
    logging.info("truth: {}".format(zs[-1]))
    logging.info("mean: {}".format(z.mean))
    logging.info("std: {}".format(z.variance**0.5))
コード例 #2
0
def main(args):
    pyro.set_rng_seed(args.seed)
    pyro.enable_validation(__debug__)

    model = SimpleHarmonicModel(args.process_noise, args.measurement_noise)
    guide = SimpleHarmonicModel_Guide(model)

    smc = SMCFilter(model,
                    guide,
                    num_particles=args.num_particles,
                    max_plate_nesting=0)

    logging.info('Generating data')
    zs, ys = generate_data(args)

    logging.info('Filtering')
    smc.init(initial=torch.tensor([1., 0.]))
    for y in ys[1:]:
        smc.step(y)

    logging.info('Marginals')
    empirical = smc.get_empirical()
    for t in range(1, 1 + args.num_timesteps):
        z = empirical["z_{}".format(t)]
        logging.info("{}\t{}\t{}\t{}".format(t, zs[t], z.mean, z.variance))
コード例 #3
0
def test_smoke(max_plate_nesting, state_size, plate_size, num_steps):
    model = SmokeModel(state_size, plate_size)
    guide = SmokeGuide(state_size, plate_size)

    smc = SMCFilter(model,
                    guide,
                    num_particles=100,
                    max_plate_nesting=max_plate_nesting)

    true_model = SmokeModel(state_size, plate_size)

    true_model.init()
    truth = [true_model.step() for t in range(num_steps)]

    smc.init()
    for xy in truth:
        smc.step(*xy)
    smc.get_values_and_log_weights()
    smc.get_empirical()
コード例 #4
0
def test_smoke(max_plate_nesting, state_size, plate_size, num_steps):
    model = SmokeModel(state_size, plate_size)
    guide = SmokeGuide(state_size, plate_size)

    smc = SMCFilter(model,
                    guide,
                    num_particles=100,
                    max_plate_nesting=max_plate_nesting)

    true_model = SmokeModel(state_size, plate_size)

    state = {}
    true_model.init(state)
    truth = [true_model.step(state) for t in range(num_steps)]

    smc.init()
    assert set(smc.state) == {"x_mean", "y_mean"}
    for x, y in truth:
        smc.step(x, y)
    assert set(smc.state) == {"x_mean", "y_mean"}
    smc.get_empirical()
コード例 #5
0
    def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10):
        """
        Finds an initial feasible guess of all latent variables, consistent
        with observed data. This is needed because not all hypotheses are
        feasible and HMC needs to start at a feasible solution to progress.

        The default implementation attempts to find a feasible state using
        :class:`~pyro.infer.smcfilter.SMCFilter` with proprosals from the
        prior.  However this method may be overridden in cases where SMC
        performs poorly e.g. in high-dimensional models.

        :param int num_particles: Number of particles used for SMC.
        :param float ess_threshold: Effective sample size threshold for SMC.
        :returns: A dictionary mapping sample site name to tensor value.
        :rtype: dict
        """
        # Run SMC.
        model = _SMCModel(self)
        guide = _SMCGuide(self)
        for attempt in range(1, 1 + retries):
            smc = SMCFilter(model,
                            guide,
                            num_particles=num_particles,
                            ess_threshold=ess_threshold,
                            max_plate_nesting=self.max_plate_nesting)
            try:
                smc.init()
                for t in range(1, self.duration):
                    smc.step()
                break
            except SMCFailed as e:
                if attempt == retries:
                    raise
                logger.info("{}. Retrying...".format(e))
                continue

        # Select the most probable hypothesis.
        i = int(smc.state._log_weights.max(0).indices)
        init = {key: value[i, 0] for key, value in smc.state.items()}

        # Fill in sample site values.
        init = self.generate(init)
        aux = torch.stack([init[name] for name in self.compartments], dim=0)
        init["auxiliary"] = clamp(aux, min=0.5, max=self.population - 0.5)
        return init
コード例 #6
0
def test_likelihood_ratio():

    model = HarmonicModel()
    guide = HarmonicGuide()

    smc = SMCFilter(model, guide, num_particles=100, max_plate_nesting=0)

    zs, ys = generate_data()
    zs_true, ys_true = generate_data()
    smc.init()
    for y in ys_true[1:]:
        smc.step(y)
    values, logweights = smc.get_values_and_log_weights()
    i = logweights.max(0)[1]
    values = {k: v[i] for k, v in values.items()}

    zs_pred = [torch.tensor([1., 0.])]
    zs_pred += [values.get("z_{}".format(t)) for t in range(1, 51)]

    assert (score_latent(zs_true, ys_true) > score_latent(zs, ys_true))
    assert (score_latent(zs_pred, ys_true) > score_latent(zs_pred, ys))
    assert (score_latent(zs_pred, ys_true) > score_latent(zs, ys_true))
コード例 #7
0
def test_gaussian_filter():
    dim = 4
    init_dist = dist.MultivariateNormal(torch.zeros(dim),
                                        scale_tril=torch.eye(dim) * 10)
    trans_mat = torch.eye(dim)
    trans_dist = dist.MultivariateNormal(torch.zeros(dim),
                                         scale_tril=torch.eye(dim))
    obs_mat = torch.eye(dim)
    obs_dist = dist.MultivariateNormal(torch.zeros(dim),
                                       scale_tril=torch.eye(dim) * 2)
    hmm = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)

    class Model:
        def init(self, state):
            state["z"] = pyro.sample("z_init", init_dist)
            self.t = 0

        def step(self, state, datum=None):
            state["z"] = pyro.sample(
                "z_{}".format(self.t),
                dist.MultivariateNormal(state["z"],
                                        scale_tril=trans_dist.scale_tril))
            datum = pyro.sample(
                "obs_{}".format(self.t),
                dist.MultivariateNormal(state["z"],
                                        scale_tril=obs_dist.scale_tril),
                obs=datum)
            self.t += 1
            return datum

    class Guide:
        def init(self, state):
            pyro.sample("z_init", init_dist)
            self.t = 0

        def step(self, state, datum):
            pyro.sample(
                "z_{}".format(self.t),
                dist.MultivariateNormal(state["z"],
                                        scale_tril=trans_dist.scale_tril * 2))
            self.t += 1

    # Generate data.
    num_steps = 20
    model = Model()
    state = {}
    model.init(state)
    data = torch.stack([model.step(state) for _ in range(num_steps)])

    # Perform inference.
    model = Model()
    guide = Guide()
    smc = SMCFilter(model, guide, num_particles=1000, max_plate_nesting=0)
    smc.init()
    for t, datum in enumerate(data):
        smc.step(datum)
        expected = hmm.filter(data[:1 + t])
        actual = smc.get_empirical()["z"]
        assert_close(actual.variance**0.5,
                     expected.variance**0.5,
                     atol=0.1,
                     rtol=0.5)
        sigma = actual.variance.max().item()**0.5
        assert_close(actual.mean, expected.mean, atol=3 * sigma)