def test_detach_transformed(shape): loc = torch.tensor(0.0, requires_grad=True) scale = torch.tensor(1.0, requires_grad=True) a = torch.tensor(2.0, requires_grad=True) b = torch.tensor(3.0, requires_grad=True) d1 = dist.TransformedDistribution( dist.Normal(loc, scale), dist.transforms.AffineTransform(a, b) ) if shape is not None: d1 = d1.expand(shape) d2 = detach(d1) assert type(d1) is type(d2) assert d2.event_shape == d1.event_shape assert d2.batch_shape == d1.batch_shape assert type(d1.base_dist) is type(d2.base_dist) assert len(d1.transforms) == len(d2.transforms) assert_equal(d1.base_dist.loc, d2.base_dist.loc) assert_equal(d1.base_dist.scale, d2.base_dist.scale) assert_equal(d1.transforms[0].loc, d2.transforms[0].loc) assert_equal(d1.transforms[0].scale, d2.transforms[0].scale) assert not d2.base_dist.loc.requires_grad assert not d2.base_dist.scale.requires_grad assert not d2.transforms[0].loc.requires_grad assert not d2.transforms[0].scale.requires_grad
def test_detach(): import torch try: from pyro.distributions.util import detach except ImportError: pytest.skip("detach() is not available") x = Tensor(torch.randn(2, 3, requires_grad=True)) y = detach(x) assert_close(x, y) assert x.data.requires_grad assert not y.data.requires_grad
def _construct_baseline(node, guide_site, downstream_cost): # XXX should the average baseline be in the param store as below? baseline = 0.0 baseline_loss = 0.0 ( nn_baseline, nn_baseline_input, use_decaying_avg_baseline, baseline_beta, baseline_value, ) = _get_baseline_options(guide_site) use_nn_baseline = nn_baseline is not None use_baseline_value = baseline_value is not None use_baseline = use_nn_baseline or use_decaying_avg_baseline or use_baseline_value assert not (use_nn_baseline and use_baseline_value ), "cannot use baseline_value and nn_baseline simultaneously" if use_decaying_avg_baseline: dc_shape = downstream_cost.shape param_name = "__baseline_avg_downstream_cost_" + node with torch.no_grad(): avg_downstream_cost_old = pyro.param( param_name, torch.zeros(dc_shape, device=guide_site["value"].device)) avg_downstream_cost_new = ( 1 - baseline_beta ) * downstream_cost + baseline_beta * avg_downstream_cost_old pyro.get_param_store()[param_name] = avg_downstream_cost_new baseline += avg_downstream_cost_old if use_nn_baseline: # block nn_baseline_input gradients except in baseline loss baseline += nn_baseline(detach(nn_baseline_input)) elif use_baseline_value: # it's on the user to make sure baseline_value tape only points to baseline params baseline += baseline_value if use_nn_baseline or use_baseline_value: # accumulate baseline loss baseline_loss += torch.pow(downstream_cost.detach() - baseline, 2.0).sum() if use_baseline: if downstream_cost.shape != baseline.shape: raise ValueError( "Expected baseline at site {} to be {} instead got {}".format( node, downstream_cost.shape, baseline.shape)) return use_baseline, baseline_loss, baseline
def test_detach_normal(shape): loc = torch.tensor(0., requires_grad=True) scale = torch.tensor(1., requires_grad=True) d1 = dist.Normal(loc, scale) if shape is not None: d1 = d1.expand(shape) d2 = detach(d1) assert type(d1) is type(d2) assert_equal(d1.loc, d2.loc) assert_equal(d1.scale, d2.scale) assert not d2.loc.requires_grad assert not d2.scale.requires_grad
def test_detach_beta(shape): concentration1 = torch.tensor(0.5, requires_grad=True) concentration0 = torch.tensor(2.0, requires_grad=True) d1 = dist.Beta(concentration1, concentration0) if shape is not None: d1 = d1.expand(shape) d2 = detach(d1) assert type(d1) is type(d2) assert d2.batch_shape == d1.batch_shape assert_equal(d1.concentration1, d2.concentration1) assert_equal(d1.concentration0, d2.concentration0) assert not d2.concentration1.requires_grad assert not d2.concentration0.requires_grad
def fn(loc, scale, data): d = dist.Normal(loc, scale, validate_args=False) if shape is not None: d = d.expand(shape) return detach(d).log_prob(data)