Example #1
0
def test_hmc_conjugate_gaussian(fixture,
                                num_samples,
                                warmup_steps,
                                hmc_params,
                                expected_means,
                                expected_precs,
                                mean_tol,
                                std_tol):
    pyro.get_param_store().clear()
    hmc_kernel = HMC(fixture.model, **hmc_params)
    mcmc_run = MCMC(hmc_kernel, num_samples, warmup_steps).run(fixture.data)
    for i in range(1, fixture.chain_len + 1):
        param_name = 'loc_' + str(i)
        marginal = EmpiricalMarginal(mcmc_run, sites=param_name)
        latent_loc = marginal.mean
        latent_std = marginal.variance.sqrt()
        expected_mean = torch.ones(fixture.dim) * expected_means[i - 1]
        expected_std = 1 / torch.sqrt(torch.ones(fixture.dim) * expected_precs[i - 1])

        # Actual vs expected posterior means for the latents
        logger.info('Posterior mean (actual) - {}'.format(param_name))
        logger.info(latent_loc)
        logger.info('Posterior mean (expected) - {}'.format(param_name))
        logger.info(expected_mean)
        assert_equal(rmse(latent_loc, expected_mean).item(), 0.0, prec=mean_tol)

        # Actual vs expected posterior precisions for the latents
        logger.info('Posterior std (actual) - {}'.format(param_name))
        logger.info(latent_std)
        logger.info('Posterior std (expected) - {}'.format(param_name))
        logger.info(expected_std)
        assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol)
Example #2
0
def test_categorical_gradient_with_logits(init_tensor_type):
    p = Variable(init_tensor_type([-float('inf'), 0]), requires_grad=True)
    categorical = Categorical(logits=p)
    log_pdf = categorical.batch_log_pdf(Variable(init_tensor_type([0, 1])))
    log_pdf.sum().backward()
    assert_equal(log_pdf.data[0], 0)
    assert_equal(p.grad.data[0], 0)
Example #3
0
def test_decorator_interface_primitives():

    @poutine.trace
    def model():
        pyro.param("p", torch.zeros(1, requires_grad=True))
        pyro.sample("a", Bernoulli(torch.tensor([0.5])),
                    infer={"enumerate": "parallel"})
        pyro.sample("b", Bernoulli(torch.tensor([0.5])))

    tr = model.get_trace()
    assert isinstance(tr, poutine.Trace)
    assert tr.graph_type == "flat"

    @poutine.trace(graph_type="dense")
    def model():
        pyro.param("p", torch.zeros(1, requires_grad=True))
        pyro.sample("a", Bernoulli(torch.tensor([0.5])),
                    infer={"enumerate": "parallel"})
        pyro.sample("b", Bernoulli(torch.tensor([0.5])))

    tr = model.get_trace()
    assert isinstance(tr, poutine.Trace)
    assert tr.graph_type == "dense"

    tr2 = poutine.trace(poutine.replay(model, trace=tr)).get_trace()

    assert_equal(tr2.nodes["a"]["value"], tr.nodes["a"]["value"])
Example #4
0
 def test_mean_and_var(self):
     torch_samples = [dist.Delta(self.v).sample().detach().cpu().numpy()
                      for _ in range(self.n_samples)]
     torch_mean = np.mean(torch_samples)
     torch_var = np.var(torch_samples)
     assert_equal(torch_mean, self.analytic_mean)
     assert_equal(torch_var, self.analytic_var)
def test_batch_log_dims(dim, vs, one_hot, ps):
    batch_pdf_shape = (3,) + (1,) * dim
    expected_log_pdf = np.array(wrap_nested(list(np.log(ps)), dim-1)).reshape(*batch_pdf_shape)
    ps, vs = modify_params_using_dims(ps, vs, dim)
    support = dist.categorical.enumerate_support(ps, vs, one_hot=one_hot)
    batch_log_pdf = dist.categorical.batch_log_pdf(support, ps, vs, one_hot=one_hot)
    assert_equal(batch_log_pdf.data.cpu().numpy(), expected_log_pdf)
Example #6
0
def test_bern_elbo_gradient(enum_discrete, trace_graph):
    pyro.clear_param_store()
    num_particles = 2000

    def model():
        p = Variable(torch.Tensor([0.25]))
        pyro.sample("z", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", Variable(torch.Tensor([0.5]), requires_grad=True))
        pyro.sample("z", dist.Bernoulli(p))

    print("Computing gradients using surrogate loss")
    Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO
    elbo = Elbo(enum_discrete=enum_discrete,
                num_particles=(1 if enum_discrete else num_particles))
    with xfail_if_not_implemented():
        elbo.loss_and_grads(model, guide)
    params = sorted(pyro.get_param_store().get_all_param_names())
    assert params, "no params found"
    actual_grads = {name: pyro.param(name).grad.clone() for name in params}

    print("Computing gradients using finite difference")
    elbo = Trace_ELBO(num_particles=num_particles)
    expected_grads = finite_difference(lambda: elbo.loss(model, guide))

    for name in params:
        print("{} {}{}{}".format(name, "-" * 30, actual_grads[name].data,
                                 expected_grads[name].data))
    assert_equal(actual_grads, expected_grads, prec=0.1)
Example #7
0
def test_optimizers(factory):
    optim = factory()

    def model(loc, cov):
        x = pyro.param("x", torch.randn(2))
        y = pyro.param("y", torch.randn(3, 2))
        z = pyro.param("z", torch.randn(4, 2).abs(), constraint=constraints.greater_than(-1))
        pyro.sample("obs_x", dist.MultivariateNormal(loc, cov), obs=x)
        with pyro.iarange("y_iarange", 3):
            pyro.sample("obs_y", dist.MultivariateNormal(loc, cov), obs=y)
        with pyro.iarange("z_iarange", 4):
            pyro.sample("obs_z", dist.MultivariateNormal(loc, cov), obs=z)

    loc = torch.tensor([-0.5, 0.5])
    cov = torch.tensor([[1.0, 0.09], [0.09, 0.1]])
    for step in range(100):
        tr = poutine.trace(model).get_trace(loc, cov)
        loss = -tr.log_prob_sum()
        params = {name: pyro.param(name).unconstrained() for name in ["x", "y", "z"]}
        optim.step(loss, params)

    for name in ["x", "y", "z"]:
        actual = pyro.param(name)
        expected = loc.expand(actual.shape)
        assert_equal(actual, expected, prec=1e-2,
                     msg='{} in correct: {} vs {}'.format(name, actual, expected))
Example #8
0
def test_quantiles(auto_class, Elbo):

    def model():
        pyro.sample("x", dist.Normal(0.0, 1.0))
        pyro.sample("y", dist.LogNormal(0.0, 1.0))
        pyro.sample("z", dist.Beta(2.0, 2.0))

    guide = auto_class(model)
    infer = SVI(model, guide, Adam({'lr': 0.01}), Elbo(strict_enumeration_warning=False))
    for _ in range(100):
        infer.step()

    quantiles = guide.quantiles([0.1, 0.5, 0.9])
    median = guide.median()
    for name in ["x", "y", "z"]:
        assert_equal(median[name], quantiles[name][1])
    quantiles = {name: [v.item() for v in value] for name, value in quantiles.items()}

    assert -3.0 < quantiles["x"][0]
    assert quantiles["x"][0] + 1.0 < quantiles["x"][1]
    assert quantiles["x"][1] + 1.0 < quantiles["x"][2]
    assert quantiles["x"][2] < 3.0

    assert 0.01 < quantiles["y"][0]
    assert quantiles["y"][0] * 2.0 < quantiles["y"][1]
    assert quantiles["y"][1] * 2.0 < quantiles["y"][2]
    assert quantiles["y"][2] < 100.0

    assert 0.01 < quantiles["z"][0]
    assert quantiles["z"][0] + 0.1 < quantiles["z"][1]
    assert quantiles["z"][1] + 0.1 < quantiles["z"][2]
    assert quantiles["z"][2] < 0.99
Example #9
0
def test_iter_discrete_traces_vector(graph_type):
    pyro.clear_param_store()

    def model():
        p = pyro.param("p", Variable(torch.Tensor([[0.05], [0.15]])))
        ps = pyro.param("ps", Variable(torch.Tensor([[0.1, 0.2, 0.3, 0.4],
                                                     [0.4, 0.3, 0.2, 0.1]])))
        x = pyro.sample("x", dist.Bernoulli(p))
        y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
        assert x.size() == (2, 1)
        assert y.size() == (2, 1)
        return dict(x=x, y=y)

    traces = list(iter_discrete_traces(graph_type, model))

    p = pyro.param("p").data
    ps = pyro.param("ps").data
    assert len(traces) == 2 * ps.size(-1)

    for scale, trace in traces:
        x = trace.nodes["x"]["value"].data.squeeze().long()[0]
        y = trace.nodes["y"]["value"].data.squeeze().long()[0]
        expected_scale = torch.exp(dist.Bernoulli(p).log_pdf(x) *
                                   dist.Categorical(ps, one_hot=False).log_pdf(y))
        expected_scale = expected_scale.data.view(-1)[0]
        assert_equal(scale, expected_scale)
def test_compute_downstream_costs_iarange_reuse(dim1, dim2):
    guide_trace = poutine.trace(iarange_reuse_model_guide,
                                graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2)
    model_trace = poutine.trace(poutine.replay(iarange_reuse_model_guide, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
    assert dc_nodes == dc_nodes_brute

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])

    expected_c1 = model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']
    expected_c1 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum()
    expected_c1 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']
    expected_c1 += model_trace.nodes['obs']['log_prob']
    assert_equal(expected_c1, dc['c1'])
Example #11
0
def test_mask(batch_dim, event_dim, mask_dim):
    # Construct base distribution.
    shape = torch.Size([2, 3, 4, 5, 6][:batch_dim + event_dim])
    batch_shape = shape[:batch_dim]
    mask_shape = batch_shape[batch_dim - mask_dim:]
    base_dist = Bernoulli(0.1).expand_by(shape).independent(event_dim)

    # Construct masked distribution.
    mask = checker_mask(mask_shape)
    dist = base_dist.mask(mask)

    # Check shape.
    sample = base_dist.sample()
    assert dist.batch_shape == base_dist.batch_shape
    assert dist.event_shape == base_dist.event_shape
    assert sample.shape == sample.shape
    assert dist.log_prob(sample).shape == base_dist.log_prob(sample).shape

    # Check values.
    assert_equal(dist.mean, base_dist.mean)
    assert_equal(dist.variance, base_dist.variance)
    assert_equal(dist.log_prob(sample), base_dist.log_prob(sample) * mask)
    assert_equal(dist.score_parts(sample), base_dist.score_parts(sample) * mask, prec=0)
    if not dist.event_shape:
        assert_equal(dist.enumerate_support(), base_dist.enumerate_support())
Example #12
0
def test_bernoulli_with_logits_overflow_gradient(init_tensor_type):
    p = Variable(init_tensor_type([1e40]), requires_grad=True)
    bernoulli = Bernoulli(logits=p)
    log_pdf = bernoulli.batch_log_pdf(Variable(init_tensor_type([1])))
    log_pdf.sum().backward()
    assert_equal(log_pdf.data[0], 0)
    assert_equal(p.grad.data[0], 0)
Example #13
0
def test_bernoulli_underflow_gradient(init_tensor_type):
    p = Variable(init_tensor_type([0]), requires_grad=True)
    bernoulli = Bernoulli(sigmoid(p) * 0.0)
    log_pdf = bernoulli.batch_log_pdf(Variable(init_tensor_type([0])))
    log_pdf.sum().backward()
    assert_equal(log_pdf.data[0], 0)
    assert_equal(p.grad.data[0], 0)
Example #14
0
def test_unweighted_samples(batch_shape, sample_shape, dtype):
    empirical_dist = Empirical()
    for i in range(5):
        empirical_dist.add(torch.ones(batch_shape, dtype=dtype) * i)
    samples = empirical_dist.sample(sample_shape=sample_shape)
    assert_equal(samples.size(), sample_shape + batch_shape)
    assert_equal(set(samples.view(-1).tolist()), set(range(5)))
Example #15
0
def test_elbo_bern(quantity, enumerate1):
    pyro.clear_param_store()
    num_particles = 1 if enumerate1 else 10000
    prec = 0.001 if enumerate1 else 0.1
    q = pyro.param("q", torch.tensor(0.5, requires_grad=True))
    kl = kl_divergence(dist.Bernoulli(q), dist.Bernoulli(0.25))

    def model():
        with pyro.iarange("particles", num_particles):
            pyro.sample("z", dist.Bernoulli(0.25).expand_by([num_particles]))

    @config_enumerate(default=enumerate1)
    def guide():
        q = pyro.param("q")
        with pyro.iarange("particles", num_particles):
            pyro.sample("z", dist.Bernoulli(q).expand_by([num_particles]))

    elbo = TraceEnum_ELBO(max_iarange_nesting=1,
                          strict_enumeration_warning=any([enumerate1]))

    if quantity == "loss":
        actual = elbo.loss(model, guide) / num_particles
        expected = kl.item()
        assert_equal(actual, expected, prec=prec, msg="".join([
            "\nexpected = {}".format(expected),
            "\n  actual = {}".format(actual),
        ]))
    else:
        elbo.loss_and_grads(model, guide)
        actual = q.grad / num_particles
        expected = grad(kl, [q])[0]
        assert_equal(actual, expected, prec=prec, msg="".join([
            "\nexpected = {}".format(expected.detach().cpu().numpy()),
            "\n  actual = {}".format(actual.detach().cpu().numpy()),
        ]))
Example #16
0
def test_elbo_hmm_in_guide(enumerate1, num_steps):
    pyro.clear_param_store()
    data = torch.ones(num_steps)
    init_probs = torch.tensor([0.5, 0.5])

    def model(data):
        transition_probs = pyro.param("transition_probs",
                                      torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
                                      constraint=constraints.simplex)
        emission_probs = pyro.param("emission_probs",
                                    torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
                                    constraint=constraints.simplex)

        x = None
        for i, y in enumerate(data):
            probs = init_probs if x is None else transition_probs[x]
            x = pyro.sample("x_{}".format(i), dist.Categorical(probs))
            pyro.sample("y_{}".format(i), dist.Categorical(emission_probs[x]), obs=y)

    @config_enumerate(default=enumerate1)
    def guide(data):
        transition_probs = pyro.param("transition_probs",
                                      torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
                                      constraint=constraints.simplex)
        x = None
        for i, y in enumerate(data):
            probs = init_probs if x is None else transition_probs[x]
            x = pyro.sample("x_{}".format(i), dist.Categorical(probs))

    elbo = TraceEnum_ELBO(max_iarange_nesting=0)
    elbo.loss_and_grads(model, guide, data)

    # These golden values simply test agreement between parallel and sequential.
    expected_grads = {
        2: {
            "transition_probs": [[0.1029949, -0.1029949], [0.1029949, -0.1029949]],
            "emission_probs": [[0.75, -0.75], [0.25, -0.25]],
        },
        3: {
            "transition_probs": [[0.25748726, -0.25748726], [0.25748726, -0.25748726]],
            "emission_probs": [[1.125, -1.125], [0.375, -0.375]],
        },
        10: {
            "transition_probs": [[1.64832076, -1.64832076], [1.64832076, -1.64832076]],
            "emission_probs": [[3.75, -3.75], [1.25, -1.25]],
        },
        20: {
            "transition_probs": [[3.70781687, -3.70781687], [3.70781687, -3.70781687]],
            "emission_probs": [[7.5, -7.5], [2.5, -2.5]],
        },
    }

    for name, value in pyro.get_param_store().named_parameters():
        actual = value.grad
        expected = torch.tensor(expected_grads[num_steps][name])
        assert_equal(actual, expected, msg=''.join([
            '\nexpected {}.grad = {}'.format(name, expected.cpu().numpy()),
            '\n  actual {}.grad = {}'.format(name, actual.detach().cpu().numpy()),
        ]))
Example #17
0
def test_unweighted_mean_and_var(size, dtype):
    empirical_dist = Empirical()
    for i in range(5):
        empirical_dist.add(torch.ones(size, dtype=dtype) * i)
    true_mean = torch.ones(size) * 2
    true_var = torch.ones(size) * 2
    assert_equal(empirical_dist.mean, true_mean)
    assert_equal(empirical_dist.variance, true_var)
Example #18
0
def test_log_pdf(dist):
    d = dist.pyro_dist
    for idx in dist.get_test_data_indices():
        dist_params = dist.get_dist_params(idx)
        test_data = dist.get_test_data(idx)
        pyro_log_pdf = unwrap_variable(d.log_pdf(test_data, **dist_params))[0]
        scipy_log_pdf = dist.get_scipy_logpdf(idx)
        assert_equal(pyro_log_pdf, scipy_log_pdf)
Example #19
0
def test_batch_log_pdf(dist):
    d = dist.pyro_dist
    for idx in dist.get_batch_data_indices():
        dist_params = dist.get_dist_params(idx)
        test_data = dist.get_test_data(idx)
        logpdf_sum_pyro = unwrap_variable(torch.sum(d.batch_log_pdf(test_data, **dist_params)))[0]
        logpdf_sum_np = np.sum(dist.get_scipy_batch_logpdf(-1))
        assert_equal(logpdf_sum_pyro, logpdf_sum_np)
Example #20
0
def test_double_type(test_data, alpha, beta):
    log_px_torch = dist.Beta(alpha, beta).log_prob(test_data).data
    assert isinstance(log_px_torch, torch.DoubleTensor)
    log_px_val = log_px_torch.numpy()
    log_px_np = sp.beta.logpdf(
        test_data.detach().cpu().numpy(),
        alpha.detach().cpu().numpy(),
        beta.detach().cpu().numpy())
    assert_equal(log_px_val, log_px_np, prec=1e-4)
Example #21
0
def test_sample_shape(dist):
    d = dist.pyro_dist
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        x_func = dist.pyro_dist.sample(**dist_params)
        x_obj = dist.pyro_dist_obj(**dist_params).sample()
        assert_equal(x_obj.size(), x_func.size())
        with xfail_if_not_implemented():
            assert(x_func.size() == d.shape(x_func, **dist_params))
Example #22
0
def test_float_type(float_test_data, float_alpha, float_beta, test_data, alpha, beta):
    log_px_torch = dist.Beta(float_alpha, float_beta).log_prob(float_test_data).data
    assert isinstance(log_px_torch, torch.FloatTensor)
    log_px_val = log_px_torch.numpy()
    log_px_np = sp.beta.logpdf(
        test_data.detach().cpu().numpy(),
        alpha.detach().cpu().numpy(),
        beta.detach().cpu().numpy())
    assert_equal(log_px_val, log_px_np, prec=1e-4)
Example #23
0
def test_batch_log_prob(dist):
    if dist.scipy_arg_fn is None:
        pytest.skip('{}.log_prob_sum has no scipy equivalent'.format(dist.pyro_dist.__name__))
    for idx in dist.get_batch_data_indices():
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        test_data = dist.get_test_data(idx)
        log_prob_sum_pyro = d.log_prob(test_data).sum().item()
        log_prob_sum_np = np.sum(dist.get_scipy_batch_logpdf(-1))
        assert_equal(log_prob_sum_pyro, log_prob_sum_np)
Example #24
0
def test_enumerate_support(discrete_dist):
    expected_support = discrete_dist.expected_support
    expected_support_non_vec = discrete_dist.expected_support_non_vec
    if not expected_support:
        pytest.skip("enumerate_support not tested for distribution")
    Dist = discrete_dist.pyro_dist
    actual_support_non_vec = Dist(**discrete_dist.get_dist_params(0)).enumerate_support()
    actual_support = Dist(**discrete_dist.get_dist_params(-1)).enumerate_support()
    assert_equal(actual_support.data, torch.tensor(expected_support))
    assert_equal(actual_support_non_vec.data, torch.tensor(expected_support_non_vec))
Example #25
0
def test_scale_tril():
    loc = torch.tensor([1.0, 2.0, 1.0, 2.0, 0.0])
    D = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
    W = torch.tensor([[1.0, -1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 1.0, 2.0, 4.0]])
    cov = D.diag() + W.t().matmul(W)

    mvn = MultivariateNormal(loc, cov)
    lowrank_mvn = LowRankMultivariateNormal(loc, W, D)

    assert_equal(mvn.scale_tril, lowrank_mvn.scale_tril)
def assert_correct_dimensions(sample, ps, vs, one_hot):
    ps_shape = list(ps.data.size())
    if isinstance(sample, torch.autograd.Variable):
        sample_shape = list(sample.data.size())
    else:
        sample_shape = list(sample.shape)
    if one_hot and not vs:
        assert_equal(sample_shape, ps_shape)
    else:
        assert_equal(sample_shape, ps_shape[:-1] + [1])
Example #27
0
def test_posterior_predictive():
    true_probs = torch.ones(5) * 0.7
    num_trials = torch.ones(5) * 1000
    num_success = dist.Binomial(num_trials, true_probs).sample()
    conditioned_model = poutine.condition(model, data={"obs": num_success})
    nuts_kernel = NUTS(conditioned_model, adapt_step_size=True)
    mcmc_run = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200).run(num_trials)
    posterior_predictive = TracePredictive(model, mcmc_run, num_samples=10000).run(num_trials)
    marginal_return_vals = EmpiricalMarginal(posterior_predictive)
    assert_equal(marginal_return_vals.mean, torch.ones(5) * 700, prec=30)
Example #28
0
def test_trajectory(example):
    model, args = example
    q_f, p_f = velocity_verlet(args.q_i,
                               args.p_i,
                               model.potential_fn,
                               args.step_size,
                               args.num_steps)
    logger.info("initial q: {}".format(args.q_i))
    logger.info("final q: {}".format(q_f))
    assert_equal(q_f, args.q_f, args.prec)
    assert_equal(p_f, args.p_f, args.prec)
Example #29
0
def test_log_prob():
    loc = torch.tensor([2.0, 1.0, 1.0, 2.0, 2.0])
    D = torch.tensor([1.0, 2.0, 3.0, 1.0, 3.0])
    W = torch.tensor([[1.0, -1.0, 2.0, 2.0, 4.0], [2.0, 1.0, 1.0, 2.0, 6.0]])
    x = torch.tensor([2.0, 3.0, 4.0, 1.0, 7.0])
    cov = D.diag() + W.t().matmul(W)

    mvn = MultivariateNormal(loc, cov)
    lowrank_mvn = LowRankMultivariateNormal(loc, W, D)

    assert_equal(mvn.log_prob(x), lowrank_mvn.log_prob(x))
Example #30
0
 def test_replay_partial(self):
     guide_trace = poutine.trace(self.guide).get_trace()
     model_trace = poutine.trace(poutine.replay(self.model,
                                                guide_trace,
                                                sites=self.partial_sample_sites)).get_trace()
     for name in self.full_sample_sites.keys():
         if name in self.partial_sample_sites:
             assert_equal(model_trace.nodes[name]["value"],
                          guide_trace.nodes[name]["value"])
         else:
             assert not eq(model_trace.nodes[name]["value"],
                           guide_trace.nodes[name]["value"])
Example #31
0
def test_tmc_categoricals(depth, max_plate_nesting, num_samples, tmc_strategy):
    qs = [pyro.param("q0", torch.tensor([0.4, 0.6], requires_grad=True))]
    for i in range(1, depth):
        qs.append(
            pyro.param("q{}".format(i),
                       torch.randn(2, 2).abs().detach().requires_grad_(),
                       constraint=constraints.simplex))
    qs.append(pyro.param("qy", torch.tensor([0.75, 0.25], requires_grad=True)))

    qs = [q.unconstrained() for q in qs]

    data = (torch.rand(4, 3) > 0.5).to(dtype=qs[-1].dtype,
                                       device=qs[-1].device)

    def model():
        x = pyro.sample("x0", dist.Categorical(pyro.param("q0")))
        with pyro.plate("local", 3):
            for i in range(1, depth):
                x = pyro.sample(
                    "x{}".format(i),
                    dist.Categorical(pyro.param("q{}".format(i))[..., x, :]))
            with pyro.plate("data", 4):
                pyro.sample("y",
                            dist.Bernoulli(pyro.param("qy")[..., x]),
                            obs=data)

    elbo = TraceEnum_ELBO(max_plate_nesting=max_plate_nesting)
    enum_model = config_enumerate(model,
                                  default="parallel",
                                  expand=False,
                                  num_samples=None,
                                  tmc=tmc_strategy)
    expected_loss = (-elbo.differentiable_loss(enum_model, lambda: None)).exp()
    expected_grads = grad(expected_loss, qs)

    tmc = TraceTMC_ELBO(max_plate_nesting=max_plate_nesting)
    tmc_model = config_enumerate(model,
                                 default="parallel",
                                 expand=False,
                                 num_samples=num_samples,
                                 tmc=tmc_strategy)
    actual_loss = (-tmc.differentiable_loss(tmc_model, lambda: None)).exp()
    actual_grads = grad(actual_loss, qs)

    prec = 0.05
    assert_equal(actual_loss,
                 expected_loss,
                 prec=prec,
                 msg="".join([
                     "\nexpected loss = {}".format(expected_loss),
                     "\n  actual loss = {}".format(actual_loss),
                 ]))

    for actual_grad, expected_grad in zip(actual_grads, expected_grads):
        assert_equal(actual_grad,
                     expected_grad,
                     prec=prec,
                     msg="".join([
                         "\nexpected grad = {}".format(
                             expected_grad.detach().cpu().numpy()),
                         "\n  actual grad = {}".format(
                             actual_grad.detach().cpu().numpy()),
                     ]))
Example #32
0
 def test_support(self):
     s = dist.Categorical(self.d_ps).enumerate_support()
     assert_equal(s.data, self.support)
Example #33
0
 def test_trace_return(self):
     model_trace = poutine.trace(self.model).get_trace()
     assert_equal(model_trace.nodes["latent1"]["value"],
                  model_trace.nodes["_RETURN"]["value"])
Example #34
0
def test_elbo_mapdata(batch_size, map_type):
    # normal-normal: known covariance
    lam0 = Variable(torch.Tensor([0.1, 0.1]))   # precision of prior
    mu0 = Variable(torch.Tensor([0.0, 0.5]))   # prior mean
    # known precision of observation noise
    lam = Variable(torch.Tensor([6.0, 4.0]))
    data = []
    sum_data = Variable(torch.zeros(2))

    def add_data_point(x, y):
        data.append(Variable(torch.Tensor([x, y])))
        sum_data.data.add_(data[-1].data)

    add_data_point(0.1, 0.21)
    add_data_point(0.16, 0.11)
    add_data_point(0.06, 0.31)
    add_data_point(-0.01, 0.07)
    add_data_point(0.23, 0.25)
    add_data_point(0.19, 0.18)
    add_data_point(0.09, 0.41)
    add_data_point(-0.04, 0.17)

    n_data = Variable(torch.Tensor([len(data)]))
    analytic_lam_n = lam0 + n_data.expand_as(lam) * lam
    analytic_log_sig_n = -0.5 * torch.log(analytic_lam_n)
    analytic_mu_n = sum_data * (lam / analytic_lam_n) +\
        mu0 * (lam0 / analytic_lam_n)
    verbose = True
    n_steps = 7000

    if verbose:
        print("DOING ELBO TEST [bs = {}, map_type = {}]".format(
            batch_size, map_type))
    pyro.clear_param_store()

    def model():
        mu_latent = pyro.sample("mu_latent", dist.normal,
                                mu0, torch.pow(lam0, -0.5))
        if map_type == "list":
            pyro.map_data("aaa", data, lambda i,
                          x: pyro.observe(
                              "obs_%d" % i, dist.normal,
                              x, mu_latent, torch.pow(lam, -0.5)), batch_size=batch_size)
        elif map_type == "tensor":
            tdata = torch.cat([xi.view(1, -1) for xi in data], 0)
            pyro.map_data("aaa", tdata,
                          # XXX get batch size args to dist right
                          lambda i, x: pyro.observe("obs", dist.normal, x, mu_latent,
                                                    torch.pow(lam, -0.5)),
                          batch_size=batch_size)
        else:
            for i, x in enumerate(data):
                pyro.observe('obs_%d' % i,
                             dist.normal, x, mu_latent, torch.pow(lam, -0.5))
        return mu_latent

    def guide():
        mu_q = pyro.param("mu_q", Variable(analytic_mu_n.data + torch.Tensor([-0.18, 0.23]),
                                           requires_grad=True))
        log_sig_q = pyro.param("log_sig_q", Variable(
            analytic_log_sig_n.data - torch.Tensor([-0.18, 0.23]),
            requires_grad=True))
        sig_q = torch.exp(log_sig_q)
        pyro.sample("mu_latent", dist.normal, mu_q, sig_q)
        if map_type == "list" or map_type is None:
            pyro.map_data("aaa", data, lambda i, x: None, batch_size=batch_size)
        elif map_type == "tensor":
            tdata = torch.cat([xi.view(1, -1) for xi in data], 0)
            # dummy map_data to do subsampling for observe
            pyro.map_data("aaa", tdata, lambda i, x: None, batch_size=batch_size)
        else:
            pass

    adam = optim.Adam({"lr": 0.0008, "betas": (0.95, 0.999)})
    svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True)

    for k in range(n_steps):
        svi.step()

        mu_error = torch.sum(
            torch.pow(
                analytic_mu_n -
                pyro.param("mu_q"),
                2.0))
        log_sig_error = torch.sum(
            torch.pow(
                analytic_log_sig_n -
                pyro.param("log_sig_q"),
                2.0))

        if verbose and k % 500 == 0:
            print("errors", mu_error.data.cpu().numpy()[0], log_sig_error.data.cpu().numpy()[0])

    assert_equal(Variable(torch.zeros(1)), mu_error, prec=0.05)
    assert_equal(Variable(torch.zeros(1)), log_sig_error, prec=0.06)
Example #35
0
def test_sample_dims(dim, probs):
    probs = modify_params_using_dims(probs, dim)
    sample = dist.Categorical(probs).sample()
    expected_shape = dist.Categorical(probs).shape()
    assert_equal(sample.size(), expected_shape)
Example #36
0
    def do_elbo_test(self, reparameterized, n_steps, lr, prec, difficulty=1.0):
        n_repa_nodes = (torch.sum(self.which_nodes_reparam)
                        if not reparameterized else self.N)
        logger.info(
            " - - - - - DO GAUSSIAN %d-CHAIN ELBO TEST  [reparameterized = %s; %d/%d] - - - - - "
            % (self.N, reparameterized, n_repa_nodes, self.N))
        if self.N < 0:

            def array_to_string(y):
                return str(
                    map(lambda x: "%.3f" % x.detach().cpu().numpy()[0], y))

            logger.debug("lambdas: " + array_to_string(self.lambdas))
            logger.debug("target_mus: " + array_to_string(self.target_mus[1:]))
            logger.debug("target_kappas: "******"lambda_posts: " +
                         array_to_string(self.lambda_posts[1:]))
            logger.debug("lambda_tilde_posts: " +
                         array_to_string(self.lambda_tilde_posts))
            pyro.clear_param_store()

        adam = optim.Adam({"lr": lr, "betas": (0.95, 0.999)})
        elbo = TraceGraph_ELBO()
        loss_and_grads = elbo.loss_and_grads
        # loss_and_grads = elbo.jit_loss_and_grads  # This fails.
        svi = SVI(self.model,
                  self.guide,
                  adam,
                  loss=elbo.loss,
                  loss_and_grads=loss_and_grads)

        for step in range(n_steps):
            t0 = time.time()
            svi.step(reparameterized=reparameterized, difficulty=difficulty)

            if step % 5000 == 0 or step == n_steps - 1:
                kappa_errors, log_sig_errors, loc_errors = [], [], []
                for k in range(1, self.N + 1):
                    if k != self.N:
                        kappa_error = param_mse("kappa_q_%d" % k,
                                                self.target_kappas[k])
                        kappa_errors.append(kappa_error)

                    loc_errors.append(
                        param_mse("loc_q_%d" % k, self.target_mus[k]))
                    log_sig_error = param_mse(
                        "log_sig_q_%d" % k,
                        -0.5 * torch.log(self.lambda_posts[k]))
                    log_sig_errors.append(log_sig_error)

                max_errors = (
                    np.max(loc_errors),
                    np.max(log_sig_errors),
                    np.max(kappa_errors),
                )
                min_errors = (
                    np.min(loc_errors),
                    np.min(log_sig_errors),
                    np.min(kappa_errors),
                )
                mean_errors = (
                    np.mean(loc_errors),
                    np.mean(log_sig_errors),
                    np.mean(kappa_errors),
                )
                logger.debug(
                    "[max errors]   (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)"
                    % max_errors)
                logger.debug(
                    "[min errors]   (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)"
                    % min_errors)
                logger.debug(
                    "[mean errors]  (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)"
                    % mean_errors)
                logger.debug("[step time = %.3f;  N = %d;  step = %d]\n" %
                             (time.time() - t0, self.N, step))

        assert_equal(0.0, max_errors[0], prec=prec)
        assert_equal(0.0, max_errors[1], prec=prec)
        assert_equal(0.0, max_errors[2], prec=prec)
Example #37
0
def test_mean_gradient(K, D, flat_logits, cost_function, mix_dist, batch_mode):
    n_samples = 200000
    if batch_mode:
        sample_shape = torch.Size(())
    else:
        sample_shape = torch.Size((n_samples,))
    if mix_dist == GaussianScaleMixture:
        locs = torch.zeros(K, D, requires_grad=True)
    else:
        locs = torch.rand(K, D).requires_grad_(True)
    if mix_dist == GaussianScaleMixture:
        component_scale = 1.5 * torch.ones(K) + 0.5 * torch.rand(K)
        component_scale.requires_grad_(True)
    else:
        component_scale = torch.ones(K, requires_grad=True)
    if mix_dist == MixtureOfDiagNormals:
        coord_scale = torch.ones(K, D) + 0.5 * torch.rand(K, D)
        coord_scale.requires_grad_(True)
    else:
        coord_scale = torch.ones(D) + 0.5 * torch.rand(D)
        coord_scale.requires_grad_(True)
    if not flat_logits:
        component_logits = (1.5 * torch.rand(K)).requires_grad_(True)
    else:
        component_logits = (0.1 * torch.rand(K)).requires_grad_(True)
    omega = (0.2 * torch.ones(D) + 0.1 * torch.rand(D)).requires_grad_(False)

    _pis = torch.exp(component_logits)
    pis = _pis / _pis.sum()

    if cost_function == 'cosine':
        analytic1 = torch.cos((omega * locs).sum(-1))
        analytic2 = torch.exp(-0.5 * torch.pow(omega * coord_scale * component_scale.unsqueeze(-1), 2.0).sum(-1))
        analytic = (pis * analytic1 * analytic2).sum()
        analytic.backward()
    elif cost_function == 'quadratic':
        analytic = torch.pow(coord_scale * component_scale.unsqueeze(-1), 2.0).sum(-1) + torch.pow(locs, 2.0).sum(-1)
        analytic = (pis * analytic).sum()
        analytic.backward()

    analytic_grads = {}
    analytic_grads['locs'] = locs.grad.clone()
    analytic_grads['coord_scale'] = coord_scale.grad.clone()
    analytic_grads['component_logits'] = component_logits.grad.clone()
    analytic_grads['component_scale'] = component_scale.grad.clone()

    assert locs.grad.shape == locs.shape
    assert coord_scale.grad.shape == coord_scale.shape
    assert component_logits.grad.shape == component_logits.shape
    assert component_scale.grad.shape == component_scale.shape

    coord_scale.grad.zero_()
    component_logits.grad.zero_()
    locs.grad.zero_()
    component_scale.grad.zero_()

    if mix_dist == MixtureOfDiagNormalsSharedCovariance:
        params = {'locs': locs, 'coord_scale': coord_scale, 'component_logits': component_logits}
        if batch_mode:
            locs = locs.unsqueeze(0).expand(n_samples, K, D)
            coord_scale = coord_scale.unsqueeze(0).expand(n_samples, D)
            component_logits = component_logits.unsqueeze(0).expand(n_samples, K)
            dist_params = {'locs': locs, 'coord_scale': coord_scale, 'component_logits': component_logits}
        else:
            dist_params = params
    elif mix_dist == MixtureOfDiagNormals:
        params = {'locs': locs, 'coord_scale': coord_scale, 'component_logits': component_logits}
        if batch_mode:
            locs = locs.unsqueeze(0).expand(n_samples, K, D)
            coord_scale = coord_scale.unsqueeze(0).expand(n_samples, K, D)
            component_logits = component_logits.unsqueeze(0).expand(n_samples, K)
            dist_params = {'locs': locs, 'coord_scale': coord_scale, 'component_logits': component_logits}
        else:
            dist_params = params
    elif mix_dist == GaussianScaleMixture:
        params = {'coord_scale': coord_scale, 'component_logits': component_logits, 'component_scale': component_scale}
        if batch_mode:
            return  # distribution does not support batched parameters
        else:
            dist_params = params

    dist = mix_dist(**dist_params)
    z = dist.rsample(sample_shape=sample_shape)
    assert z.shape == (n_samples, D)
    if cost_function == 'cosine':
        cost = torch.cos((omega * z).sum(-1)).sum() / float(n_samples)
    elif cost_function == 'quadratic':
        cost = torch.pow(z, 2.0).sum() / float(n_samples)
    cost.backward()

    assert_equal(analytic, cost, prec=0.1,
                 msg='bad cost function evaluation for {} test (expected {}, got {})'.format(
                     mix_dist.__name__, analytic.item(), cost.item()))
    logger.debug("analytic_grads_logit: {}"
                 .format(analytic_grads['component_logits'].detach().cpu().numpy()))

    for param_name, param in params.items():
        assert_equal(param.grad, analytic_grads[param_name], prec=0.1,
                     msg='bad {} grad for {} (expected {}, got {})'.format(
                         param_name, mix_dist.__name__, analytic_grads[param_name], param.grad))
Example #38
0
def test_elbo_mapdata(batch_size, map_type):
    # normal-normal: known covariance
    lam0 = torch.tensor([0.1, 0.1])  # precision of prior
    loc0 = torch.tensor([0.0, 0.5])  # prior mean
    # known precision of observation noise
    lam = torch.tensor([6.0, 4.0])
    data = []
    sum_data = torch.zeros(2)

    def add_data_point(x, y):
        data.append(torch.tensor([x, y]))
        sum_data.data.add_(data[-1].data)

    add_data_point(0.1, 0.21)
    add_data_point(0.16, 0.11)
    add_data_point(0.06, 0.31)
    add_data_point(-0.01, 0.07)
    add_data_point(0.23, 0.25)
    add_data_point(0.19, 0.18)
    add_data_point(0.09, 0.41)
    add_data_point(-0.04, 0.17)

    data = torch.stack(data)
    n_data = torch.tensor([float(len(data))])
    analytic_lam_n = lam0 + n_data.expand_as(lam) * lam
    analytic_log_sig_n = -0.5 * torch.log(analytic_lam_n)
    analytic_loc_n = sum_data * (lam / analytic_lam_n) +\
        loc0 * (lam0 / analytic_lam_n)
    n_steps = 7000

    logger.debug("DOING ELBO TEST [bs = {}, map_type = {}]".format(
        batch_size, map_type))
    pyro.clear_param_store()

    def model():
        loc_latent = pyro.sample(
            "loc_latent",
            dist.Normal(loc0, torch.pow(lam0, -0.5)).to_event(1))
        if map_type == "iplate":
            for i in pyro.plate("aaa", len(data), batch_size):
                pyro.sample("obs_%d" % i,
                            dist.Normal(loc_latent,
                                        torch.pow(lam, -0.5)).to_event(1),
                            obs=data[i]),
        elif map_type == "plate":
            with pyro.plate("aaa", len(data), batch_size) as ind:
                pyro.sample("obs",
                            dist.Normal(loc_latent,
                                        torch.pow(lam, -0.5)).to_event(1),
                            obs=data[ind]),
        else:
            for i, x in enumerate(data):
                pyro.sample('obs_%d' % i,
                            dist.Normal(loc_latent,
                                        torch.pow(lam, -0.5)).to_event(1),
                            obs=x)
        return loc_latent

    def guide():
        loc_q = pyro.param(
            "loc_q",
            analytic_loc_n.detach().clone() + torch.tensor([-0.18, 0.23]))
        log_sig_q = pyro.param(
            "log_sig_q",
            analytic_log_sig_n.detach().clone() - torch.tensor([-0.18, 0.23]))
        sig_q = torch.exp(log_sig_q)
        pyro.sample("loc_latent", dist.Normal(loc_q, sig_q).to_event(1))
        if map_type == "iplate" or map_type is None:
            for i in pyro.plate("aaa", len(data), batch_size):
                pass
        elif map_type == "plate":
            # dummy plate to do subsampling for observe
            with pyro.plate("aaa", len(data), batch_size):
                pass
        else:
            pass

    adam = optim.Adam({"lr": 0.0008, "betas": (0.95, 0.999)})
    svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())

    for k in range(n_steps):
        svi.step()

        loc_error = torch.sum(
            torch.pow(analytic_loc_n - pyro.param("loc_q"), 2.0))
        log_sig_error = torch.sum(
            torch.pow(analytic_log_sig_n - pyro.param("log_sig_q"), 2.0))

        if k % 500 == 0:
            logger.debug("errors - {}, {}".format(loc_error, log_sig_error))

    assert_equal(loc_error.item(), 0, prec=0.05)
    assert_equal(log_sig_error.item(), 0, prec=0.06)
Example #39
0
def test_persistent_independent_subproblems(num_objects, num_frames, num_detections, bp_iters):
    # solve a random assignment problem
    exists_logits_1 = -2 * torch.rand(num_objects)
    assign_logits_1 = 2 * torch.rand(num_frames, num_detections, num_objects) - 1
    assignment_1 = MarginalAssignmentPersistent(exists_logits_1, assign_logits_1, bp_iters)
    exists_probs_1 = assignment_1.exists_dist.probs
    assign_probs_1 = assignment_1.assign_dist.probs

    # solve another random assignment problem
    exists_logits_2 = -2 * torch.rand(num_objects)
    assign_logits_2 = 2 * torch.rand(num_frames, num_detections, num_objects) - 1
    assignment_2 = MarginalAssignmentPersistent(exists_logits_2, assign_logits_2, bp_iters)
    exists_probs_2 = assignment_2.exists_dist.probs
    assign_probs_2 = assignment_2.assign_dist.probs

    # solve a unioned assignment problem
    exists_logits = torch.cat([exists_logits_1, exists_logits_2])
    assign_logits = torch.full((num_frames, num_detections * 2, num_objects * 2), -INF)
    assign_logits[:, :num_detections, :num_objects] = assign_logits_1
    assign_logits[:, num_detections:, num_objects:] = assign_logits_2
    assignment = MarginalAssignmentPersistent(exists_logits, assign_logits, bp_iters)
    exists_probs = assignment.exists_dist.probs
    assign_probs = assignment.assign_dist.probs

    # check agreement
    assert_equal(exists_probs_1, exists_probs[:num_objects])
    assert_equal(exists_probs_2, exists_probs[num_objects:])
    assert_equal(assign_probs_1[:, :, :-1], assign_probs[:, :num_detections, :num_objects])
    assert_equal(assign_probs_1[:, :, -1], assign_probs[:, :num_detections, -1])
    assert_equal(assign_probs_2[:, :, :-1], assign_probs[:, num_detections:, num_objects:-1])
    assert_equal(assign_probs_2[:, :, -1], assign_probs[:, num_detections:, -1])
Example #40
0
 def test_replay_full(self):
     guide_trace = poutine.trace(self.guide).get_trace()
     model_trace = poutine.trace(poutine.replay(self.model, guide_trace)).get_trace()
     for name in self.full_sample_sites.keys():
         assert_equal(model_trace.nodes[name]["value"],
                      guide_trace.nodes[name]["value"])
Example #41
0
def test_generic_lgssm_forecast(model_class, state_dim, obs_dim, T):
    torch.set_default_tensor_type('torch.DoubleTensor')

    if model_class == 'lgssm':
        model = GenericLGSSM(state_dim=state_dim, obs_dim=obs_dim,
                             obs_noise_scale_init=0.1 + torch.rand(obs_dim))
    elif model_class == 'lgssmgp':
        model = GenericLGSSMWithGPNoiseModel(state_dim=state_dim, obs_dim=obs_dim, nu=1.5,
                                             obs_noise_scale_init=0.1 + torch.rand(obs_dim))
        # with these hyperparameters we essentially turn off the GP contributions
        model.kernel.length_scale = 1.0e-6 * torch.ones(obs_dim)
        model.kernel.kernel_scale = 1.0e-6 * torch.ones(obs_dim)

    targets = torch.randn(T, obs_dim)
    filtering_state = model._filter(targets)

    actual_loc, actual_cov = model._forecast(3, filtering_state, include_observation_noise=False)

    obs_matrix = model.obs_matrix if model_class == 'lgssm' else model.z_obs_matrix
    trans_matrix = model.trans_matrix if model_class == 'lgssm' else model.z_trans_matrix
    trans_matrix_sq = torch.mm(trans_matrix, trans_matrix)
    trans_matrix_cubed = torch.mm(trans_matrix_sq, trans_matrix)

    trans_obs = torch.mm(trans_matrix, obs_matrix)
    trans_trans_obs = torch.mm(trans_matrix_sq, obs_matrix)
    trans_trans_trans_obs = torch.mm(trans_matrix_cubed, obs_matrix)

    # we only compute contributions for the state space portion for lgssmgp
    fs_loc = filtering_state.loc if model_class == 'lgssm' else filtering_state.loc[-state_dim:]

    predicted_mean1 = torch.mm(fs_loc.unsqueeze(-2), trans_obs).squeeze(-2)
    predicted_mean2 = torch.mm(fs_loc.unsqueeze(-2), trans_trans_obs).squeeze(-2)
    predicted_mean3 = torch.mm(fs_loc.unsqueeze(-2), trans_trans_trans_obs).squeeze(-2)

    # check predicted means for 3 timesteps
    assert_equal(actual_loc[0], predicted_mean1)
    assert_equal(actual_loc[1], predicted_mean2)
    assert_equal(actual_loc[2], predicted_mean3)

    # check predicted covariances for 3 timesteps
    fs_covar, process_covar = None, None
    if model_class == 'lgssm':
        process_covar = model._get_trans_dist().covariance_matrix
        fs_covar = filtering_state.covariance_matrix
    elif model_class == 'lgssmgp':
        # we only compute contributions for the state space portion
        process_covar = model.trans_noise_scale_sq.diag_embed()
        fs_covar = filtering_state.covariance_matrix[-state_dim:, -state_dim:]

    predicted_covar1 = torch.mm(trans_obs.t(), torch.mm(fs_covar, trans_obs)) + \
        torch.mm(obs_matrix.t(), torch.mm(process_covar, obs_matrix))

    predicted_covar2 = torch.mm(trans_trans_obs.t(), torch.mm(fs_covar, trans_trans_obs)) + \
        torch.mm(trans_obs.t(), torch.mm(process_covar, trans_obs)) + \
        torch.mm(obs_matrix.t(), torch.mm(process_covar, obs_matrix))

    predicted_covar3 = torch.mm(trans_trans_trans_obs.t(), torch.mm(fs_covar, trans_trans_trans_obs)) + \
        torch.mm(trans_trans_obs.t(), torch.mm(process_covar, trans_trans_obs)) + \
        torch.mm(trans_obs.t(), torch.mm(process_covar, trans_obs)) + \
        torch.mm(obs_matrix.t(), torch.mm(process_covar, obs_matrix))

    assert_equal(actual_cov[0], predicted_covar1)
    assert_equal(actual_cov[1], predicted_covar2)
    assert_equal(actual_cov[2], predicted_covar3)
Example #42
0
 def test_log_prob_sum(self):
     log_px_torch = dist.Categorical(self.probs).log_prob(self.test_data).sum().item()
     log_px_np = float(sp.multinomial.logpmf(np.array([0, 0, 1]), 1, self.probs.detach().cpu().numpy()))
     assert_equal(log_px_torch, log_px_np, prec=1e-4)
Example #43
0
def test_mask(batch_dim, event_dim, mask_dim):
    # Construct base distribution.
    shape = torch.Size([2, 3, 4, 5, 6][:batch_dim + event_dim])
    batch_shape = shape[:batch_dim]
    mask_shape = batch_shape[batch_dim - mask_dim:]
    base_dist = Bernoulli(0.1).expand_by(shape).to_event(event_dim)

    # Construct masked distribution.
    mask = checker_mask(mask_shape)
    dist = base_dist.mask(mask)

    # Check shape.
    sample = base_dist.sample()
    assert dist.batch_shape == base_dist.batch_shape
    assert dist.event_shape == base_dist.event_shape
    assert sample.shape == sample.shape
    assert dist.log_prob(sample).shape == base_dist.log_prob(sample).shape

    # Check values.
    assert_equal(dist.mean, base_dist.mean)
    assert_equal(dist.variance, base_dist.variance)
    assert_equal(dist.log_prob(sample),
                 scale_and_mask(base_dist.log_prob(sample), mask=mask))
    assert_equal(dist.score_parts(sample),
                 base_dist.score_parts(sample).scale_and_mask(mask=mask),
                 prec=0)
    if not dist.event_shape:
        assert_equal(dist.enumerate_support(), base_dist.enumerate_support())
        assert_equal(dist.enumerate_support(expand=True),
                     base_dist.enumerate_support(expand=True))
        assert_equal(dist.enumerate_support(expand=False),
                     base_dist.enumerate_support(expand=False))
Example #44
0
def test_tmc_normals_chain_iwae(depth, num_samples, max_plate_nesting,
                                reparameterized, guide_type, expand,
                                tmc_strategy):
    # compare iwae and tmc
    q2 = pyro.param("q2", torch.tensor(0.5, requires_grad=True))
    qs = (q2.unconstrained(), )

    def model(reparameterized):
        Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal
        x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth)))
        for i in range(1, depth):
            x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth)))
        pyro.sample("y", Normal(x, 1.), obs=torch.tensor(float(1)))

    def factorized_guide(reparameterized):
        Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal
        pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth)))
        for i in range(1, depth):
            pyro.sample("x{}".format(i),
                        Normal(0., math.sqrt(float(i + 1) / depth)))

    def nonfactorized_guide(reparameterized):
        Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal
        x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth)))
        for i in range(1, depth):
            x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth)))

    guide = factorized_guide if guide_type == "factorized" else \
        nonfactorized_guide if guide_type == "nonfactorized" else \
        poutine.block(model, hide_fn=lambda msg: msg["type"] == "sample" and msg["is_observed"])
    flat_num_samples = num_samples**min(depth,
                                        2)  # don't use too many, expensive
    vectorized_log_weights, _, _ = vectorized_importance_weights(
        model,
        guide,
        True,
        max_plate_nesting=max_plate_nesting,
        num_samples=flat_num_samples)
    assert vectorized_log_weights.shape == (flat_num_samples, )
    expected_loss = (vectorized_log_weights.logsumexp(dim=-1) -
                     math.log(float(flat_num_samples))).exp()
    expected_grads = grad(expected_loss, qs)

    tmc = TraceTMC_ELBO(max_plate_nesting=max_plate_nesting)
    tmc_model = config_enumerate(model,
                                 default="parallel",
                                 expand=expand,
                                 num_samples=num_samples,
                                 tmc=tmc_strategy)
    tmc_guide = config_enumerate(guide,
                                 default="parallel",
                                 expand=expand,
                                 num_samples=num_samples,
                                 tmc=tmc_strategy)
    actual_loss = (
        -tmc.differentiable_loss(tmc_model, tmc_guide, reparameterized)).exp()
    actual_grads = grad(actual_loss, qs)

    assert_equal(actual_loss,
                 expected_loss,
                 prec=0.05,
                 msg="".join([
                     "\nexpected loss = {}".format(expected_loss),
                     "\n  actual loss = {}".format(actual_loss),
                 ]))

    grad_prec = 0.05 if reparameterized else 0.1
    for actual_grad, expected_grad in zip(actual_grads, expected_grads):
        assert_equal(actual_grad,
                     expected_grad,
                     prec=grad_prec,
                     msg="".join([
                         "\nexpected grad = {}".format(
                             expected_grad.detach().cpu().numpy()),
                         "\n  actual grad = {}".format(
                             actual_grad.detach().cpu().numpy()),
                     ]))
Example #45
0
    def do_elbo_test(
        self,
        reparameterized,
        n_steps,
        lr,
        prec,
        beta1,
        difficulty=1.0,
        model_permutation=False,
    ):
        n_repa_nodes = (torch.sum(self.which_nodes_reparam)
                        if not reparameterized else len(self.q_topo_sort))
        logger.info((
            " - - - DO GAUSSIAN %d-LAYERED PYRAMID ELBO TEST " +
            "(with a total of %d RVs) [reparameterized=%s; %d/%d; perm=%s] - - -"
        ) % (
            self.N,
            (2**self.N) - 1,
            reparameterized,
            n_repa_nodes,
            len(self.q_topo_sort),
            model_permutation,
        ))
        pyro.clear_param_store()

        # check graph structure is as expected but only for N=2
        if self.N == 2:
            guide_trace = pyro.poutine.trace(
                self.guide, graph_type="dense").get_trace(
                    reparameterized=reparameterized,
                    model_permutation=model_permutation,
                    difficulty=difficulty,
                )
            expected_nodes = set([
                "log_sig_1R",
                "kappa_1_1L",
                "_INPUT",
                "constant_term_loc_latent_1R",
                "_RETURN",
                "loc_latent_1R",
                "loc_latent_1",
                "constant_term_loc_latent_1",
                "loc_latent_1L",
                "constant_term_loc_latent_1L",
                "log_sig_1L",
                "kappa_1_1R",
                "kappa_1R_1L",
                "log_sig_1",
            ])
            expected_edges = set([
                ("loc_latent_1R", "loc_latent_1"),
                ("loc_latent_1L", "loc_latent_1R"),
                ("loc_latent_1L", "loc_latent_1"),
            ])
            assert expected_nodes == set(guide_trace.nodes)
            assert expected_edges == set(guide_trace.edges)

        adam = optim.Adam({"lr": lr, "betas": (beta1, 0.999)})
        svi = SVI(self.model, self.guide, adam, loss=TraceGraph_ELBO())

        for step in range(n_steps):
            t0 = time.time()
            svi.step(
                reparameterized=reparameterized,
                model_permutation=model_permutation,
                difficulty=difficulty,
            )

            if step % 5000 == 0 or step == n_steps - 1:
                log_sig_errors = []
                for node in self.target_lambdas:
                    target_log_sig = -0.5 * torch.log(
                        self.target_lambdas[node])
                    log_sig_error = param_mse("log_sig_" + node,
                                              target_log_sig)
                    log_sig_errors.append(log_sig_error)
                max_log_sig_error = np.max(log_sig_errors)
                min_log_sig_error = np.min(log_sig_errors)
                mean_log_sig_error = np.mean(log_sig_errors)
                leftmost_node = self.q_topo_sort[0]
                leftmost_constant_error = param_mse(
                    "constant_term_" + leftmost_node,
                    self.target_leftmost_constant)
                almost_leftmost_constant_error = param_mse(
                    "constant_term_" + leftmost_node[:-1] + "R",
                    self.target_almost_leftmost_constant,
                )

                logger.debug(
                    "[mean function constant errors (partial)]   %.4f  %.4f" %
                    (leftmost_constant_error, almost_leftmost_constant_error))
                logger.debug(
                    "[min/mean/max log(scale) errors]   %.4f  %.4f   %.4f" %
                    (min_log_sig_error, mean_log_sig_error, max_log_sig_error))
                logger.debug("[step time = %.3f;  N = %d;  step = %d]\n" %
                             (time.time() - t0, self.N, step))

        assert_equal(0.0, max_log_sig_error, prec=prec)
        assert_equal(0.0, leftmost_constant_error, prec=prec)
        assert_equal(0.0, almost_leftmost_constant_error, prec=prec)
Example #46
0
 def test_log_pdf(self):
     log_px_torch = dist.delta.log_pdf(self.test_data, self.v).data
     assert_equal(torch.sum(log_px_torch), 0)
Example #47
0
 def test_undo_uncondition(self):
     unconditioned_model = poutine.uncondition(self.model)
     reconditioned_model = pyro.condition(unconditioned_model, {"obs": torch.ones(2)})
     reconditioned_trace = poutine.trace(reconditioned_model).get_trace()
     assert_equal(reconditioned_trace.nodes["obs"]["value"], torch.ones(2))
Example #48
0
 def test_uncondition(self):
     unconditioned_model = poutine.uncondition(self.model)
     unconditioned_trace = poutine.trace(unconditioned_model).get_trace()
     conditioned_trace = poutine.trace(self.model).get_trace()
     assert_equal(conditioned_trace.nodes["obs"]["value"], torch.ones(2))
     assert_not_equal(unconditioned_trace.nodes["obs"]["value"], torch.ones(2))
Example #49
0
def test_timeseries_models(model, nu_statedim, obs_dim, T):
    torch.set_default_tensor_type('torch.DoubleTensor')
    dt = 0.1 + torch.rand(1).item()

    if model == 'lcmgp':
        num_gps = 2
        gp = LinearlyCoupledMaternGP(nu=nu_statedim, obs_dim=obs_dim, dt=dt, num_gps=num_gps,
                                     log_length_scale_init=torch.randn(num_gps),
                                     log_kernel_scale_init=torch.randn(num_gps),
                                     log_obs_noise_scale_init=torch.randn(obs_dim))
    elif model == 'imgp':
        gp = IndependentMaternGP(nu=nu_statedim, obs_dim=obs_dim, dt=dt,
                                 log_length_scale_init=torch.randn(obs_dim),
                                 log_kernel_scale_init=torch.randn(obs_dim),
                                 log_obs_noise_scale_init=torch.randn(obs_dim))
    elif model == 'glgssm':
        gp = GenericLGSSM(state_dim=nu_statedim, obs_dim=obs_dim,
                          log_obs_noise_scale_init=torch.randn(obs_dim))
    elif model == 'ssmgp':
        state_dim = {0.5: 4, 1.5: 3, 2.5: 2}[nu_statedim]
        gp = GenericLGSSMWithGPNoiseModel(nu=nu_statedim, state_dim=state_dim, obs_dim=obs_dim,
                                          log_obs_noise_scale_init=torch.randn(obs_dim))
    elif model == 'dmgp':
        gp = DependentMaternGP(nu=nu_statedim, obs_dim=obs_dim, dt=dt,
                               log_length_scale_init=torch.randn(obs_dim))

    targets = torch.randn(T, obs_dim)
    gp_log_prob = gp.log_prob(targets)
    if model == 'imgp':
        assert gp_log_prob.shape == (obs_dim,)
    else:
        assert gp_log_prob.dim() == 0

    # compare matern log probs to vanilla GP result via multivariate normal
    if model == 'imgp':
        times = dt * torch.arange(T).double()
        for dim in range(obs_dim):
            lengthscale = gp.kernel.log_length_scale.exp()[dim]
            variance = (2.0 * gp.kernel.log_kernel_scale).exp()[dim]
            obs_noise = (2.0 * gp.log_obs_noise_scale).exp()[dim]

            kernel = {0.5: pyro.contrib.gp.kernels.Exponential,
                      1.5: pyro.contrib.gp.kernels.Matern32,
                      2.5: pyro.contrib.gp.kernels.Matern52}[nu_statedim]
            kernel = kernel(input_dim=1, lengthscale=lengthscale, variance=variance)
            kernel = kernel(times) + obs_noise * torch.eye(T)

            mvn = torch.distributions.MultivariateNormal(torch.zeros(T), kernel)
            mvn_log_prob = mvn.log_prob(targets[:, dim])
            assert_equal(mvn_log_prob, gp_log_prob[dim], prec=1e-4)

    for S in [1, 5]:
        if model in ['imgp', 'lcmgp', 'dmgp']:
            dts = torch.rand(S).cumsum(dim=-1)
            predictive = gp.forecast(targets, dts)
        else:
            predictive = gp.forecast(targets, S)
        assert predictive.loc.shape == (S, obs_dim)
        if model == 'imgp':
            assert predictive.scale.shape == (S, obs_dim)
            # assert monotonic increase of predictive noise
            if S > 1:
                delta = predictive.scale[1:S, :] - predictive.scale[0:S-1, :]
                assert (delta > 0.0).sum() == (S - 1) * obs_dim
        else:
            assert predictive.covariance_matrix.shape == (S, obs_dim, obs_dim)
            # assert monotonic increase of predictive noise
            if S > 1:
                dets = predictive.covariance_matrix.det()
                delta = dets[1:S] - dets[0:S-1]
                assert (delta > 0.0).sum() == (S - 1)

    if model in ['imgp', 'lcmgp', 'dmgp']:
        # the distant future
        dts = torch.tensor([500.0])
        predictive = gp.forecast(targets, dts)
        # assert mean reverting behavior for GP models
        assert_equal(predictive.loc, torch.zeros(1, obs_dim))
 def test_support(self):
     s = dist.one_hot_categorical.enumerate_support(self.d_ps)
     assert_equal(s.data, self.support_one_hot)
Example #51
0
def test_batch_log_dims(dim, probs):
    probs = modify_params_using_dims(probs, dim)
    log_prob_shape = torch.Size((3,) + dist.Categorical(probs).batch_shape)
    support = dist.Categorical(probs).enumerate_support()
    log_prob = dist.Categorical(probs).log_prob(support)
    assert_equal(log_prob.size(), log_prob_shape)
 def test_support_non_vectorized(self):
     s = dist.one_hot_categorical.enumerate_support(self.d_ps[0].squeeze(0))
     assert_equal(s.data, self.support_one_hot_non_vec)
Example #53
0
def test_support_dims(dim, probs):
    probs = modify_params_using_dims(probs, dim)
    support = dist.Categorical(probs).enumerate_support()
    assert_equal(support.size(), torch.Size((probs.size(-1),) + probs.size()[:-1]))
Example #54
0
def test_compute_downstream_costs_big_model_guide_pair(include_inner_1,
                                                       include_single,
                                                       flip_c23,
                                                       include_triple,
                                                       include_z1):
    guide_trace = poutine.trace(big_model_guide, graph_type="dense").get_trace(
        include_obs=False,
        include_inner_1=include_inner_1,
        include_single=include_single,
        flip_c23=flip_c23,
        include_triple=include_triple,
        include_z1=include_z1)
    model_trace = poutine.trace(poutine.replay(big_model_guide,
                                               trace=guide_trace),
                                graph_type="dense").get_trace(
                                    include_obs=True,
                                    include_inner_1=include_inner_1,
                                    include_single=include_single,
                                    flip_c23=flip_c23,
                                    include_triple=include_triple,
                                    include_z1=include_z1)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()
    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)

    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(
        model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_nodes_full_model = {
        'a1': {'c2', 'a1', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'b0'},
        'd2': {'obs', 'd2'},
        'd1': {'obs', 'd1', 'd2'},
        'c3': {'d2', 'obs', 'd1', 'c3'},
        'b0': {'b0', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'c2'},
        'b1': {'obs', 'b1', 'd1', 'd2', 'c3', 'c1', 'c2'},
        'c1': {'d1', 'c1', 'obs', 'd2', 'c3', 'c2'},
        'c2': {'obs', 'd1', 'c3', 'd2', 'c2'}
    }
    if not include_triple and include_inner_1 and include_single and not flip_c23:
        assert (dc_nodes == expected_nodes_full_model)

    expected_b1 = (model_trace.nodes['b1']['log_prob'] -
                   guide_trace.nodes['b1']['log_prob'])
    expected_b1 += (model_trace.nodes['d2']['log_prob'] -
                    guide_trace.nodes['d2']['log_prob']).sum(0)
    expected_b1 += (model_trace.nodes['d1']['log_prob'] -
                    guide_trace.nodes['d1']['log_prob']).sum(0)
    expected_b1 += model_trace.nodes['obs']['log_prob'].sum(0, keepdim=False)
    if include_inner_1:
        expected_b1 += (model_trace.nodes['c1']['log_prob'] -
                        guide_trace.nodes['c1']['log_prob']).sum(0)
        expected_b1 += (model_trace.nodes['c2']['log_prob'] -
                        guide_trace.nodes['c2']['log_prob']).sum(0)
        expected_b1 += (model_trace.nodes['c3']['log_prob'] -
                        guide_trace.nodes['c3']['log_prob']).sum(0)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)

    if include_single:
        expected_b0 = (model_trace.nodes['b0']['log_prob'] -
                       guide_trace.nodes['b0']['log_prob'])
        expected_b0 += (model_trace.nodes['b1']['log_prob'] -
                        guide_trace.nodes['b1']['log_prob']).sum()
        expected_b0 += (model_trace.nodes['d2']['log_prob'] -
                        guide_trace.nodes['d2']['log_prob']).sum()
        expected_b0 += (model_trace.nodes['d1']['log_prob'] -
                        guide_trace.nodes['d1']['log_prob']).sum()
        expected_b0 += model_trace.nodes['obs']['log_prob'].sum()
        if include_inner_1:
            expected_b0 += (model_trace.nodes['c1']['log_prob'] -
                            guide_trace.nodes['c1']['log_prob']).sum()
            expected_b0 += (model_trace.nodes['c2']['log_prob'] -
                            guide_trace.nodes['c2']['log_prob']).sum()
            expected_b0 += (model_trace.nodes['c3']['log_prob'] -
                            guide_trace.nodes['c3']['log_prob']).sum()
        assert_equal(expected_b0, dc['b0'], prec=1.0e-6)
        assert dc['b0'].size() == (5, )

    if include_inner_1:
        expected_c3 = (model_trace.nodes['c3']['log_prob'] -
                       guide_trace.nodes['c3']['log_prob'])
        expected_c3 += (model_trace.nodes['d1']['log_prob'] -
                        guide_trace.nodes['d1']['log_prob']).sum(0)
        expected_c3 += (model_trace.nodes['d2']['log_prob'] -
                        guide_trace.nodes['d2']['log_prob']).sum(0)
        expected_c3 += model_trace.nodes['obs']['log_prob'].sum(0)

        expected_c2 = (model_trace.nodes['c2']['log_prob'] -
                       guide_trace.nodes['c2']['log_prob'])
        expected_c2 += (model_trace.nodes['d1']['log_prob'] -
                        guide_trace.nodes['d1']['log_prob']).sum(0)
        expected_c2 += (model_trace.nodes['d2']['log_prob'] -
                        guide_trace.nodes['d2']['log_prob']).sum(0)
        expected_c2 += model_trace.nodes['obs']['log_prob'].sum(0)

        expected_c1 = (model_trace.nodes['c1']['log_prob'] -
                       guide_trace.nodes['c1']['log_prob'])

        if flip_c23:
            expected_c3 += model_trace.nodes['c2'][
                'log_prob'] - guide_trace.nodes['c2']['log_prob']
            expected_c2 += model_trace.nodes['c3']['log_prob']
        else:
            expected_c2 += model_trace.nodes['c3'][
                'log_prob'] - guide_trace.nodes['c3']['log_prob']
            expected_c2 += model_trace.nodes['c2'][
                'log_prob'] - guide_trace.nodes['c2']['log_prob']
        expected_c1 += expected_c3

        assert_equal(expected_c1, dc['c1'], prec=1.0e-6)
        assert_equal(expected_c2, dc['c2'], prec=1.0e-6)
        assert_equal(expected_c3, dc['c3'], prec=1.0e-6)

    expected_d1 = model_trace.nodes['d1']['log_prob'] - guide_trace.nodes[
        'd1']['log_prob']
    expected_d1 += model_trace.nodes['d2']['log_prob'] - guide_trace.nodes[
        'd2']['log_prob']
    expected_d1 += model_trace.nodes['obs']['log_prob']

    expected_d2 = (model_trace.nodes['d2']['log_prob'] -
                   guide_trace.nodes['d2']['log_prob'])
    expected_d2 += model_trace.nodes['obs']['log_prob']

    if include_triple:
        expected_z0 = dc['a1'] + model_trace.nodes['z0'][
            'log_prob'] - guide_trace.nodes['z0']['log_prob']
        assert_equal(expected_z0, dc['z0'], prec=1.0e-6)
    assert_equal(expected_d2, dc['d2'], prec=1.0e-6)
    assert_equal(expected_d1, dc['d1'], prec=1.0e-6)

    assert dc['b1'].size() == (2, )
    assert dc['d2'].size() == (4, 2)

    for k in dc:
        assert (guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
Example #55
0
 def test_mean_and_var(self):
     torch_samples = [dist.Categorical(self.probs).sample().detach().cpu().numpy()
                      for _ in range(self.n_samples)]
     _, counts = np.unique(torch_samples, return_counts=True)
     computed_mean = float(counts[0]) / self.n_samples
     assert_equal(computed_mean, self.analytic_mean.detach().cpu().numpy()[0], prec=0.05)
Example #56
0
def test_NcpContinuous():
    framerate = 100  # Hz
    dt = 1.0 / framerate
    d = 3
    ncp = NcpContinuous(dimension=d, sv2=2.0)
    assert ncp.dimension == d
    assert ncp.dimension_pv == 2*d
    assert ncp.num_process_noise_parameters == 1

    x = torch.rand(d)
    y = ncp(x, dt)
    assert_equal(y, x)

    dx = ncp.geodesic_difference(x, y)
    assert_equal(dx, torch.zeros(d))

    x_pv = ncp.mean2pv(x)
    assert len(x_pv) == 6
    assert_equal(x, x_pv[:d])
    assert_equal(torch.zeros(d), x_pv[d:])

    P = torch.eye(d)
    P_pv = ncp.cov2pv(P)
    assert P_pv.shape == (2*d, 2*d)
    P_pv_ref = torch.zeros((2*d, 2*d))
    P_pv_ref[:d, :d] = P
    assert_equal(P_pv_ref, P_pv)

    Q = ncp.process_noise_cov(dt)
    Q1 = ncp.process_noise_cov(dt)  # Test caching.
    assert_equal(Q, Q1)
    assert Q1.shape == (d, d)
    assert_cov_validity(Q1)

    dx = ncp.process_noise_dist(dt).sample()
    assert dx.shape == (ncp.dimension,)
    def do_elbo_test(self, reparameterized, n_steps, lr, prec, difficulty=1.0):
        n_repa_nodes = torch.sum(self.which_nodes_reparam) if not reparameterized else self.N
        logger.info(" - - - - - DO GAUSSIAN %d-CHAIN ELBO TEST  [reparameterized = %s; %d/%d] - - - - - " %
                    (self.N, reparameterized, n_repa_nodes, self.N))
        if self.N < 0:
            def array_to_string(y):
                return str(map(lambda x: "%.3f" % x.data.cpu().numpy()[0], y))

            logger.debug("lambdas: " + array_to_string(self.lambdas))
            logger.debug("target_mus: " + array_to_string(self.target_mus[1:]))
            logger.debug("target_kappas: "******"lambda_posts: " + array_to_string(self.lambda_posts[1:]))
            logger.debug("lambda_tilde_posts: " + array_to_string(self.lambda_tilde_posts))
            pyro.clear_param_store()

        def model(*args, **kwargs):
            next_mean = self.mu0
            for k in range(1, self.N + 1):
                latent_dist = dist.Normal(next_mean, torch.pow(self.lambdas[k - 1], -0.5))
                mu_latent = pyro.sample("mu_latent_%d" % k, latent_dist)
                next_mean = mu_latent

            mu_N = next_mean
            for i, x in enumerate(self.data):
                pyro.observe("obs_%d" % i, dist.normal, x, mu_N,
                             torch.pow(self.lambdas[self.N], -0.5))
            return mu_N

        def guide(*args, **kwargs):
            previous_sample = None
            for k in reversed(range(1, self.N + 1)):
                mu_q = pyro.param("mu_q_%d" % k, Variable(self.target_mus[k].data +
                                                          difficulty * (0.1 * torch.randn(1) - 0.53),
                                                          requires_grad=True))
                log_sig_q = pyro.param("log_sig_q_%d" % k,
                                       Variable(-0.5 * torch.log(self.lambda_posts[k]).data +
                                                difficulty * (0.1 * torch.randn(1) - 0.53),
                                                requires_grad=True))
                sig_q = torch.exp(log_sig_q)
                kappa_q = None if k == self.N \
                    else pyro.param("kappa_q_%d" % k,
                                    Variable(self.target_kappas[k].data +
                                             difficulty * (0.1 * torch.randn(1) - 0.53),
                                             requires_grad=True))
                mean_function = mu_q if k == self.N else kappa_q * previous_sample + mu_q
                node_flagged = True if self.which_nodes_reparam[k - 1] == 1.0 else False
                normal = dist.normal if reparameterized or node_flagged else fakes.nonreparameterized_normal
                mu_latent = pyro.sample("mu_latent_%d" % k, normal, mean_function, sig_q,
                                        baseline=dict(use_decaying_avg_baseline=True))
                previous_sample = mu_latent
            return previous_sample

        adam = optim.Adam({"lr": lr, "betas": (0.95, 0.999)})
        svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True)

        for step in range(n_steps):
            t0 = time.time()
            svi.step()

            if step % 5000 == 0 or step == n_steps - 1:
                kappa_errors, log_sig_errors, mu_errors = [], [], []
                for k in range(1, self.N + 1):
                    if k != self.N:
                        kappa_error = param_mse("kappa_q_%d" % k, self.target_kappas[k])
                        kappa_errors.append(kappa_error)

                    mu_errors.append(param_mse("mu_q_%d" % k, self.target_mus[k]))
                    log_sig_error = param_mse("log_sig_q_%d" % k, -0.5 * torch.log(self.lambda_posts[k]))
                    log_sig_errors.append(log_sig_error)

                max_errors = (np.max(mu_errors), np.max(log_sig_errors), np.max(kappa_errors))
                min_errors = (np.min(mu_errors), np.min(log_sig_errors), np.min(kappa_errors))
                mean_errors = (np.mean(mu_errors), np.mean(log_sig_errors), np.mean(kappa_errors))
                logger.debug("[max errors]   (mu, log_sigma, kappa) = (%.4f, %.4f, %.4f)" % max_errors)
                logger.debug("[min errors]   (mu, log_sigma, kappa) = (%.4f, %.4f, %.4f)" % min_errors)
                logger.debug("[mean errors]  (mu, log_sigma, kappa) = (%.4f, %.4f, %.4f)" % mean_errors)
                logger.debug("[step time = %.3f;  N = %d;  step = %d]\n" % (time.time() - t0, self.N, step))

        assert_equal(0.0, max_errors[0], prec=prec)
        assert_equal(0.0, max_errors[1], prec=prec)
        assert_equal(0.0, max_errors[2], prec=prec)
Example #58
0
def test_tmc_normals_chain_gradient(depth, num_samples, max_plate_nesting,
                                    expand, guide_type, reparameterized,
                                    tmc_strategy):
    # compare reparameterized and nonreparameterized gradient estimates
    q2 = pyro.param("q2", torch.tensor(0.5, requires_grad=True))
    qs = (q2.unconstrained(), )

    def model(reparameterized):
        Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal
        x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth)))
        for i in range(1, depth):
            x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth)))
        pyro.sample("y", Normal(x, 1.), obs=torch.tensor(float(1)))

    def factorized_guide(reparameterized):
        Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal
        pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth)))
        for i in range(1, depth):
            pyro.sample("x{}".format(i),
                        Normal(0., math.sqrt(float(i + 1) / depth)))

    def nonfactorized_guide(reparameterized):
        Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal
        x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth)))
        for i in range(1, depth):
            x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth)))

    tmc = TraceTMC_ELBO(max_plate_nesting=max_plate_nesting)
    tmc_model = config_enumerate(model,
                                 default="parallel",
                                 expand=expand,
                                 num_samples=num_samples,
                                 tmc=tmc_strategy)
    guide = factorized_guide if guide_type == "factorized" else \
        nonfactorized_guide if guide_type == "nonfactorized" else \
        lambda *args: None
    tmc_guide = config_enumerate(guide,
                                 default="parallel",
                                 expand=expand,
                                 num_samples=num_samples,
                                 tmc=tmc_strategy)

    # gold values from Funsor
    expected_grads = (torch.tensor({
        1: 0.0999,
        2: 0.0860,
        3: 0.0802,
        4: 0.0771
    }[depth]), )

    # convert to linear space for unbiasedness
    actual_loss = (
        -tmc.differentiable_loss(tmc_model, tmc_guide, reparameterized)).exp()
    actual_grads = grad(actual_loss, qs)

    grad_prec = 0.05 if reparameterized else 0.1

    for actual_grad, expected_grad in zip(actual_grads, expected_grads):
        print(actual_loss)
        assert_equal(actual_grad,
                     expected_grad,
                     prec=grad_prec,
                     msg="".join([
                         "\nexpected grad = {}".format(
                             expected_grad.detach().cpu().numpy()),
                         "\n  actual grad = {}".format(
                             actual_grad.detach().cpu().numpy()),
                     ]))
    def do_elbo_test(self, reparameterized, n_steps, lr, prec, beta1,
                     difficulty=1.0, model_permutation=False):
        n_repa_nodes = torch.sum(self.which_nodes_reparam) if not reparameterized \
            else len(self.q_topo_sort)
        logger.info((" - - - DO GAUSSIAN %d-LAYERED PYRAMID ELBO TEST " +
                     "(with a total of %d RVs) [reparameterized=%s; %d/%d; perm=%s] - - -") %
                    (self.N, (2 ** self.N) - 1, reparameterized, n_repa_nodes,
                     len(self.q_topo_sort), model_permutation))
        pyro.clear_param_store()

        def model(*args, **kwargs):
            top_latent_dist = dist.Normal(self.mu0, torch.pow(self.lambdas[0], -0.5))
            previous_names = ["mu_latent_1"]
            top_latent = pyro.sample(previous_names[0], top_latent_dist)
            previous_latents_and_names = list(zip([top_latent], previous_names))

            # for sampling model variables in different sequential orders
            def permute(x, n):
                if model_permutation:
                    return [x[self.model_permutations[n - 1][i]] for i in range(len(x))]
                return x

            def unpermute(x, n):
                if model_permutation:
                    return [x[self.model_unpermutations[n - 1][i]] for i in range(len(x))]
                return x

            for n in range(2, self.N + 1):
                new_latents_and_names = []
                for prev_latent, prev_name in permute(previous_latents_and_names, n - 1):
                    latent_dist = dist.Normal(prev_latent, torch.pow(self.lambdas[n - 1], -0.5))
                    couple = []
                    for LR in ['L', 'R']:
                        new_name = prev_name + LR
                        mu_latent_LR = pyro.sample(new_name, latent_dist)
                        couple.append([mu_latent_LR, new_name])
                    new_latents_and_names.append(couple)
                _previous_latents_and_names = unpermute(new_latents_and_names, n - 1)
                previous_latents_and_names = []
                for x in _previous_latents_and_names:
                    previous_latents_and_names.append(x[0])
                    previous_latents_and_names.append(x[1])

            for i, data_i in enumerate(self.data):
                for k, x in enumerate(data_i):
                    pyro.observe("obs_%s_%d" % (previous_latents_and_names[i][1], k),
                                 dist.normal, x, previous_latents_and_names[i][0],
                                 torch.pow(self.lambdas[-1], -0.5))
            return top_latent

        def guide(*args, **kwargs):
            latents_dict = {}

            n_nodes = len(self.q_topo_sort)
            for i, node in enumerate(self.q_topo_sort):
                deps = self.q_dag.predecessors(node)
                node_suffix = node[10:]
                log_sig_node = pyro.param("log_sig_" + node_suffix,
                                          Variable(-0.5 * torch.log(self.target_lambdas[node_suffix]).data +
                                                   difficulty * (torch.Tensor([-0.3]) -
                                                                 0.3 * (torch.randn(1) ** 2)),
                                                   requires_grad=True))
                mean_function_node = pyro.param("constant_term_" + node,
                                                Variable(self.mu0.data +
                                                         torch.Tensor([difficulty * i / n_nodes]),
                                                         requires_grad=True))
                for dep in deps:
                    kappa_dep = pyro.param("kappa_" + node_suffix + '_' + dep[10:],
                                           Variable(torch.Tensor([0.5 + difficulty * i / n_nodes]),
                                                    requires_grad=True))
                    mean_function_node = mean_function_node + kappa_dep * latents_dict[dep]
                node_flagged = True if self.which_nodes_reparam[i] == 1.0 else False
                normal = dist.normal if reparameterized or node_flagged else fakes.nonreparameterized_normal
                latent_node = pyro.sample(node, normal, mean_function_node, torch.exp(log_sig_node),
                                          baseline=dict(use_decaying_avg_baseline=True,
                                                        baseline_beta=0.96))
                latents_dict[node] = latent_node

            return latents_dict['mu_latent_1']

        # check graph structure is as expected but only for N=2
        if self.N == 2:
            guide_trace = pyro.poutine.trace(guide, graph_type="dense").get_trace()
            expected_nodes = set(['log_sig_1R', 'kappa_1_1L', '_INPUT', 'constant_term_mu_latent_1R', '_RETURN',
                                  'mu_latent_1R', 'mu_latent_1', 'constant_term_mu_latent_1', 'mu_latent_1L',
                                  'constant_term_mu_latent_1L', 'log_sig_1L', 'kappa_1_1R', 'kappa_1R_1L', 'log_sig_1'])
            expected_edges = set([('mu_latent_1R', 'mu_latent_1'), ('mu_latent_1L', 'mu_latent_1R'),
                                  ('mu_latent_1L', 'mu_latent_1')])
            assert expected_nodes == set(guide_trace.nodes)
            assert expected_edges == set(guide_trace.edges)

        adam = optim.Adam({"lr": lr, "betas": (beta1, 0.999)})
        svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True)

        for step in range(n_steps):
            t0 = time.time()
            svi.step()

            if step % 5000 == 0 or step == n_steps - 1:
                log_sig_errors = []
                for node in self.target_lambdas:
                    target_log_sig = -0.5 * torch.log(self.target_lambdas[node])
                    log_sig_error = param_mse('log_sig_' + node, target_log_sig)
                    log_sig_errors.append(log_sig_error)
                max_log_sig_error = np.max(log_sig_errors)
                min_log_sig_error = np.min(log_sig_errors)
                mean_log_sig_error = np.mean(log_sig_errors)
                leftmost_node = self.q_topo_sort[0]
                leftmost_constant_error = param_mse('constant_term_' + leftmost_node,
                                                    self.target_leftmost_constant)
                almost_leftmost_constant_error = param_mse('constant_term_' + leftmost_node[:-1] + 'R',
                                                           self.target_almost_leftmost_constant)

                logger.debug("[mean function constant errors (partial)]   %.4f  %.4f" %
                             (leftmost_constant_error, almost_leftmost_constant_error))
                logger.debug("[min/mean/max log(sigma) errors]   %.4f  %.4f   %.4f" %
                             (min_log_sig_error, mean_log_sig_error, max_log_sig_error))
                logger.debug("[step time = %.3f;  N = %d;  step = %d]\n" % (time.time() - t0, self.N, step))

        assert_equal(0.0, max_log_sig_error, prec=prec)
        assert_equal(0.0, leftmost_constant_error, prec=prec)
        assert_equal(0.0, almost_leftmost_constant_error, prec=prec)
Example #60
0
    def _test_plate_in_elbo(self, n_superfluous_top, n_superfluous_bottom, n_steps):
        pyro.clear_param_store()
        self.data_tensor = torch.zeros(9, 2)
        for _out in range(self.n_outer):
            for _in in range(self.n_inner):
                self.data_tensor[3 * _out + _in, :] = self.data[_out][_in]
        self.data_as_list = [self.data_tensor[0:4, :], self.data_tensor[4:7, :],
                             self.data_tensor[7:9, :]]

        def model():
            loc_latent = pyro.sample("loc_latent",
                                     fakes.NonreparameterizedNormal(self.loc0, torch.pow(self.lam0, -0.5))
                                     .to_event(1))

            for i in pyro.plate("outer", 3):
                x_i = self.data_as_list[i]
                with pyro.plate("inner_%d" % i, x_i.size(0)):
                    for k in range(n_superfluous_top):
                        z_i_k = pyro.sample("z_%d_%d" % (i, k),
                                            fakes.NonreparameterizedNormal(0, 1).expand_by([4 - i]))
                        assert z_i_k.shape == (4 - i,)
                    obs_i = pyro.sample("obs_%d" % i, dist.Normal(loc_latent, torch.pow(self.lam, -0.5))
                                                          .to_event(1), obs=x_i)
                    assert obs_i.shape == (4 - i, 2)
                    for k in range(n_superfluous_top, n_superfluous_top + n_superfluous_bottom):
                        z_i_k = pyro.sample("z_%d_%d" % (i, k),
                                            fakes.NonreparameterizedNormal(0, 1).expand_by([4 - i]))
                        assert z_i_k.shape == (4 - i,)

        pt_loc_baseline = torch.nn.Linear(1, 1)
        pt_superfluous_baselines = []
        for k in range(n_superfluous_top + n_superfluous_bottom):
            pt_superfluous_baselines.extend([torch.nn.Linear(2, 4), torch.nn.Linear(2, 3),
                                             torch.nn.Linear(2, 2)])

        def guide():
            loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.094)
            log_sig_q = pyro.param("log_sig_q",
                                   self.analytic_log_sig_n.expand(2) - 0.07)
            sig_q = torch.exp(log_sig_q)
            trivial_baseline = pyro.module("loc_baseline", pt_loc_baseline)
            baseline_value = trivial_baseline(torch.ones(1)).squeeze()
            loc_latent = pyro.sample("loc_latent",
                                     fakes.NonreparameterizedNormal(loc_q, sig_q).to_event(1),
                                     infer=dict(baseline=dict(baseline_value=baseline_value)))

            for i in pyro.plate("outer", 3):
                with pyro.plate("inner_%d" % i, 4 - i):
                    for k in range(n_superfluous_top + n_superfluous_bottom):
                        z_baseline = pyro.module("z_baseline_%d_%d" % (i, k),
                                                 pt_superfluous_baselines[3 * k + i])
                        baseline_value = z_baseline(loc_latent.detach())
                        mean_i = pyro.param("mean_%d_%d" % (i, k),
                                            0.5 * torch.ones(4 - i))
                        z_i_k = pyro.sample("z_%d_%d" % (i, k),
                                            fakes.NonreparameterizedNormal(mean_i, 1),
                                            infer=dict(baseline=dict(baseline_value=baseline_value)))
                        assert z_i_k.shape == (4 - i,)

        def per_param_callable(module_name, param_name):
            if 'baseline' in param_name or 'baseline' in module_name:
                return {"lr": 0.010, "betas": (0.95, 0.999)}
            else:
                return {"lr": 0.0012, "betas": (0.95, 0.999)}

        adam = optim.Adam(per_param_callable)
        svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())

        for step in range(n_steps):
            svi.step()

            loc_error = param_abs_error("loc_q", self.analytic_loc_n)
            log_sig_error = param_abs_error("log_sig_q", self.analytic_log_sig_n)

            if n_superfluous_top > 0 or n_superfluous_bottom > 0:
                superfluous_errors = []
                for k in range(n_superfluous_top + n_superfluous_bottom):
                    mean_0_error = torch.sum(torch.pow(pyro.param("mean_0_%d" % k), 2.0))
                    mean_1_error = torch.sum(torch.pow(pyro.param("mean_1_%d" % k), 2.0))
                    mean_2_error = torch.sum(torch.pow(pyro.param("mean_2_%d" % k), 2.0))
                    superfluous_error = torch.max(torch.max(mean_0_error, mean_1_error), mean_2_error)
                    superfluous_errors.append(superfluous_error.detach().cpu().numpy())

            if step % 500 == 0:
                logger.debug("loc error, log(scale) error:  %.4f, %.4f" % (loc_error, log_sig_error))
                if n_superfluous_top > 0 or n_superfluous_bottom > 0:
                    logger.debug("superfluous error: %.4f" % np.max(superfluous_errors))

        assert_equal(0.0, loc_error, prec=0.04)
        assert_equal(0.0, log_sig_error, prec=0.05)
        if n_superfluous_top > 0 or n_superfluous_bottom > 0:
            assert_equal(0.0, np.max(superfluous_errors), prec=0.04)