예제 #1
0
파일: test_util.py 프로젝트: pyro-ppl/pyro
def test_detach_transformed(shape):
    loc = torch.tensor(0.0, requires_grad=True)
    scale = torch.tensor(1.0, requires_grad=True)
    a = torch.tensor(2.0, requires_grad=True)
    b = torch.tensor(3.0, requires_grad=True)
    d1 = dist.TransformedDistribution(
        dist.Normal(loc, scale), dist.transforms.AffineTransform(a, b)
    )
    if shape is not None:
        d1 = d1.expand(shape)

    d2 = detach(d1)
    assert type(d1) is type(d2)
    assert d2.event_shape == d1.event_shape
    assert d2.batch_shape == d1.batch_shape
    assert type(d1.base_dist) is type(d2.base_dist)
    assert len(d1.transforms) == len(d2.transforms)
    assert_equal(d1.base_dist.loc, d2.base_dist.loc)
    assert_equal(d1.base_dist.scale, d2.base_dist.scale)
    assert_equal(d1.transforms[0].loc, d2.transforms[0].loc)
    assert_equal(d1.transforms[0].scale, d2.transforms[0].scale)
    assert not d2.base_dist.loc.requires_grad
    assert not d2.base_dist.scale.requires_grad
    assert not d2.transforms[0].loc.requires_grad
    assert not d2.transforms[0].scale.requires_grad
예제 #2
0
def test_detach():
    import torch
    try:
        from pyro.distributions.util import detach
    except ImportError:
        pytest.skip("detach() is not available")
    x = Tensor(torch.randn(2, 3, requires_grad=True))
    y = detach(x)
    assert_close(x, y)
    assert x.data.requires_grad
    assert not y.data.requires_grad
예제 #3
0
def _construct_baseline(node, guide_site, downstream_cost):

    # XXX should the average baseline be in the param store as below?

    baseline = 0.0
    baseline_loss = 0.0

    (
        nn_baseline,
        nn_baseline_input,
        use_decaying_avg_baseline,
        baseline_beta,
        baseline_value,
    ) = _get_baseline_options(guide_site)

    use_nn_baseline = nn_baseline is not None
    use_baseline_value = baseline_value is not None

    use_baseline = use_nn_baseline or use_decaying_avg_baseline or use_baseline_value

    assert not (use_nn_baseline and use_baseline_value
                ), "cannot use baseline_value and nn_baseline simultaneously"
    if use_decaying_avg_baseline:
        dc_shape = downstream_cost.shape
        param_name = "__baseline_avg_downstream_cost_" + node
        with torch.no_grad():
            avg_downstream_cost_old = pyro.param(
                param_name,
                torch.zeros(dc_shape, device=guide_site["value"].device))
            avg_downstream_cost_new = (
                1 - baseline_beta
            ) * downstream_cost + baseline_beta * avg_downstream_cost_old
        pyro.get_param_store()[param_name] = avg_downstream_cost_new
        baseline += avg_downstream_cost_old
    if use_nn_baseline:
        # block nn_baseline_input gradients except in baseline loss
        baseline += nn_baseline(detach(nn_baseline_input))
    elif use_baseline_value:
        # it's on the user to make sure baseline_value tape only points to baseline params
        baseline += baseline_value
    if use_nn_baseline or use_baseline_value:
        # accumulate baseline loss
        baseline_loss += torch.pow(downstream_cost.detach() - baseline,
                                   2.0).sum()

    if use_baseline:
        if downstream_cost.shape != baseline.shape:
            raise ValueError(
                "Expected baseline at site {} to be {} instead got {}".format(
                    node, downstream_cost.shape, baseline.shape))

    return use_baseline, baseline_loss, baseline
예제 #4
0
파일: test_util.py 프로젝트: zeta1999/pyro
def test_detach_normal(shape):
    loc = torch.tensor(0., requires_grad=True)
    scale = torch.tensor(1., requires_grad=True)
    d1 = dist.Normal(loc, scale)
    if shape is not None:
        d1 = d1.expand(shape)

    d2 = detach(d1)
    assert type(d1) is type(d2)
    assert_equal(d1.loc, d2.loc)
    assert_equal(d1.scale, d2.scale)
    assert not d2.loc.requires_grad
    assert not d2.scale.requires_grad
예제 #5
0
파일: test_util.py 프로젝트: zeta1999/pyro
def test_detach_beta(shape):
    concentration1 = torch.tensor(0.5, requires_grad=True)
    concentration0 = torch.tensor(2.0, requires_grad=True)
    d1 = dist.Beta(concentration1, concentration0)
    if shape is not None:
        d1 = d1.expand(shape)

    d2 = detach(d1)
    assert type(d1) is type(d2)
    assert d2.batch_shape == d1.batch_shape
    assert_equal(d1.concentration1, d2.concentration1)
    assert_equal(d1.concentration0, d2.concentration0)
    assert not d2.concentration1.requires_grad
    assert not d2.concentration0.requires_grad
예제 #6
0
파일: test_util.py 프로젝트: zeta1999/pyro
 def fn(loc, scale, data):
     d = dist.Normal(loc, scale, validate_args=False)
     if shape is not None:
         d = d.expand(shape)
     return detach(d).log_prob(data)