示例#1
0
def test_subsequent_expands_ok(dist, sample_shapes, default):
    for idx in range(dist.get_num_test_data()):
        d = dist.pyro_dist(**dist.get_dist_params(idx))
        original_batch_shape = d.batch_shape
        for shape in sample_shapes:
            proposed_batch_shape = torch.Size(shape) + original_batch_shape
            if default:
                n = TorchDistribution.expand(d, proposed_batch_shape)
            else:
                with xfail_if_not_implemented():
                    n = d.expand(proposed_batch_shape)
            assert n.batch_shape == proposed_batch_shape
            with xfail_if_not_implemented():
                check_sample_shapes(d, n)
            d = n
示例#2
0
def test_gof(continuous_dist):
    Dist = continuous_dist.pyro_dist
    if Dist in [dist.LKJ, dist.LKJCholesky]:
        pytest.xfail(reason="incorrect submanifold scaling")

    num_samples = 50000
    for i in range(continuous_dist.get_num_test_data()):
        d = Dist(**continuous_dist.get_dist_params(i))
        samples = d.sample(torch.Size([num_samples]))
        with xfail_if_not_implemented():
            probs = d.log_prob(samples).exp()

        dim = None
        if "ProjectedNormal" in Dist.__name__:
            dim = samples.size(-1) - 1

        # Test each batch independently.
        probs = probs.reshape(num_samples, -1)
        samples = samples.reshape(probs.shape + d.event_shape)
        if "Dirichlet" in Dist.__name__:
            # The Dirichlet density is over all but one of the probs.
            samples = samples[..., :-1]
        for b in range(probs.size(-1)):
            gof = auto_goodness_of_fit(samples[:, b], probs[:, b], dim=dim)
            assert gof > TEST_FAILURE_RATE
示例#3
0
def test_bern_elbo_gradient(enum_discrete, trace_graph):
    pyro.clear_param_store()
    num_particles = 2000

    def model():
        p = Variable(torch.Tensor([0.25]))
        pyro.sample("z", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", Variable(torch.Tensor([0.5]), requires_grad=True))
        pyro.sample("z", dist.Bernoulli(p))

    print("Computing gradients using surrogate loss")
    Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO
    elbo = Elbo(enum_discrete=enum_discrete,
                num_particles=(1 if enum_discrete else num_particles))
    with xfail_if_not_implemented():
        elbo.loss_and_grads(model, guide)
    params = sorted(pyro.get_param_store().get_all_param_names())
    assert params, "no params found"
    actual_grads = {name: pyro.param(name).grad.clone() for name in params}

    print("Computing gradients using finite difference")
    elbo = Trace_ELBO(num_particles=num_particles)
    expected_grads = finite_difference(lambda: elbo.loss(model, guide))

    for name in params:
        print("{} {}{}{}".format(name, "-" * 30, actual_grads[name].data,
                                 expected_grads[name].data))
    assert_equal(actual_grads, expected_grads, prec=0.1)
示例#4
0
def test_expand_new_dim(dist, sample_shape, shape_type):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        with xfail_if_not_implemented():
            large = small.expand(shape_type(sample_shape + small.batch_shape))
            assert large.batch_shape == sample_shape + small.batch_shape
            check_sample_shapes(small, large)
示例#5
0
def test_pyrocov_reparam(model, Guide, backend):
    T, P, S, F = 2, 3, 4, 5
    dataset = {
        "features": torch.randn(S, F),
        "local_time": torch.randn(T, P),
        "weekly_strains": torch.randn(T, P, S).exp().round(),
    }

    # Reparametrize the model.
    config = {
        "coef": LocScaleReparam(),
        "rate_loc": None if model is pyrocov_model else LocScaleReparam(),
        "rate": LocScaleReparam(),
        "init_loc": LocScaleReparam(),
        "init": LocScaleReparam(),
    }
    model = poutine.reparam(model, config)
    guide = Guide(model, backend=backend)
    svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), Trace_ELBO())
    for step in range(2):
        with xfail_if_not_implemented():
            svi.step(dataset)
    guide(dataset)
    predictive = Predictive(model, guide=guide, num_samples=2)
    predictive(dataset)
示例#6
0
def test_sample_shape_smoke(num_nodes, sample_shape, dtype, bp_iters):
    logits = torch.randn(num_nodes, num_nodes, dtype=dtype)
    d = dist.OneOneMatching(logits, bp_iters=bp_iters)
    with xfail_if_not_implemented():
        values = d.sample(sample_shape)
    assert values.shape == sample_shape + (num_nodes, )
    assert d.support.check(values).all()
示例#7
0
def test_batch_log_pdf_mask(dist):
    if dist.get_test_distribution_name() not in ('Normal', 'Bernoulli',
                                                 'Categorical'):
        pytest.skip('Batch pdf masking not supported for the distribution.')
    d = dist.pyro_dist
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        x = dist.get_test_data(idx)
        with xfail_if_not_implemented():
            batch_pdf_shape = d.batch_shape(**dist_params) + (1, )
            batch_pdf_shape_broadcasted = d.batch_shape(x, **
                                                        dist_params) + (1, )
            zeros_mask = ng_zeros(1)  # should be broadcasted to data dims
            ones_mask = ng_ones(
                batch_pdf_shape)  # should be broadcasted to data dims
            half_mask = ng_ones(1) * 0.5
            batch_log_pdf = d.batch_log_pdf(x, **dist_params)
            batch_log_pdf_zeros_mask = d.batch_log_pdf(x,
                                                       log_pdf_mask=zeros_mask,
                                                       **dist_params)
            batch_log_pdf_ones_mask = d.batch_log_pdf(x,
                                                      log_pdf_mask=ones_mask,
                                                      **dist_params)
            batch_log_pdf_half_mask = d.batch_log_pdf(x,
                                                      log_pdf_mask=half_mask,
                                                      **dist_params)
            assert_equal(batch_log_pdf_ones_mask, batch_log_pdf)
            assert_equal(batch_log_pdf_zeros_mask,
                         ng_zeros(batch_pdf_shape_broadcasted))
            assert_equal(batch_log_pdf_half_mask, 0.5 * batch_log_pdf)
示例#8
0
def test_bern_elbo_gradient(enum_discrete, trace_graph):
    pyro.clear_param_store()
    num_particles = 2000

    def model():
        p = Variable(torch.Tensor([0.25]))
        pyro.sample("z", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", Variable(torch.Tensor([0.5]), requires_grad=True))
        pyro.sample("z", dist.Bernoulli(p))

    print("Computing gradients using surrogate loss")
    Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO
    elbo = Elbo(enum_discrete=enum_discrete,
                num_particles=(1 if enum_discrete else num_particles))
    with xfail_if_not_implemented():
        elbo.loss_and_grads(model, guide)
    params = sorted(pyro.get_param_store().get_all_param_names())
    assert params, "no params found"
    actual_grads = {name: pyro.param(name).grad.clone() for name in params}

    print("Computing gradients using finite difference")
    elbo = Trace_ELBO(num_particles=num_particles)
    expected_grads = finite_difference(lambda: elbo.loss(model, guide))

    for name in params:
        print("{} {}{}{}".format(name, "-" * 30, actual_grads[name].data,
                                 expected_grads[name].data))
    assert_equal(actual_grads, expected_grads, prec=0.1)
示例#9
0
def test_shape(dist):
    d = dist.pyro_dist
    for idx in dist.get_test_data_indices():
        dist_params = dist.get_dist_params(idx)
        with xfail_if_not_implemented():
            assert d.shape(**dist_params) == d.batch_shape(
                **dist_params) + d.event_shape(**dist_params)
示例#10
0
def test_expand_error(dist, initial_shape, proposed_shape):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        with xfail_if_not_implemented():
            large = small.expand(torch.Size(initial_shape) + small.batch_shape)
            proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape
            with pytest.raises(RuntimeError):
                large.expand(proposed_batch_shape)
示例#11
0
def test_trace_handler(model, backend):
    with pyro_backend(backend), handlers.seed(
            rng_seed=2), xfail_if_not_implemented():
        f = MODELS[model]()
        model, model_args, model_kwargs = f['model'], f.get(
            'model_args', ()), f.get('model_kwargs', {})
        # should be implemented
        handlers.trace(model).get_trace(*model_args, **model_kwargs)
示例#12
0
def test_batch_entropy_shape(dist):
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        with xfail_if_not_implemented():
            # Get entropy shape after broadcasting.
            expected_shape = _log_prob_shape(d)
            entropy_obj = d.entropy()
            assert entropy_obj.size() == expected_shape
示例#13
0
def test_sample_shape(dist):
    d = dist.pyro_dist
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        x_func = dist.pyro_dist.sample(**dist_params)
        x_obj = dist.pyro_dist_obj(**dist_params).sample()
        assert_equal(x_obj.size(), x_func.size())
        with xfail_if_not_implemented():
            assert(x_func.size() == d.shape(x_func, **dist_params))
示例#14
0
def test_sample_shape(dist):
    d = dist.pyro_dist
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        x_func = dist.pyro_dist.sample(**dist_params)
        x_obj = dist.pyro_dist_obj(**dist_params).sample()
        assert_equal(x_obj.size(), x_func.size())
        with xfail_if_not_implemented():
            assert (x_func.size() == d.shape(x_func, **dist_params))
示例#15
0
def test_svi_step_smoke(model, guide, enum_discrete, trace_graph):
    pyro.clear_param_store()
    data = Variable(torch.Tensor([0, 1, 9]))

    optimizer = pyro.optim.Adam({"lr": .001})
    inference = SVI(model, guide, optimizer, loss="ELBO",
                    trace_graph=trace_graph, enum_discrete=enum_discrete)
    with xfail_if_not_implemented():
        inference.step(data)
示例#16
0
def test_autocovariance():
    x = torch.arange(10.)
    with xfail_if_not_implemented():
        actual = autocovariance(x)
    assert_equal(actual,
                 torch.tensor([
                     8.25, 6.42, 4.25, 1.75, -1.08, -4.25, -7.75, -11.58,
                     -15.75, -20.25
                 ]),
                 prec=0.01)
示例#17
0
def test_autocorrelation():
    x = torch.arange(10.0)
    with xfail_if_not_implemented():
        actual = autocorrelation(x)
    assert_equal(
        actual,
        torch.tensor(
            [1, 0.78, 0.52, 0.21, -0.13, -0.52, -0.94, -1.4, -1.91, -2.45]),
        prec=0.01,
    )
示例#18
0
def test_batch_log_prob_shape(dist):
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        x = dist.get_test_data(idx)
        with xfail_if_not_implemented():
            # Get log_prob shape after broadcasting.
            expected_shape = _log_prob_shape(d, x.size())
            log_p_obj = d.log_prob(x)
            assert log_p_obj.size() == expected_shape
示例#19
0
def test_expand_twice(dist):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        medium = small.expand(torch.Size((2, 1)) + small.batch_shape)
        batch_shape = torch.Size((2, 3)) + small.batch_shape
        with xfail_if_not_implemented():
            large = medium.expand(batch_shape)
        assert large.batch_shape == batch_shape
        check_sample_shapes(small, large)
        check_sample_shapes(medium, large)
示例#20
0
def test_batch_log_prob_shape(dist):
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        x = dist.get_test_data(idx)
        with xfail_if_not_implemented():
            # Get log_prob shape after broadcasting.
            expected_shape = _log_prob_shape(d, x.size())
            log_p_obj = d.log_prob(x)
            assert log_p_obj.size() == expected_shape
示例#21
0
def test_expand_twice(dist):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        medium = small.expand(torch.Size((2, 1)) + small.batch_shape)
        batch_shape = torch.Size((2, 3)) + small.batch_shape
        with xfail_if_not_implemented():
            large = medium.expand(batch_shape)
        assert large.batch_shape == batch_shape
        check_sample_shapes(small, large)
        check_sample_shapes(medium, large)
示例#22
0
def test_expand_error(dist, initial_shape, proposed_shape, default):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        if default:
            large = TorchDistribution.expand(small, initial_shape + small.batch_shape)
        else:
            with xfail_if_not_implemented():
                large = small.expand(torch.Size(initial_shape) + small.batch_shape)
        proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape
        with pytest.raises((RuntimeError, ValueError)):
            large.expand(proposed_batch_shape)
示例#23
0
def test_cdf_icdf(continuous_dist):
    Dist = continuous_dist.pyro_dist
    for i in range(continuous_dist.get_num_test_data()):
        d = Dist(**continuous_dist.get_dist_params(i))
        if d.event_shape.numel() != 1:
            continue  # only valid for univariate distributions
        u = torch.empty((100, ) + d.shape()).uniform_()
        with xfail_if_not_implemented():
            x = d.icdf(u)
            u2 = d.cdf(x)
        assert_equal(u, u2)
示例#24
0
def test_von_mises_3d_gof(scale):
    concentration = torch.randn(3)
    concentration = concentration * (scale / concentration.norm(2))
    d = VonMises3D(concentration, validate_args=True)

    with xfail_if_not_implemented():
        samples = d.sample(torch.Size([2000]))
    probs = d.log_prob(samples).exp()

    gof = auto_goodness_of_fit(samples, probs, dim=2)
    assert gof > TEST_FAILURE_RATE
示例#25
0
def test_expand_error(dist, initial_shape, proposed_shape):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        with xfail_if_not_implemented():
            large = small.expand(torch.Size(initial_shape) + small.batch_shape)
            proposed_batch_shape = torch.Size(proposed_shape) + small.batch_shape
            if dist.get_test_distribution_name() == 'LKJCorrCholesky':
                pytest.skip('LKJCorrCholesky can expand to a shape not' +
                            'broadcastable with its original batch_shape.')
            with pytest.raises(RuntimeError):
                large.expand(proposed_batch_shape)
示例#26
0
def test_enumerate_support(num_edges):
    pyro.set_rng_seed(2**32 - num_edges)
    E = num_edges
    V = 1 + E
    K = V * (V - 1) // 2
    edge_logits = torch.randn(K)
    d = SpanningTree(edge_logits)
    with xfail_if_not_implemented():
        support = d.enumerate_support()
    assert support.dim() == 3
    assert support.shape[1:] == d.event_shape
    assert support.size(0) == NUM_SPANNING_TREES[V]
示例#27
0
def test_expand_new_dim(dist, sample_shape, shape_type, default):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        if default:
            large = TorchDistribution.expand(small, shape_type(sample_shape + small.batch_shape))
        else:
            with xfail_if_not_implemented():
                large = small.expand(shape_type(sample_shape + small.batch_shape))
        assert large.batch_shape == sample_shape + small.batch_shape
        if dist.get_test_distribution_name() == 'Stable':
            pytest.skip('Stable does not implement a log_prob method.')
        check_sample_shapes(small, large)
示例#28
0
def test_rng_seed(model, backend):
    with pyro_backend(backend), handlers.seed(
            rng_seed=2), xfail_if_not_implemented():
        f = MODELS[model]()
        model, model_args = f['model'], f.get('model_args', ())
        with handlers.seed(rng_seed=0):
            expected = model(*model_args)
        if expected is None:
            pytest.skip()
        with handlers.seed(rng_seed=0):
            actual = model(*model_args)
        assert ops.allclose(actual, expected)
示例#29
0
def test_subsample_gradient(Elbo, reparameterized, has_rsample, subsample, local_samples, scale):
    pyro.clear_param_store()
    data = torch.tensor([-0.5, 2.0])
    subsample_size = 1 if subsample else len(data)
    precision = 0.06 * scale
    Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal

    def model(subsample):
        with pyro.plate("data", len(data), subsample_size, subsample) as ind:
            x = data[ind]
            z = pyro.sample("z", Normal(0, 1))
            pyro.sample("x", Normal(z, 1), obs=x)

    def guide(subsample):
        scale = pyro.param("scale", lambda: torch.tensor([1.0]))
        with pyro.plate("data", len(data), subsample_size, subsample):
            loc = pyro.param("loc", lambda: torch.zeros(len(data)), event_dim=0)
            z_dist = Normal(loc, scale)
            if has_rsample is not None:
                z_dist.has_rsample_(has_rsample)
            pyro.sample("z", z_dist)

    if scale != 1.0:
        model = poutine.scale(model, scale=scale)
        guide = poutine.scale(guide, scale=scale)

    num_particles = 50000
    if local_samples:
        guide = config_enumerate(guide, num_samples=num_particles)
        num_particles = 1

    optim = Adam({"lr": 0.1})
    elbo = Elbo(max_plate_nesting=1,  # set this to ensure rng agrees across runs
                num_particles=num_particles,
                vectorize_particles=True,
                strict_enumeration_warning=False)
    inference = SVI(model, guide, optim, loss=elbo)
    with xfail_if_not_implemented():
        if subsample_size == 1:
            inference.loss_and_grads(model, guide, subsample=torch.tensor([0], dtype=torch.long))
            inference.loss_and_grads(model, guide, subsample=torch.tensor([1], dtype=torch.long))
        else:
            inference.loss_and_grads(model, guide, subsample=torch.tensor([0, 1], dtype=torch.long))
    params = dict(pyro.get_param_store().named_parameters())
    normalizer = 2 if subsample else 1
    actual_grads = {name: param.grad.detach().cpu().numpy() / normalizer for name, param in params.items()}

    expected_grads = {'loc': scale * np.array([0.5, -2.0]), 'scale': scale * np.array([2.0])}
    for name in sorted(params):
        logger.info('expected {} = {}'.format(name, expected_grads[name]))
        logger.info('actual   {} = {}'.format(name, actual_grads[name]))
    assert_equal(actual_grads, expected_grads, prec=precision)
示例#30
0
def test_log_prob(num_edges):
    pyro.set_rng_seed(2**32 - num_edges)
    E = num_edges
    V = 1 + E
    K = V * (V - 1) // 2
    edge_logits = torch.randn(K)
    d = SpanningTree(edge_logits)
    with xfail_if_not_implemented():
        support = d.enumerate_support()
    log_probs = d.log_prob(support)
    assert log_probs.shape == (len(support), )
    log_total = log_probs.logsumexp(0).item()
    assert abs(log_total) < 1e-6, log_total
示例#31
0
def test_expand_existing_dim(dist, shape_type):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        for dim, size in enumerate(small.batch_shape):
            if size != 1:
                continue
            batch_shape = list(small.batch_shape)
            batch_shape[dim] = 5
            batch_shape = torch.Size(batch_shape)
            with xfail_if_not_implemented():
                large = small.expand(shape_type(batch_shape))
            assert large.batch_shape == batch_shape
            check_sample_shapes(small, large)
示例#32
0
def test_svi_step_smoke(model, guide, enum_discrete, trace_graph):
    pyro.clear_param_store()
    data = Variable(torch.Tensor([0, 1, 9]))

    optimizer = pyro.optim.Adam({"lr": .001})
    inference = SVI(model,
                    guide,
                    optimizer,
                    loss="ELBO",
                    trace_graph=trace_graph,
                    enum_discrete=enum_discrete)
    with xfail_if_not_implemented():
        inference.step(data)
示例#33
0
def test_expand_existing_dim(dist, shape_type):
    for idx in range(dist.get_num_test_data()):
        small = dist.pyro_dist(**dist.get_dist_params(idx))
        for dim, size in enumerate(small.batch_shape):
            if size != 1:
                continue
            batch_shape = list(small.batch_shape)
            batch_shape[dim] = 5
            batch_shape = torch.Size(batch_shape)
            with xfail_if_not_implemented():
                large = small.expand(shape_type(batch_shape))
            assert large.batch_shape == batch_shape
            check_sample_shapes(small, large)
示例#34
0
def test_batch_log_pdf_shape(dist):
    d = dist.pyro_dist
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        x = dist.get_test_data(idx)
        with xfail_if_not_implemented():
            # Get batch pdf shape after broadcasting.
            expected_shape = d.batch_shape(x, **dist_params) + (1,)
            log_p_func = d.batch_log_pdf(x, **dist_params)
            log_p_obj = dist.pyro_dist_obj(**dist_params).batch_log_pdf(x)
            # assert that the functional and object forms return
            # the same batch pdf.
            assert_equal(log_p_func.size(), log_p_obj.size())
            assert log_p_func.size() == expected_shape
示例#35
0
def test_broken_plates_smoke(backend):
    def model():
        with pyro.plate("i", 2):
            a = pyro.sample("a", dist.Normal(0, 1))
        pyro.sample("b", dist.Normal(a.mean(-1), 1), obs=torch.tensor(0.0))

    guide = AutoGaussian(model, backend=backend)
    svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), Trace_ELBO())
    for step in range(2):
        with xfail_if_not_implemented():
            svi.step()
    guide()
    predictive = Predictive(model, guide=guide, num_samples=2)
    predictive()
示例#36
0
def test_batch_log_pdf_shape(dist):
    d = dist.pyro_dist
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        x = dist.get_test_data(idx)
        with xfail_if_not_implemented():
            # Get batch pdf shape after broadcasting.
            expected_shape = d.batch_shape(x, **dist_params) + (1, )
            log_p_func = d.batch_log_pdf(x, **dist_params)
            log_p_obj = dist.pyro_dist_obj(**dist_params).batch_log_pdf(x)
            # assert that the functional and object forms return
            # the same batch pdf.
            assert_equal(log_p_func.size(), log_p_obj.size())
            assert log_p_func.size() == expected_shape
示例#37
0
def test_partition_function(num_edges):
    pyro.set_rng_seed(2**32 - num_edges)
    E = num_edges
    V = 1 + E
    K = V * (V - 1) // 2
    edge_logits = torch.randn(K)
    d = SpanningTree(edge_logits)
    with xfail_if_not_implemented():
        support = d.enumerate_support()
    v1 = support[..., 0]
    v2 = support[..., 1]
    k = v1 + v2 * (v2 - 1) // 2
    expected = edge_logits[k].sum(-1).logsumexp(0)
    actual = d.log_partition_function
    assert (actual - expected).abs() < 1e-6, (actual, expected)
示例#38
0
def test_iter_discrete_traces_nan(enum_discrete, trace_graph):
    pyro.clear_param_store()

    def model():
        p = Variable(torch.Tensor([0.0, 0.5, 1.0]))
        pyro.sample("z", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", Variable(torch.Tensor([0.0, 0.5, 1.0]), requires_grad=True))
        pyro.sample("z", dist.Bernoulli(p))

    Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO
    elbo = Elbo(enum_discrete=enum_discrete)
    with xfail_if_not_implemented():
        loss = elbo.loss(model, guide)
        assert isinstance(loss, float) and not math.isnan(loss), loss
        loss = elbo.loss_and_grads(model, guide)
        assert isinstance(loss, float) and not math.isnan(loss), loss
示例#39
0
def test_log_prob(dist):
    for idx in range(len(dist.dist_params)):

        # Compute CPU value.
        with tensors_default_to("cpu"):
            data = dist.get_test_data(idx)
            params = dist.get_dist_params(idx)
        with xfail_if_not_implemented():
            cpu_value = dist.pyro_dist(**params).log_prob(data)
        assert not cpu_value.is_cuda

        # Compute GPU value.
        with tensors_default_to("cuda"):
            data = dist.get_test_data(idx)
            params = dist.get_dist_params(idx)
        cuda_value = dist.pyro_dist(**params).log_prob(data)
        assert cuda_value.is_cuda

        assert_equal(cpu_value, cuda_value.cpu())
示例#40
0
def test_sample(dist):
    for idx in range(len(dist.dist_params)):

        # Compute CPU value.
        with tensors_default_to("cpu"):
            params = dist.get_dist_params(idx)
        try:
            with xfail_if_not_implemented():
                cpu_value = dist.pyro_dist(**params).sample()
        except ValueError as e:
            pytest.xfail('CPU version fails: {}'.format(e))
        assert not cpu_value.is_cuda

        # Compute GPU value.
        with tensors_default_to("cuda"):
            params = dist.get_dist_params(idx)
        cuda_value = dist.pyro_dist(**params).sample()
        assert cuda_value.is_cuda

        assert_equal(cpu_value.size(), cuda_value.size())
示例#41
0
def test_batch_log_pdf_mask(dist):
    if dist.get_test_distribution_name() not in ('Normal', 'Bernoulli', 'Categorical'):
        pytest.skip('Batch pdf masking not supported for the distribution.')
    d = dist.pyro_dist
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        x = dist.get_test_data(idx)
        with xfail_if_not_implemented():
            batch_pdf_shape = d.batch_shape(**dist_params) + (1,)
            batch_pdf_shape_broadcasted = d.batch_shape(x, **dist_params) + (1,)
            zeros_mask = ng_zeros(1)  # should be broadcasted to data dims
            ones_mask = ng_ones(batch_pdf_shape)  # should be broadcasted to data dims
            half_mask = ng_ones(1) * 0.5
            batch_log_pdf = d.batch_log_pdf(x, **dist_params)
            batch_log_pdf_zeros_mask = d.batch_log_pdf(x, log_pdf_mask=zeros_mask, **dist_params)
            batch_log_pdf_ones_mask = d.batch_log_pdf(x, log_pdf_mask=ones_mask, **dist_params)
            batch_log_pdf_half_mask = d.batch_log_pdf(x, log_pdf_mask=half_mask, **dist_params)
            assert_equal(batch_log_pdf_ones_mask, batch_log_pdf)
            assert_equal(batch_log_pdf_zeros_mask, ng_zeros(batch_pdf_shape_broadcasted))
            assert_equal(batch_log_pdf_half_mask, 0.5 * batch_log_pdf)
示例#42
0
def test_rsample(dist):
    if not dist.pyro_dist.has_rsample:
        return
    for idx in range(len(dist.dist_params)):

        # Compute CPU value.
        with tensors_default_to("cpu"):
            params = dist.get_dist_params(idx)
            grad_params = [key for key, val in params.items()
                           if torch.is_tensor(val) and val.dtype in (torch.float32, torch.float64)]
            for key in grad_params:
                val = params[key].clone()
                val.requires_grad = True
                params[key] = val
        try:
            with xfail_if_not_implemented():
                cpu_value = dist.pyro_dist(**params).rsample()
                cpu_grads = grad(cpu_value.sum(), [params[key] for key in grad_params])
        except ValueError as e:
            pytest.xfail('CPU version fails: {}'.format(e))
        assert not cpu_value.is_cuda

        # Compute GPU value.
        with tensors_default_to("cuda"):
            params = dist.get_dist_params(idx)
            for key in grad_params:
                val = params[key].clone()
                val.requires_grad = True
                params[key] = val
        cuda_value = dist.pyro_dist(**params).rsample()
        assert cuda_value.is_cuda
        assert_equal(cpu_value.size(), cuda_value.size())

        cuda_grads = grad(cuda_value.sum(), [params[key] for key in grad_params])
        for cpu_grad, cuda_grad in zip(cpu_grads, cuda_grads):
            assert_equal(cpu_grad.size(), cuda_grad.size())
示例#43
0
def test_shape(dist):
    d = dist.pyro_dist
    for idx in dist.get_test_data_indices():
        dist_params = dist.get_dist_params(idx)
        with xfail_if_not_implemented():
            assert d.shape(**dist_params) == d.batch_shape(**dist_params) + d.event_shape(**dist_params)