def test_finite_gradient(self, batch_size: int = 10000):
     """
     Tests whether gradients stay finite close to the bounds.
     """
     x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size)
     x.requires_grad = True
     bounds = 1 - 10.0**torch.linspace(-1, -5, 5)
     for lower_bound in -bounds:
         for upper_bound in bounds:
             if upper_bound < lower_bound:
                 continue
             x.grad = None
             y = acos_linear_extrapolation(
                 x,
                 [float(lower_bound),
                  float(upper_bound)],
             )
             self.assertTrue(torch.isfinite(y).all())
             loss = y.mean()
             loss.backward()
             self.assertIsNotNone(x.grad)
             self.assertTrue(torch.isfinite(x.grad).all())
 def _one_acos_test(self, x: torch.Tensor, lower_bound: float,
                    upper_bound: float):
     """
     Test that `acos_linear_extrapolation` returns correct values for
     `x` between/above/below `lower_bound`/`upper_bound`.
     """
     x.requires_grad = True
     x.grad = None
     y = acos_linear_extrapolation(x, [lower_bound, upper_bound])
     # compute the gradient of the acos w.r.t. x
     y.backward(torch.ones_like(y))
     dacos_dx = x.grad
     x_lower = x <= lower_bound
     x_upper = x >= upper_bound
     x_mid = (~x_lower) & (~x_upper)
     # test that between bounds, the function returns plain acos
     self.assertClose(x[x_mid].acos(), y[x_mid])
     # test that outside the bounds, the function is linear with the right
     # slope and continuous around the bound
     self._test_acos_outside_bounds(x[x_upper], y[x_upper],
                                    dacos_dx[x_upper], upper_bound)
     self._test_acos_outside_bounds(x[x_lower], y[x_lower],
                                    dacos_dx[x_lower], lower_bound)
 def compute_acos():
     acos_linear_extrapolation(x)
     torch.cuda.synchronize()