示例#1
0
def main(args):
    pyro.set_rng_seed(args.seed)
    pyro.enable_validation(__debug__)

    model = SimpleHarmonicModel(args.process_noise, args.measurement_noise)
    guide = SimpleHarmonicModel_Guide(model)

    smc = SMCFilter(model,
                    guide,
                    num_particles=args.num_particles,
                    max_plate_nesting=0)

    logging.info('Generating data')
    zs, ys = generate_data(args)

    logging.info('Filtering')
    smc.init(initial=torch.tensor([1., 0.]))
    for y in ys[1:]:
        smc.step(y)

    logging.info('Marginals')
    empirical = smc.get_empirical()
    for t in range(1, 1 + args.num_timesteps):
        z = empirical["z_{}".format(t)]
        logging.info("{}\t{}\t{}\t{}".format(t, zs[t], z.mean, z.variance))
示例#2
0
def main(args):
    pyro.set_rng_seed(args.seed)

    model = SimpleHarmonicModel(args.process_noise, args.measurement_noise)
    guide = SimpleHarmonicModel_Guide(model)

    smc = SMCFilter(model,
                    guide,
                    num_particles=args.num_particles,
                    max_plate_nesting=0)

    logging.info("Generating data")
    zs, ys = generate_data(args)

    logging.info("Filtering")

    smc.init(initial=torch.tensor([1., 0.]))
    for y in ys[1:]:
        smc.step(y)

    logging.info("At final time step:")
    z = smc.get_empirical()["z"]
    logging.info("truth: {}".format(zs[-1]))
    logging.info("mean: {}".format(z.mean))
    logging.info("std: {}".format(z.variance**0.5))
示例#3
0
def test_smoke(max_plate_nesting, state_size, plate_size, num_steps):
    model = SmokeModel(state_size, plate_size)
    guide = SmokeGuide(state_size, plate_size)

    smc = SMCFilter(model,
                    guide,
                    num_particles=100,
                    max_plate_nesting=max_plate_nesting)

    true_model = SmokeModel(state_size, plate_size)

    true_model.init()
    truth = [true_model.step() for t in range(num_steps)]

    smc.init()
    for xy in truth:
        smc.step(*xy)
    smc.get_values_and_log_weights()
    smc.get_empirical()
示例#4
0
def test_smoke(max_plate_nesting, state_size, plate_size, num_steps):
    model = SmokeModel(state_size, plate_size)
    guide = SmokeGuide(state_size, plate_size)

    smc = SMCFilter(model,
                    guide,
                    num_particles=100,
                    max_plate_nesting=max_plate_nesting)

    true_model = SmokeModel(state_size, plate_size)

    state = {}
    true_model.init(state)
    truth = [true_model.step(state) for t in range(num_steps)]

    smc.init()
    assert set(smc.state) == {"x_mean", "y_mean"}
    for x, y in truth:
        smc.step(x, y)
    assert set(smc.state) == {"x_mean", "y_mean"}
    smc.get_empirical()
示例#5
0
def test_gaussian_filter():
    dim = 4
    init_dist = dist.MultivariateNormal(torch.zeros(dim),
                                        scale_tril=torch.eye(dim) * 10)
    trans_mat = torch.eye(dim)
    trans_dist = dist.MultivariateNormal(torch.zeros(dim),
                                         scale_tril=torch.eye(dim))
    obs_mat = torch.eye(dim)
    obs_dist = dist.MultivariateNormal(torch.zeros(dim),
                                       scale_tril=torch.eye(dim) * 2)
    hmm = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)

    class Model:
        def init(self, state):
            state["z"] = pyro.sample("z_init", init_dist)
            self.t = 0

        def step(self, state, datum=None):
            state["z"] = pyro.sample(
                "z_{}".format(self.t),
                dist.MultivariateNormal(state["z"],
                                        scale_tril=trans_dist.scale_tril))
            datum = pyro.sample(
                "obs_{}".format(self.t),
                dist.MultivariateNormal(state["z"],
                                        scale_tril=obs_dist.scale_tril),
                obs=datum)
            self.t += 1
            return datum

    class Guide:
        def init(self, state):
            pyro.sample("z_init", init_dist)
            self.t = 0

        def step(self, state, datum):
            pyro.sample(
                "z_{}".format(self.t),
                dist.MultivariateNormal(state["z"],
                                        scale_tril=trans_dist.scale_tril * 2))
            self.t += 1

    # Generate data.
    num_steps = 20
    model = Model()
    state = {}
    model.init(state)
    data = torch.stack([model.step(state) for _ in range(num_steps)])

    # Perform inference.
    model = Model()
    guide = Guide()
    smc = SMCFilter(model, guide, num_particles=1000, max_plate_nesting=0)
    smc.init()
    for t, datum in enumerate(data):
        smc.step(datum)
        expected = hmm.filter(data[:1 + t])
        actual = smc.get_empirical()["z"]
        assert_close(actual.variance**0.5,
                     expected.variance**0.5,
                     atol=0.1,
                     rtol=0.5)
        sigma = actual.variance.max().item()**0.5
        assert_close(actual.mean, expected.mean, atol=3 * sigma)