def scale_and_mask(tensor, scale=1.0, mask=None): """ Scale and mask a packed tensor, broadcasting and avoiding unnecessary ops. :param torch.Tensor tensor: a packed tensor :param scale: a positive scale :type scale: torch.Tensor or number :param mask: an optional packed tensor mask :type mask: torch.BoolTensor, bool, or None """ if isinstance(scale, torch.Tensor) and scale.dim(): raise NotImplementedError('non-scalar scale is not supported') if mask is None or mask is True: if is_identically_one(scale): return tensor result = tensor * scale result._pyro_dims = tensor._pyro_dims return result if mask is False: result = torch.zeros_like(tensor) result._pyro_dims = tensor._pyro_dims return result tensor, mask = broadcast_all(tensor, mask) result = torch.where(mask, tensor, tensor.new_zeros(())) result._pyro_dims = tensor._pyro_dims return result
def __call__(self, name, fn, obs): assert obs is None, "LocScaleReparam does not support observe statements" centered = self.centered if is_identically_one(centered): return name, fn, obs event_shape = fn.event_shape fn, event_dim = self._unwrap(fn) # Apply a partial decentering transform. params = {key: getattr(fn, key) for key in self.shape_params} if self.centered is None: centered = pyro.param("{}_centered", lambda: fn.loc.new_full(event_shape, 0.5), constraint=constraints.unit_interval) params["loc"] = fn.loc * centered params["scale"] = fn.scale ** centered decentered_fn = type(fn)(**params) # Draw decentered noise. decentered_value = pyro.sample("{}_decentered".format(name), self._wrap(decentered_fn, event_dim)) # Differentiably transform. delta = decentered_value - centered * fn.loc value = fn.loc + fn.scale.pow(1 - centered) * delta # Simulate a pyro.deterministic() site. new_fn = dist.Delta(value, event_dim=event_dim).mask(False) return new_fn, value
def __mul__(self, scale): """ Scale appropriate terms of a gradient estimator by a data multiplicity factor. Note that the `score_function` term should not be scaled. """ if is_identically_one(scale): return self log_prob = scale_tensor(self.log_prob, scale) score_function = scale_tensor(self.score_function, torch_sign(scale)) entropy_term = scale_tensor(self.entropy_term, scale) return ScoreParts(log_prob, score_function, entropy_term)
def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] centered = self.centered if is_identically_one(centered): return msg event_shape = fn.event_shape fn, event_dim = self._unwrap(fn) # Apply a partial decentering transform. if self.shape_params is None: self.shape_params = tuple(k for k in fn.arg_constraints if k not in ("loc", "scale")) params = {key: getattr(fn, key) for key in self.shape_params} if centered is None: centered = pyro.param( "{}_centered".format(name), lambda: fn.loc.new_full(event_shape, 0.5), constraint=constraints.unit_interval, ) params["loc"] = fn.loc * centered params["scale"] = fn.scale**centered decentered_fn = type(fn)(**params) # Differentiably invert transform. decentered_value = None if value is not None: delta = (value - fn.loc) * fn.scale.pow(centered - 1) decentered_value = delta + centered * fn.loc # Draw decentered noise. decentered_value = pyro.sample( f"{name}_decentered", self._wrap(decentered_fn, event_dim), obs=decentered_value, infer={"is_observed": is_observed}, ) # Differentiably transform. if value is None: delta = decentered_value - centered * fn.loc value = fn.loc + fn.scale.pow(1 - centered) * delta # Simulate a pyro.deterministic() site. new_fn = dist.Delta(value, event_dim=event_dim).mask(False) return {"fn": new_fn, "value": value, "is_observed": True}
def test_moments(dist_type, centered, shape): loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() if isinstance(centered, torch.Tensor): centered = centered.expand(shape) def model(): with pyro.plate_stack("plates", shape): with pyro.plate("particles", 200000): if "dist_type" == "Normal": pyro.sample("x", dist.Normal(loc, scale)) elif "dist_type" == "StudentT": pyro.sample("x", dist.StudentT(10.0, loc, scale)) else: pyro.sample("x", dist.AsymmetricLaplace(loc, scale, 1.5)) value = poutine.trace(model).get_trace().nodes["x"]["value"] expected_probe = get_moments(value) reparam = LocScaleReparam(centered) reparam_model = poutine.reparam(model, {"x": reparam}) value = poutine.trace(reparam_model).get_trace().nodes["x"]["value"] actual_probe = get_moments(value) if not is_identically_one(centered): if "dist_type" == "Normal": assert reparam.shape_params == () elif "dist_type" == "StudentT": assert reparam.shape_params == ("df", ) else: assert reparam.shape_params == ("asymmetry", ) assert_close(actual_probe, expected_probe, atol=0.1, rtol=0.05) for actual_m, expected_m in zip(actual_probe, expected_probe): expected_grads = grad(expected_m.sum(), [loc, scale], retain_graph=True) actual_grads = grad(actual_m.sum(), [loc, scale], retain_graph=True) assert_close(actual_grads[0], expected_grads[0], atol=0.1, rtol=0.05) assert_close(actual_grads[1], expected_grads[1], atol=0.1, rtol=0.05)
def scale_and_mask(tensor, scale=1.0, mask=None): """ Scale and mask a packed tensor, broadcasting and avoiding unnecessary ops. :param torch.Tensor tensor: a packed tensor :param scale: a positive scale :type scale: torch.Tensor or number :param mask: an optional packed tensor mask :type mask: torch.ByteTensor or None """ if isinstance(scale, torch.Tensor) and scale.dim(): raise NotImplementedError('non-scalar scale is not supported') if mask is None: if is_identically_one(scale): return tensor result = tensor * scale result._pyro_dims = tensor._pyro_dims return result tensor, mask = broadcast_all(tensor, mask) result = tensor * scale # triggers a copy, avoiding in-place op errors result.masked_fill_(~mask, 0.) result._pyro_dims = tensor._pyro_dims return result