Ejemplo n.º 1
0
def AutoMixed(model_full, init_loc={}, delta=None):
    guide = AutoGuideList(model_full)

    marginalised_guide_block = poutine.block(model_full,
                                             expose_all=True,
                                             hide_all=False,
                                             hide=['tau'])
    if delta is None:
        guide.append(
            AutoNormal(marginalised_guide_block,
                       init_loc_fn=autoguide.init_to_value(values=init_loc),
                       init_scale=0.05))
    elif delta == 'part' or delta == 'all':
        guide.append(
            AutoDelta(marginalised_guide_block,
                      init_loc_fn=autoguide.init_to_value(values=init_loc)))

    full_rank_guide_block = poutine.block(model_full,
                                          hide_all=True,
                                          expose=['tau'])
    if delta is None or delta == 'part':
        guide.append(
            AutoMultivariateNormal(
                full_rank_guide_block,
                init_loc_fn=autoguide.init_to_value(values=init_loc),
                init_scale=0.05))
    else:
        guide.append(
            AutoDelta(full_rank_guide_block,
                      init_loc_fn=autoguide.init_to_value(values=init_loc)))

    return guide
Ejemplo n.º 2
0
def _infer_hmc(args, data, model, init_values={}):
    logging.info("Running inference...")
    kernel = NUTS(model,
                  full_mass=[("R0", "rho")],
                  max_tree_depth=args.max_tree_depth,
                  init_strategy=init_to_value(values=init_values),
                  jit_compile=args.jit, ignore_jit_warnings=True)

    # We'll define a hook_fn to log potential energy values during inference.
    # This is helpful to diagnose whether the chain is mixing.
    energies = []

    def hook_fn(kernel, *unused):
        e = float(kernel._potential_energy_last)
        energies.append(e)
        if args.verbose:
            logging.info("potential = {:0.6g}".format(e))

    mcmc = MCMC(kernel, hook_fn=hook_fn,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps)
    mcmc.run(args, data)
    mcmc.summary()
    if args.plot:
        import matplotlib.pyplot as plt
        plt.figure(figsize=(6, 3))
        plt.plot(energies)
        plt.xlabel("MCMC step")
        plt.ylabel("potential energy")
        plt.title("MCMC energy trace")
        plt.tight_layout()

    samples = mcmc.get_samples()
    return samples
Ejemplo n.º 3
0
 def _heuristic(self, haar, **options):
     with poutine.block():
         init_values = self.heuristic(**options)
     assert isinstance(init_values, dict)
     assert "auxiliary" in init_values, \
         ".heuristic() did not define auxiliary value"
     if haar:
         haar.user_to_aux(init_values)
     logger.info("Heuristic init: {}".format(", ".join(
         "{}={:0.3g}".format(k, v.item())
         for k, v in sorted(init_values.items()) if v.numel() == 1)))
     return init_to_value(values=init_values)
Ejemplo n.º 4
0
    value = torch.randn(()).exp() * 10
    kernel = NUTS(model,
                  init_strategy=partial(init_to_value, values={"x": value}))
    kernel.setup(warmup_steps=10)
    assert_close(value, kernel.initial_params['x'].exp())


@pytest.mark.parametrize("init_strategy", [
    init_to_feasible,
    init_to_mean,
    init_to_median,
    init_to_sample,
    init_to_uniform,
    init_to_value,
    init_to_feasible(),
    init_to_mean(),
    init_to_median(num_samples=4),
    init_to_sample(),
    init_to_uniform(radius=0.1),
    init_to_value(values={"x": torch.tensor(3.)}),
    init_to_generated(
        generate=lambda: init_to_value(values={"x": torch.rand(())})),
],
                         ids=str_erase_pointers)
def test_init_strategy_smoke(init_strategy):
    def model():
        pyro.sample("x", dist.LogNormal(0, 1))

    kernel = NUTS(model, init_strategy=init_strategy)
    kernel.setup(warmup_steps=10)