def test_init_to_value(): def model(): pyro.sample("x", dist.LogNormal(0, 1)) value = torch.randn(()).exp() * 10 kernel = NUTS(model, init_strategy=partial(init_to_value, values={"x": value})) kernel.setup(warmup_steps=10) assert_close(value, kernel.initial_params["x"].exp())
def test_init_strategy_smoke(init_strategy): def model(): pyro.sample("x", dist.LogNormal(0, 1)) kernel = NUTS(model, init_strategy=init_strategy) kernel.setup(warmup_steps=10)