Exemplo n.º 1
0
def scale_and_mask(tensor, scale=1.0, mask=None):
    """
    Scale and mask a packed tensor, broadcasting and avoiding unnecessary ops.

    :param torch.Tensor tensor: a packed tensor
    :param scale: a positive scale
    :type scale: torch.Tensor or number
    :param mask: an optional packed tensor mask
    :type mask: torch.BoolTensor, bool, or None
    """
    if isinstance(scale, torch.Tensor) and scale.dim():
        raise NotImplementedError('non-scalar scale is not supported')
    if mask is None or mask is True:
        if is_identically_one(scale):
            return tensor
        result = tensor * scale
        result._pyro_dims = tensor._pyro_dims
        return result
    if mask is False:
        result = torch.zeros_like(tensor)
        result._pyro_dims = tensor._pyro_dims
        return result
    tensor, mask = broadcast_all(tensor, mask)
    result = torch.where(mask, tensor, tensor.new_zeros(()))
    result._pyro_dims = tensor._pyro_dims
    return result
Exemplo n.º 2
0
    def __call__(self, name, fn, obs):
        assert obs is None, "LocScaleReparam does not support observe statements"
        centered = self.centered
        if is_identically_one(centered):
            return name, fn, obs
        event_shape = fn.event_shape
        fn, event_dim = self._unwrap(fn)

        # Apply a partial decentering transform.
        params = {key: getattr(fn, key) for key in self.shape_params}
        if self.centered is None:
            centered = pyro.param("{}_centered",
                                  lambda: fn.loc.new_full(event_shape, 0.5),
                                  constraint=constraints.unit_interval)
        params["loc"] = fn.loc * centered
        params["scale"] = fn.scale ** centered
        decentered_fn = type(fn)(**params)

        # Draw decentered noise.
        decentered_value = pyro.sample("{}_decentered".format(name),
                                       self._wrap(decentered_fn, event_dim))

        # Differentiably transform.
        delta = decentered_value - centered * fn.loc
        value = fn.loc + fn.scale.pow(1 - centered) * delta

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(value, event_dim=event_dim).mask(False)
        return new_fn, value
Exemplo n.º 3
0
 def __mul__(self, scale):
     """
     Scale appropriate terms of a gradient estimator by a data multiplicity factor.
     Note that the `score_function` term should not be scaled.
     """
     if is_identically_one(scale):
         return self
     log_prob = scale_tensor(self.log_prob, scale)
     score_function = scale_tensor(self.score_function, torch_sign(scale))
     entropy_term = scale_tensor(self.entropy_term, scale)
     return ScoreParts(log_prob, score_function, entropy_term)
Exemplo n.º 4
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]

        centered = self.centered
        if is_identically_one(centered):
            return msg
        event_shape = fn.event_shape
        fn, event_dim = self._unwrap(fn)

        # Apply a partial decentering transform.
        if self.shape_params is None:
            self.shape_params = tuple(k for k in fn.arg_constraints
                                      if k not in ("loc", "scale"))
        params = {key: getattr(fn, key) for key in self.shape_params}
        if centered is None:
            centered = pyro.param(
                "{}_centered".format(name),
                lambda: fn.loc.new_full(event_shape, 0.5),
                constraint=constraints.unit_interval,
            )
        params["loc"] = fn.loc * centered
        params["scale"] = fn.scale**centered
        decentered_fn = type(fn)(**params)

        # Differentiably invert transform.
        decentered_value = None
        if value is not None:
            delta = (value - fn.loc) * fn.scale.pow(centered - 1)
            decentered_value = delta + centered * fn.loc

        # Draw decentered noise.
        decentered_value = pyro.sample(
            f"{name}_decentered",
            self._wrap(decentered_fn, event_dim),
            obs=decentered_value,
            infer={"is_observed": is_observed},
        )

        # Differentiably transform.
        if value is None:
            delta = decentered_value - centered * fn.loc
            value = fn.loc + fn.scale.pow(1 - centered) * delta

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(value, event_dim=event_dim).mask(False)
        return {"fn": new_fn, "value": value, "is_observed": True}
Exemplo n.º 5
0
def test_moments(dist_type, centered, shape):
    loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_()
    scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_()
    if isinstance(centered, torch.Tensor):
        centered = centered.expand(shape)

    def model():
        with pyro.plate_stack("plates", shape):
            with pyro.plate("particles", 200000):
                if "dist_type" == "Normal":
                    pyro.sample("x", dist.Normal(loc, scale))
                elif "dist_type" == "StudentT":
                    pyro.sample("x", dist.StudentT(10.0, loc, scale))
                else:
                    pyro.sample("x", dist.AsymmetricLaplace(loc, scale, 1.5))

    value = poutine.trace(model).get_trace().nodes["x"]["value"]
    expected_probe = get_moments(value)

    reparam = LocScaleReparam(centered)
    reparam_model = poutine.reparam(model, {"x": reparam})
    value = poutine.trace(reparam_model).get_trace().nodes["x"]["value"]
    actual_probe = get_moments(value)

    if not is_identically_one(centered):
        if "dist_type" == "Normal":
            assert reparam.shape_params == ()
        elif "dist_type" == "StudentT":
            assert reparam.shape_params == ("df", )
        else:
            assert reparam.shape_params == ("asymmetry", )

    assert_close(actual_probe, expected_probe, atol=0.1, rtol=0.05)

    for actual_m, expected_m in zip(actual_probe, expected_probe):
        expected_grads = grad(expected_m.sum(), [loc, scale],
                              retain_graph=True)
        actual_grads = grad(actual_m.sum(), [loc, scale], retain_graph=True)
        assert_close(actual_grads[0], expected_grads[0], atol=0.1, rtol=0.05)
        assert_close(actual_grads[1], expected_grads[1], atol=0.1, rtol=0.05)
Exemplo n.º 6
0
def scale_and_mask(tensor, scale=1.0, mask=None):
    """
    Scale and mask a packed tensor, broadcasting and avoiding unnecessary ops.

    :param torch.Tensor tensor: a packed tensor
    :param scale: a positive scale
    :type scale: torch.Tensor or number
    :param mask: an optional packed tensor mask
    :type mask: torch.ByteTensor or None
    """
    if isinstance(scale, torch.Tensor) and scale.dim():
        raise NotImplementedError('non-scalar scale is not supported')
    if mask is None:
        if is_identically_one(scale):
            return tensor
        result = tensor * scale
        result._pyro_dims = tensor._pyro_dims
        return result
    tensor, mask = broadcast_all(tensor, mask)
    result = tensor * scale  # triggers a copy, avoiding in-place op errors
    result.masked_fill_(~mask, 0.)
    result._pyro_dims = tensor._pyro_dims
    return result