示例#1
0
 def test_approximate_round(self):
     for dtype in (torch.float, torch.double):
         X = torch.linspace(-2.5, 2.5, 11, device=self.device, dtype=dtype)
         exact_rounded_X = X.round()
         approx_rounded_X = approximate_round(X)
         # check that approximate rounding is closer to rounded values than
         # the original inputs
         rounded_diffs = (approx_rounded_X - exact_rounded_X).abs()
         diffs = (X - exact_rounded_X).abs()
         self.assertTrue((rounded_diffs <= diffs).all())
         # check that not all gradients are zero
         X.requires_grad_(True)
         approximate_round(X).sum().backward()
         self.assertTrue((X.grad.abs() != 0).any())
示例#2
0
    def transform(self, X: Tensor) -> Tensor:
        r"""Round the inputs.

        Args:
            X: A `batch_shape x n x d`-dim tensor of inputs.

        Returns:
            A `batch_shape x n x d`-dim tensor of rounded inputs.
        """
        X_rounded = X.clone()
        X_int = X_rounded[..., self.indices]
        if self.approximate:
            X_int = approximate_round(X_int, tau=self.tau)
        else:
            X_int = X_int.round()
        X_rounded[..., self.indices] = X_int
        return X_rounded