예제 #1
0
def test_systematic_sample(size):
    pyro.set_rng_seed(size)
    probs = torch.randn(size).exp()
    probs /= probs.sum()

    num_samples = 20000
    index = _systematic_sample(probs.expand(num_samples, size))
    histogram = torch.zeros_like(probs)
    histogram.scatter_add_(-1, index.reshape(-1),
                           probs.new_ones(1).expand(num_samples * size))

    expected = probs * size
    actual = histogram / num_samples
    assert_close(actual, expected, atol=0.01)
예제 #2
0
def assert_dist_close(d1, d2):
    x = torch.arange(float(200))
    p1 = d1.log_prob(x).exp()
    p2 = d2.log_prob(x).exp()

    assert (p1.sum() - 1).abs() < 1e-3, "incomplete mass"
    assert (p2.sum() - 1).abs() < 1e-3, "incomplete mass"

    mean1 = (p1 * x).sum()
    mean2 = (p2 * x).sum()
    assert_close(mean1, mean2, rtol=0.05)

    max_prob = torch.max(p1.max(), p2.max())
    assert (p1 - p2).abs().max() / max_prob < 0.05
예제 #3
0
def test_log_prob_constant_rate_2(num_leaves, num_steps, batch_shape,
                                  sample_shape):
    rate = torch.randn(batch_shape).exp()
    rate_grid = rate.unsqueeze(-1).expand(batch_shape + (num_steps, ))
    leaf_times = torch.rand(batch_shape + (num_leaves, )).pow(0.5) * num_steps

    d1 = CoalescentTimes(leaf_times, rate)
    coal_times = d1.sample(sample_shape)
    log_prob_1 = d1.log_prob(coal_times)

    d2 = CoalescentTimesWithRate(leaf_times, rate_grid)
    log_prob_2 = d2.log_prob(coal_times)

    assert_close(log_prob_1, log_prob_2)
예제 #4
0
def test_condition(sample_shape, batch_shape, left, right):
    dim = left + right
    gaussian = random_gaussian(batch_shape, dim)
    gaussian.precision += torch.eye(dim) * 0.1
    value = torch.randn(sample_shape + (1, ) * len(batch_shape) + (dim, ))
    left_value, right_value = value[..., :left], value[..., left:]

    conditioned = gaussian.condition(right_value)
    assert conditioned.batch_shape == sample_shape + gaussian.batch_shape
    assert conditioned.dim() == left

    actual = conditioned.log_density(left_value)
    expected = gaussian.log_density(value)
    assert_close(actual, expected)
예제 #5
0
def test_zinb_mean_variance(gate, total_count, logits):
    num_samples = 1000000
    zinb_ = ZeroInflatedNegativeBinomial(
        torch.tensor(gate),
        total_count=torch.tensor(total_count),
        logits=torch.tensor(logits),
    )
    s = zinb_.sample((num_samples,))
    expected_mean = zinb_.mean
    estimated_mean = s.mean()
    expected_std = zinb_.stddev
    estimated_std = s.std()
    assert_close(expected_mean, estimated_mean, atol=1e-01)
    assert_close(expected_std, estimated_std, atol=1e-1)
예제 #6
0
def test_beta_binomial(sample_shape, batch_shape):
    concentration1 = torch.randn(batch_shape).exp()
    concentration0 = torch.randn(batch_shape).exp()
    total = 10
    obs = dist.Binomial(total, 0.2).sample(sample_shape + batch_shape)

    f = dist.Beta(concentration1, concentration0)
    g = dist.Beta(1 + obs, 1 + total - obs)
    fg, log_normalizer = f.conjugate_update(g)

    x = fg.sample(sample_shape)
    assert_close(
        f.log_prob(x) + g.log_prob(x),
        fg.log_prob(x) + log_normalizer)
예제 #7
0
def test_dirichlet_multinomial(sample_shape, batch_shape):
    concentration = torch.randn(batch_shape + (3, )).exp()
    total = 10
    probs = torch.tensor([0.2, 0.3, 0.5])
    obs = dist.Multinomial(total, probs).sample(sample_shape + batch_shape)

    f = dist.Dirichlet(concentration)
    g = dist.Dirichlet(1 + obs)
    fg, log_normalizer = f.conjugate_update(g)

    x = fg.sample(sample_shape)
    assert_close(
        f.log_prob(x) + g.log_prob(x),
        fg.log_prob(x) + log_normalizer)
예제 #8
0
def test_symmetric_stable(shape):
    stability = torch.empty(shape).uniform_(1.6, 1.9).requires_grad_()
    scale = torch.empty(shape).uniform_(0.5, 1.0).requires_grad_()
    loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_()
    params = [stability, scale, loc]

    def model():
        with pyro.plate_stack("plates", shape):
            with pyro.plate("particles", 200000):
                return pyro.sample("x", dist.Stable(stability, 0, scale, loc))

    value = model()
    expected_moments = get_moments(value)

    reparam_model = poutine.reparam(model, {"x": SymmetricStableReparam()})
    trace = poutine.trace(reparam_model).get_trace()
    assert isinstance(trace.nodes["x"]["fn"], dist.Normal)
    trace.compute_log_prob()  # smoke test only
    value = trace.nodes["x"]["value"]
    actual_moments = get_moments(value)
    assert_close(actual_moments, expected_moments, atol=0.05)

    for actual_m, expected_m in zip(actual_moments, expected_moments):
        expected_grads = grad(expected_m.sum(), params, retain_graph=True)
        actual_grads = grad(actual_m.sum(), params, retain_graph=True)
        assert_close(actual_grads[0], expected_grads[0], atol=0.2)
        assert_close(actual_grads[1], expected_grads[1], atol=0.1)
        assert_close(actual_grads[2], expected_grads[2], atol=0.1)
예제 #9
0
def test_kl_independent_normal_mvn(batch_shape, size):
    loc = torch.randn(batch_shape + (size, ))
    scale = torch.randn(batch_shape + (size, )).exp()
    p1 = dist.Normal(loc, scale).to_event(1)
    p2 = dist.MultivariateNormal(loc, scale_tril=scale.diag_embed())

    loc = torch.randn(batch_shape + (size, ))
    cov = torch.randn(batch_shape + (size, size))
    cov = cov @ cov.transpose(-1, -2) + 0.01 * torch.eye(size)
    q = dist.MultivariateNormal(loc, covariance_matrix=cov)

    actual = kl_divergence(p1, q)
    expected = kl_divergence(p2, q)
    assert_close(actual, expected)
예제 #10
0
def test_polya_gamma(batch_shape, num_points=20000):
    d = TruncatedPolyaGamma(prototype=torch.ones(1)).expand(batch_shape)

    # test density approximately normalized
    x = torch.linspace(1.0e-6, d.truncation_point,
                       num_points).expand(batch_shape + (num_points, ))
    prob = (d.truncation_point / num_points) * torch.logsumexp(d.log_prob(x),
                                                               dim=-1).exp()
    assert_close(prob, torch.tensor(1.0).expand(batch_shape), rtol=1.0e-4)

    # test mean of approximate sampler
    z = d.sample(sample_shape=(3000, ))
    mean = z.mean(-1)
    assert_close(mean, torch.tensor(0.25).expand(batch_shape), rtol=0.07)
예제 #11
0
def test_logsumexp(batch_shape, dim):
    g = random_gamma_gaussian(batch_shape, dim)
    g.info_vec *= 0.1  # approximately centered
    g.precision += torch.eye(dim) * 0.1
    s = torch.randn(batch_shape).exp() + 0.2

    num_samples = 200000
    scale = 10
    samples = torch.rand((num_samples, ) + (1, ) * len(batch_shape) +
                         (dim, )) * scale - scale / 2
    expected = g.log_density(samples, s).logsumexp(0) + math.log(
        scale**dim / num_samples)
    actual = g.event_logsumexp().log_density(s)
    assert_close(actual, expected, atol=0.05, rtol=0.05)
예제 #12
0
def test_gamma_and_mvn_to_gamma_gaussian(sample_shape, batch_shape, dim):
    gamma = random_gamma(batch_shape)
    mvn = random_mvn(batch_shape, dim)
    g = gamma_and_mvn_to_gamma_gaussian(gamma, mvn)
    value = mvn.sample(sample_shape)
    s = gamma.sample(sample_shape)
    actual_log_prob = g.log_density(value, s)

    s_log_prob = gamma.log_prob(s)
    scaled_prec = mvn.precision_matrix * s.unsqueeze(-1).unsqueeze(-1)
    mvn_log_prob = dist.MultivariateNormal(
        mvn.loc, precision_matrix=scaled_prec).log_prob(value)
    expected_log_prob = s_log_prob + mvn_log_prob
    assert_close(actual_log_prob, expected_log_prob)
예제 #13
0
파일: test_hmm.py 프로젝트: yufengwa/pyro
def test_gaussian_mrf_log_prob(sample_shape, batch_shape, num_steps,
                               hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_dist = random_mvn(batch_shape + (num_steps, ),
                            hidden_dim + hidden_dim)
    obs_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim + obs_dim)
    d = dist.GaussianMRF(init_dist, trans_dist, obs_dist)
    data = obs_dist.sample(sample_shape)[..., hidden_dim:]
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)

    # Compare against hand-computed density.
    # We will construct enormous unrolled joint gaussians with shapes:
    #       t | 0 1 2 3 1 2 3      T = 3 in this example
    #   ------+-----------------------------------------
    #    init | H
    #   trans | H H H H            H = hidden
    #     obs |   H H H O O O      O = observed
    # and then combine these using gaussian_tensordot().
    T = num_steps
    init = mvn_to_gaussian(init_dist)
    trans = mvn_to_gaussian(trans_dist)
    obs = mvn_to_gaussian(obs_dist)

    unrolled_trans = reduce(operator.add, [
        trans[..., t].event_pad(left=t * hidden_dim,
                                right=(T - t - 1) * hidden_dim)
        for t in range(T)
    ])
    unrolled_obs = reduce(operator.add, [
        obs[..., t].event_pad(left=t * obs.dim(),
                              right=(T - t - 1) * obs.dim()) for t in range(T)
    ])
    # Permute obs from HOHOHO to HHHOOO.
    perm = torch.cat(
        [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] +
        [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)])
    unrolled_obs = unrolled_obs.event_permute(perm)
    unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, ))

    assert init.dim() == hidden_dim
    assert unrolled_trans.dim() == (1 + T) * hidden_dim
    assert unrolled_obs.dim() == T * (hidden_dim + obs_dim)
    logp_h = gaussian_tensordot(init, unrolled_trans, hidden_dim)
    logp_oh = gaussian_tensordot(logp_h, unrolled_obs, T * hidden_dim)
    logp_h += unrolled_obs.marginalize(right=T * obs_dim)
    expected_log_prob = logp_oh.log_density(
        unrolled_data) - logp_h.event_logsumexp()
    assert_close(actual_log_prob, expected_log_prob)
예제 #14
0
def test_relaxed_overdispersed_beta_binomial(overdispersion):
    total_count = torch.arange(1, 17)
    concentration1 = torch.logspace(-1, 2, 8).unsqueeze(-1)
    concentration0 = concentration1.unsqueeze(-1)

    d1 = beta_binomial_dist(concentration1, concentration0, total_count,
                            overdispersion=overdispersion)
    assert isinstance(d1, dist.ExtendedBetaBinomial)

    with set_relaxed_distributions():
        d2 = beta_binomial_dist(concentration1, concentration0, total_count,
                                overdispersion=overdispersion)
    assert isinstance(d2, dist.Normal)
    assert_close(d2.mean, d1.mean)
    assert_close(d2.variance, d1.variance.clamp(min=_RELAX_MIN_VARIANCE))
예제 #15
0
def test_posterior_predictive_svi_auto_diag_normal_guide(return_trace):
    true_probs = torch.ones(5) * 0.7
    num_trials = torch.ones(5) * 1000
    num_success = dist.Binomial(num_trials, true_probs).sample()
    conditioned_model = poutine.condition(model, data={"obs": num_success})
    guide = AutoDiagonalNormal(conditioned_model)
    svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO())
    for i in range(1000):
        svi.step(num_trials)
    posterior_predictive = Predictive(model, guide=guide, num_samples=10000, parallel=True)
    if return_trace:
        marginal_return_vals = posterior_predictive.get_vectorized_trace(num_trials).nodes["obs"]["value"]
    else:
        marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"]
    assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05)
예제 #16
0
def test_log_prob_scale(num_leaves, num_steps, batch_shape, sample_shape):
    rate = torch.randn(batch_shape).exp()

    leaf_times_1 = torch.rand(batch_shape + (num_leaves,)).pow(0.5) * num_steps
    d1 = CoalescentTimes(leaf_times_1)
    coal_times_1 = d1.sample(sample_shape)
    log_prob_1 = d1.log_prob(coal_times_1)

    leaf_times_2 = leaf_times_1 / rate.unsqueeze(-1)
    coal_times_2 = coal_times_1 / rate.unsqueeze(-1)
    d2 = CoalescentTimes(leaf_times_2, rate)
    log_prob_2 = d2.log_prob(coal_times_2)

    log_abs_det_jacobian = -coal_times_2.size(-1) * rate.log()
    assert_close(log_prob_1 - log_abs_det_jacobian, log_prob_2)
예제 #17
0
def test_gamma_poisson(sample_shape, batch_shape):
    concentration = torch.randn(batch_shape).exp()
    rate = torch.randn(batch_shape).exp()
    nobs = 5
    obs = dist.Poisson(10.0).sample((nobs, ) + sample_shape +
                                    batch_shape).sum(0)

    f = dist.Gamma(concentration, rate)
    g = dist.Gamma(1 + obs, nobs)
    fg, log_normalizer = f.conjugate_update(g)

    x = fg.sample(sample_shape)
    assert_close(
        f.log_prob(x) + g.log_prob(x),
        fg.log_prob(x) + log_normalizer)
예제 #18
0
def test_zinb_1_gate(total_count, probs):
    # if gate is 1 ZINB is Delta(0)
    zinb1 = ZeroInflatedNegativeBinomial(total_count=torch.tensor(total_count),
                                         gate=torch.ones(1),
                                         probs=torch.tensor(probs))
    zinb2 = ZeroInflatedNegativeBinomial(total_count=torch.tensor(total_count),
                                         gate_logits=torch.tensor(math.inf),
                                         probs=torch.tensor(probs))
    delta = Delta(torch.zeros(1))
    s = torch.tensor([0.0, 1.0])
    zinb1_prob = zinb1.log_prob(s)
    zinb2_prob = zinb2.log_prob(s)
    delta_prob = delta.log_prob(s)
    assert_close(zinb1_prob, delta_prob)
    assert_close(zinb2_prob, delta_prob)
예제 #19
0
def test_pickling(wrapper):
    wrapped = wrapper(_model)
    buffer = io.BytesIO()
    # default protocol cannot serialize torch.Size objects (see https://github.com/pytorch/pytorch/issues/20823)
    torch.save(wrapped, buffer, pickle_protocol=pickle.HIGHEST_PROTOCOL)
    buffer.seek(0)
    deserialized = torch.load(buffer)
    obs = torch.tensor(0.5)
    pyro.set_rng_seed(0)
    actual_trace = poutine.trace(deserialized).get_trace(obs)
    pyro.set_rng_seed(0)
    expected_trace = poutine.trace(wrapped).get_trace(obs)
    assert tuple(actual_trace) == tuple(expected_trace.nodes)
    assert_close([actual_trace.nodes[site]['value'] for site in actual_trace.stochastic_nodes],
                 [expected_trace.nodes[site]['value'] for site in expected_trace.stochastic_nodes])
예제 #20
0
def test_sequential_logmatmulexp(batch_shape, state_dim, num_steps):
    logits = torch.randn(batch_shape + (num_steps, state_dim, state_dim))
    actual = _sequential_logmatmulexp(logits)
    assert actual.shape == batch_shape + (state_dim, state_dim)

    # Check against einsum.
    operands = list(logits.unbind(-3))
    symbol = (opt_einsum.get_symbol(i) for i in range(1000))
    batch_symbols = ''.join(next(symbol) for _ in batch_shape)
    state_symbols = [next(symbol) for _ in range(num_steps + 1)]
    equation = (','.join(batch_symbols + state_symbols[t] + state_symbols[t + 1]
                         for t in range(num_steps)) +
                '->' + batch_symbols + state_symbols[0] + state_symbols[-1])
    expected = opt_einsum.contract(equation, *operands, backend='pyro.ops.einsum.torch_log')
    assert_close(actual, expected)
예제 #21
0
def test_zinb_0_gate(total_count, probs):
    # if gate is 0 ZINB is NegativeBinomial
    zinb1 = ZeroInflatedNegativeBinomial(total_count=torch.tensor(total_count),
                                         gate=torch.zeros(1),
                                         probs=torch.tensor(probs))
    zinb2 = ZeroInflatedNegativeBinomial(total_count=torch.tensor(total_count),
                                         gate_logits=torch.tensor(-99.9),
                                         probs=torch.tensor(probs))
    neg_bin = NegativeBinomial(torch.tensor(total_count),
                               probs=torch.tensor(probs))
    s = neg_bin.sample((20, ))
    zinb1_prob = zinb1.log_prob(s)
    zinb2_prob = zinb2.log_prob(s)
    neg_bin_prob = neg_bin.log_prob(s)
    assert_close(zinb1_prob, neg_bin_prob)
    assert_close(zinb2_prob, neg_bin_prob)
예제 #22
0
def test_posterior_predictive_svi_one_hot():
    pseudocounts = torch.ones(3) * 0.1
    true_probs = torch.tensor([0.15, 0.6, 0.25])
    classes = dist.OneHotCategorical(true_probs).sample((10000, ))
    guide = AutoDelta(one_hot_model)
    svi = SVI(one_hot_model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO())
    for i in range(1000):
        svi.step(pseudocounts, classes=classes)
    posterior_samples = Predictive(guide,
                                   num_samples=10000).get_samples(pseudocounts)
    posterior_predictive = Predictive(one_hot_model, posterior_samples)
    marginal_return_vals = posterior_predictive.get_samples(
        pseudocounts)["obs"]
    assert_close(marginal_return_vals.mean(dim=0),
                 true_probs.unsqueeze(0),
                 rtol=0.1)
예제 #23
0
def test_matrix_and_mvn_to_gamma_gaussian(sample_shape, batch_shape, x_dim,
                                          y_dim):
    matrix = torch.randn(batch_shape + (x_dim, y_dim))
    y_mvn = random_mvn(batch_shape, y_dim)
    g = matrix_and_mvn_to_gamma_gaussian(matrix, y_mvn)
    xy = torch.randn(sample_shape + batch_shape + (x_dim + y_dim, ))
    s = torch.rand(sample_shape + batch_shape)
    actual_log_prob = g.log_density(xy, s)

    x, y = xy[..., :x_dim], xy[..., x_dim:]
    y_pred = x.unsqueeze(-2).matmul(matrix).squeeze(-2)
    loc = y_pred + y_mvn.loc
    scaled_prec = y_mvn.precision_matrix * s.unsqueeze(-1).unsqueeze(-1)
    expected_log_prob = dist.MultivariateNormal(
        loc, precision_matrix=scaled_prec).log_prob(y)
    assert_close(actual_log_prob, expected_log_prob)
예제 #24
0
def test_posterior_predictive_svi_auto_delta_guide(parallel):
    true_probs = torch.ones(5) * 0.7
    num_trials = torch.ones(5) * 1000
    num_success = dist.Binomial(num_trials, true_probs).sample()
    conditioned_model = poutine.condition(model, data={"obs": num_success})
    guide = AutoDelta(conditioned_model)
    svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=1.0)), Trace_ELBO())
    for i in range(1000):
        svi.step(num_trials)
    posterior_predictive = Predictive(model,
                                      guide=guide,
                                      num_samples=10000,
                                      parallel=parallel)
    marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"]
    assert_close(marginal_return_vals.mean(dim=0),
                 torch.ones(5) * 700,
                 rtol=0.05)
예제 #25
0
def test_reparam_log_joint(model, kwargs):
    guide = AutoIAFNormal(model)
    guide(**kwargs)
    neutra = NeuTraReparam(guide)
    reparam_model = neutra.reparam(model)
    _, pe_fn, transforms, _ = initialize_model(model, model_kwargs=kwargs)
    init_params, pe_fn_neutra, _, _ = initialize_model(
        reparam_model, model_kwargs=kwargs
    )
    latent_x = list(init_params.values())[0]
    transformed_params = neutra.transform_sample(latent_x)
    pe_transformed = pe_fn_neutra(init_params)
    neutra_transform = ComposeTransform(guide.get_posterior(**kwargs).transforms)
    latent_y = neutra_transform(latent_x)
    log_det_jacobian = neutra_transform.log_abs_det_jacobian(latent_x, latent_y)
    pe = pe_fn({k: transforms[k](v) for k, v in transformed_params.items()})
    assert_close(pe_transformed, pe - log_det_jacobian)
예제 #26
0
def test_dirichlet_multinomial_log_prob(total_count, batch_shape, is_sparse):
    event_shape = (3, )
    concentration = torch.rand(batch_shape + event_shape).exp()
    # test on one-hots
    value = total_count * torch.eye(3).reshape(event_shape +
                                               (1, ) * len(batch_shape) +
                                               event_shape)

    num_samples = 100000
    probs = dist.Dirichlet(concentration).sample((num_samples, 1))
    log_probs = dist.Multinomial(total_count, probs).log_prob(value)
    assert log_probs.shape == (num_samples, ) + event_shape + batch_shape
    expected = log_probs.logsumexp(0) - math.log(num_samples)

    actual = DirichletMultinomial(concentration, total_count,
                                  is_sparse).log_prob(value)
    assert_close(actual, expected, atol=0.05)
예제 #27
0
def test_posterior_predictive_svi_manual_guide(parallel):
    true_probs = torch.ones(5) * 0.7
    num_trials = torch.ones(5) * 1000
    num_success = dist.Binomial(num_trials, true_probs).sample()
    conditioned_model = poutine.condition(model, data={"obs": num_success})
    elbo = Trace_ELBO(num_particles=100, vectorize_particles=True)
    svi = SVI(conditioned_model, beta_guide, optim.Adam(dict(lr=1.0)), elbo)
    for i in range(1000):
        svi.step(num_trials)
    posterior_predictive = Predictive(
        model,
        guide=beta_guide,
        num_samples=10000,
        parallel=parallel,
        return_sites=["_RETURN"],
    )
    marginal_return_vals = posterior_predictive(num_trials)["_RETURN"]
    assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05)
예제 #28
0
def test_posterior_predictive_svi_auto_diag_normal_guide():
    true_probs = torch.ones(5) * 0.7
    num_trials = torch.ones(5) * 1000
    num_success = dist.Binomial(num_trials, true_probs).sample()
    conditioned_model = poutine.condition(model, data={"obs": num_success})
    opt = optim.Adam(dict(lr=0.1))
    loss = Trace_ELBO()
    guide = AutoDiagonalNormal(conditioned_model)
    svi_run = SVI(conditioned_model,
                  guide,
                  opt,
                  loss,
                  num_steps=1000,
                  num_samples=100).run(num_trials)
    posterior_predictive = TracePredictive(model, svi_run,
                                           num_samples=10000).run(num_trials)
    marginal_return_vals = posterior_predictive.marginal().empirical["_RETURN"]
    assert_close(marginal_return_vals.mean, torch.ones(5) * 700, rtol=0.05)
예제 #29
0
def test_posterior_predictive_svi_one_hot():
    pseudocounts = torch.ones(3) * 0.1
    true_probs = torch.tensor([0.15, 0.6, 0.25])
    classes = dist.OneHotCategorical(true_probs).sample((10000, ))
    opt = optim.Adam(dict(lr=0.1))
    loss = Trace_ELBO()
    guide = AutoDelta(one_hot_model)
    svi_run = SVI(one_hot_model,
                  guide,
                  opt,
                  loss,
                  num_steps=1000,
                  num_samples=1000).run(pseudocounts, classes=classes)
    posterior_predictive = TracePredictive(one_hot_model,
                                           svi_run,
                                           num_samples=10000).run(pseudocounts)
    marginal_return_vals = posterior_predictive.marginal().empirical["_RETURN"]
    assert_close(marginal_return_vals.mean, true_probs.unsqueeze(0), rtol=0.1)
예제 #30
0
def test_uniform(shape, dim, smooth):
    def model():
        with pyro.plate_stack("plates", shape[:dim]):
            with pyro.plate("particles", 10000):
                pyro.sample("x",
                            dist.Uniform(0, 1).expand(shape).to_event(-dim))

    value = poutine.trace(model).get_trace().nodes["x"]["value"]
    expected_probe = get_moments(value)

    reparam_model = poutine.reparam(
        model, {"x": DiscreteCosineReparam(dim=dim, smooth=smooth)})
    trace = poutine.trace(reparam_model).get_trace()
    assert isinstance(trace.nodes["x_dct"]["fn"], dist.TransformedDistribution)
    assert isinstance(trace.nodes["x"]["fn"], dist.Delta)
    value = trace.nodes["x"]["value"]
    actual_probe = get_moments(value)
    assert_close(actual_probe, expected_probe, atol=0.1)