Пример #1
0
class DeltaMethodAnalyticalExpectedGrads(chex.TestCase):

  @chex.all_variants
  @parameterized.named_parameters(
      chex.params_product([
          ('_score_function_jacobians', 1.0, 1.0, sge.score_function_jacobians),
          ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians),
          ('_measure_valued_jacobians', 1.0, 1.0, sge.measure_valued_jacobians),
      ], [
          ('estimate_cv_coeffs', True),
          ('no_estimate_cv_coeffs', False),
      ],
                          named=True))
  def testQuadraticFunction(self, effective_mean, effective_log_scale,
                            grad_estimator, estimate_cv_coeffs):
    data_dims = 3
    num_samples = 10**3

    mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32)
    log_scale = effective_log_scale * jnp.ones(
        shape=(data_dims), dtype=jnp.float32)

    params = [mean, log_scale]
    function = lambda x: jnp.sum(x**2)
    rng = jax.random.PRNGKey(1)

    jacobians = _cv_jac_variant(self.variant)(
        function,
        control_variates.control_delta_method,
        grad_estimator,
        params,
        utils.multi_normal,  # dist_builder
        rng,
        num_samples,
        None,  # No cv state.
        estimate_cv_coeffs)[0]

    expected_mean_grads = 2 * effective_mean * np.ones(
        data_dims, dtype=np.float32)
    expected_log_scale_grads = 2 * np.exp(2 * effective_log_scale) * np.ones(
        data_dims, dtype=np.float32)

    mean_jacobians = jacobians[0]
    chex.assert_shape(mean_jacobians, (num_samples, data_dims))
    mean_grads_from_jacobian = jnp.mean(mean_jacobians, axis=0)

    log_scale_jacobians = jacobians[1]
    chex.assert_shape(log_scale_jacobians, (num_samples, data_dims))
    log_scale_grads_from_jacobian = jnp.mean(log_scale_jacobians, axis=0)

    _assert_equal(mean_grads_from_jacobian, expected_mean_grads,
                  rtol=1e-1, atol=1e-3)
    _assert_equal(log_scale_grads_from_jacobian, expected_log_scale_grads,
                  rtol=1e-1, atol=1e-3)

  @chex.all_variants
  @parameterized.named_parameters(
      chex.params_product([
          ('_score_function_jacobians', 1.0, 1.0, sge.score_function_jacobians),
          ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians),
          ('_measure_valued_jacobians', 1.0, 1.0, sge.measure_valued_jacobians),
      ], [
          ('estimate_cv_coeffs', True),
          ('no_estimate_cv_coeffs', False),
      ],
                          named=True))
  def testCubicFunction(
      self, effective_mean, effective_log_scale, grad_estimator,
      estimate_cv_coeffs):
    data_dims = 1
    num_samples = 10**5

    mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32)
    log_scale = effective_log_scale * jnp.ones(
        shape=(data_dims), dtype=jnp.float32)

    params = [mean, log_scale]
    function = lambda x: jnp.sum(x**3)
    rng = jax.random.PRNGKey(1)

    jacobians = _cv_jac_variant(self.variant)(
        function,
        control_variates.control_delta_method,
        grad_estimator,
        params,
        utils.multi_normal,
        rng,
        num_samples,
        None,  # No cv state.
        estimate_cv_coeffs)[0]

    # The third order uncentered moment of the Gaussian distribution is
    # mu**3 + 2 mu * sigma **2. We use that to compute the expected value
    # of the gradients. Note: for the log scale we need use the chain rule.
    expected_mean_grads = (
        3 * effective_mean**2 + 3 * np.exp(effective_log_scale)**2)
    expected_mean_grads *= np.ones(data_dims, dtype=np.float32)
    expected_log_scale_grads = (
        6 * effective_mean * np.exp(effective_log_scale) ** 2)
    expected_log_scale_grads *= np.ones(data_dims, dtype=np.float32)

    mean_jacobians = jacobians[0]
    chex.assert_shape(mean_jacobians, (num_samples, data_dims))
    mean_grads_from_jacobian = jnp.mean(mean_jacobians, axis=0)

    log_scale_jacobians = jacobians[1]
    chex.assert_shape(log_scale_jacobians, (num_samples, data_dims))
    log_scale_grads_from_jacobian = jnp.mean(log_scale_jacobians, axis=0)

    _assert_equal(mean_grads_from_jacobian, expected_mean_grads,
                  rtol=1e-1, atol=1e-3)

    _assert_equal(log_scale_grads_from_jacobian, expected_log_scale_grads,
                  rtol=1e-1, atol=1e-3)

  @chex.all_variants
  @parameterized.named_parameters(
      chex.params_product([
          ('_score_function_jacobians', 1.0, 1.0, sge.score_function_jacobians),
          ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians),
          ('_measure_valued_jacobians', 1.0, 1.0, sge.measure_valued_jacobians),
      ], [
          ('estimate_cv_coeffs', True),
          ('no_estimate_cv_coeffs', False),
      ],
                          named=True))
  def testForthPowerFunction(
      self, effective_mean, effective_log_scale, grad_estimator,
      estimate_cv_coeffs):
    data_dims = 1
    num_samples = 10**5

    mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32)
    log_scale = effective_log_scale * jnp.ones(
        shape=(data_dims), dtype=jnp.float32)

    params = [mean, log_scale]
    function = lambda x: jnp.sum(x**4)
    rng = jax.random.PRNGKey(1)

    jacobians = _cv_jac_variant(self.variant)(
        function,
        control_variates.control_delta_method,
        grad_estimator,
        params,
        utils.multi_normal,
        rng,
        num_samples,
        None,  # No cv state
        estimate_cv_coeffs)[0]
    # The third order uncentered moment of the Gaussian distribution is
    # mu**4 + 6 mu **2 sigma **2 + 3 sigma**4. We use that to compute the
    # expected value of the gradients.
    # Note: for the log scale we need use the chain rule.
    expected_mean_grads = (
        3 * effective_mean**3
        + 12 * effective_mean * np.exp(effective_log_scale)**2)
    expected_mean_grads *= np.ones(data_dims, dtype=np.float32)
    expected_log_scale_grads = 12 * (
        effective_mean**2 * np.exp(effective_log_scale) +
        np.exp(effective_log_scale) ** 3) * np.exp(effective_log_scale)
    expected_log_scale_grads *= np.ones(data_dims, dtype=np.float32)

    mean_jacobians = jacobians[0]
    chex.assert_shape(mean_jacobians, (num_samples, data_dims))
    mean_grads_from_jacobian = jnp.mean(mean_jacobians, axis=0)

    log_scale_jacobians = jacobians[1]
    chex.assert_shape(log_scale_jacobians, (num_samples, data_dims))
    log_scale_grads_from_jacobian = jnp.mean(log_scale_jacobians, axis=0)

    _assert_equal(mean_grads_from_jacobian, expected_mean_grads,
                  rtol=1e-1, atol=1e-3)

    _assert_equal(log_scale_grads_from_jacobian, expected_log_scale_grads,
                  rtol=1e-1, atol=1e-3)
Пример #2
0
class ConsistencyWithStandardEstimators(chex.TestCase):

  @chex.all_variants
  @parameterized.named_parameters(
      chex.params_product([
          ('_score_function_jacobians', 1, 1, sge.score_function_jacobians,
           10**6),
          ('_pathwise_jacobians', 1, 1, sge.pathwise_jacobians, 10**5),
          ('_measure_valued_jacobians', 1, 1, sge.measure_valued_jacobians,
           10**5),
      ], [
          ('control_delta_method', control_variates.control_delta_method),
          ('moving_avg_baseline', control_variates.moving_avg_baseline),
      ],
                          named=True))
  def testWeightedLinearFunction(self, effective_mean, effective_log_scale,
                                 grad_estimator, num_samples,
                                 control_variate_from_function):
    """Check that the gradients are consistent between estimators."""
    weights = jnp.array([1., 2., 3.], dtype=jnp.float32)
    data_dims = len(weights)

    mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32)
    log_scale = effective_log_scale * jnp.ones(
        shape=(data_dims), dtype=jnp.float32)

    params = [mean, log_scale]
    function = lambda x: jnp.sum(weights * x)
    rng = jax.random.PRNGKey(1)
    cv_rng, ge_rng = jax.random.split(rng)

    jacobians = _cv_jac_variant(self.variant)(
        function,
        control_variate_from_function,
        grad_estimator,
        params,
        utils.multi_normal,  # dist_builder
        cv_rng,  # rng
        num_samples,
        (0., 0),  # control_variate_state
        False)[0]

    mean_jacobians = jacobians[0]
    chex.assert_shape(mean_jacobians, (num_samples, data_dims))
    mean_grads = jnp.mean(mean_jacobians, axis=0)

    log_scale_jacobians = jacobians[1]
    chex.assert_shape(log_scale_jacobians, (num_samples, data_dims))
    log_scale_grads = jnp.mean(log_scale_jacobians, axis=0)

    # We use a different random number generator for the gradient estimator
    # without the control variate.
    no_cv_jacobians = grad_estimator(
        function, [mean, log_scale],
        utils.multi_normal, ge_rng, num_samples=num_samples)

    no_cv_mean_jacobians = no_cv_jacobians[0]
    chex.assert_shape(no_cv_mean_jacobians, (num_samples, data_dims))
    no_cv_mean_grads = jnp.mean(no_cv_mean_jacobians, axis=0)

    no_cv_log_scale_jacobians = no_cv_jacobians[1]
    chex.assert_shape(no_cv_log_scale_jacobians, (num_samples, data_dims))
    no_cv_log_scale_grads = jnp.mean(no_cv_log_scale_jacobians, axis=0)

    _assert_equal(mean_grads, no_cv_mean_grads, rtol=1e-1, atol=5e-2)
    _assert_equal(log_scale_grads, no_cv_log_scale_grads, rtol=1, atol=5e-2)

  @chex.all_variants
  @parameterized.named_parameters(
      chex.params_product([
          ('_score_function_jacobians', 1, 1, sge.score_function_jacobians,
           10**5),
          ('_pathwise_jacobians', 1, 1, sge.pathwise_jacobians, 10**5),
          ('_measure_valued_jacobians', 1, 1, sge.measure_valued_jacobians,
           10**5),
      ], [
          ('control_delta_method', control_variates.control_delta_method),
          ('moving_avg_baseline', control_variates.moving_avg_baseline),
      ],
                          named=True))
  def testNonPolynomialFunction(
      self, effective_mean, effective_log_scale,
      grad_estimator, num_samples, control_variate_from_function):
    """Check that the gradients are consistent between estimators."""
    data_dims = 3

    mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32)
    log_scale = effective_log_scale * jnp.ones(
        shape=(data_dims), dtype=jnp.float32)

    params = [mean, log_scale]
    function = lambda x: jnp.log(jnp.sum(x**2))
    rng = jax.random.PRNGKey(1)
    cv_rng, ge_rng = jax.random.split(rng)

    jacobians = _cv_jac_variant(self.variant)(
        function,
        control_variate_from_function,
        grad_estimator,
        params,
        utils.multi_normal,
        cv_rng,
        num_samples,
        (0., 0),  # control_variate_state
        False)[0]

    mean_jacobians = jacobians[0]
    chex.assert_shape(mean_jacobians, (num_samples, data_dims))
    mean_grads = jnp.mean(mean_jacobians, axis=0)

    log_scale_jacobians = jacobians[1]
    chex.assert_shape(log_scale_jacobians, (num_samples, data_dims))
    log_scale_grads = jnp.mean(log_scale_jacobians, axis=0)

    # We use a different random number generator for the gradient estimator
    # without the control variate.
    no_cv_jacobians = grad_estimator(
        function, [mean, log_scale],
        utils.multi_normal, ge_rng, num_samples=num_samples)

    no_cv_mean_jacobians = no_cv_jacobians[0]
    chex.assert_shape(no_cv_mean_jacobians, (num_samples, data_dims))
    no_cv_mean_grads = jnp.mean(no_cv_mean_jacobians, axis=0)

    no_cv_log_scale_jacobians = no_cv_jacobians[1]
    chex.assert_shape(no_cv_log_scale_jacobians, (num_samples, data_dims))
    no_cv_log_scale_grads = jnp.mean(no_cv_log_scale_jacobians, axis=0)

    _assert_equal(mean_grads, no_cv_mean_grads, rtol=1e-1, atol=5e-2)
    _assert_equal(log_scale_grads, no_cv_log_scale_grads, rtol=1e-1, atol=5e-2)
Пример #3
0
class GradientEstimatorsTest(chex.TestCase):
    @chex.all_variants
    @parameterized.named_parameters(
        chex.params_product([
            ('_score_function_jacobians', sge.score_function_jacobians),
            ('_pathwise_jacobians', sge.pathwise_jacobians),
            ('_measure_valued_jacobians', sge.measure_valued_jacobians),
        ], [
            ('0.1', 0.1),
            ('0.5', 0.5),
            ('0.9', 0.9),
        ],
                            named=True))
    def testConstantFunction(self, estimator, constant):
        data_dims = 3
        num_samples = _estimator_to_num_samples[estimator]

        effective_mean = 1.5
        mean = effective_mean * _ones(data_dims)

        effective_log_scale = 0.0
        log_scale = effective_log_scale * _ones(data_dims)
        rng = jax.random.PRNGKey(1)

        jacobians = _estimator_variant(self.variant, estimator)(
            lambda x: jnp.array(constant), [mean, log_scale],
            utils.multi_normal, rng, num_samples)

        # Average over the number of samples.
        mean_jacobians = jacobians[0]
        chex.assert_shape(mean_jacobians, (num_samples, data_dims))
        mean_grads = np.mean(mean_jacobians, axis=0)
        expected_mean_grads = np.zeros(data_dims, dtype=np.float32)

        log_scale_jacobians = jacobians[1]
        chex.assert_shape(log_scale_jacobians, (num_samples, data_dims))
        log_scale_grads = np.mean(log_scale_jacobians, axis=0)
        expected_log_scale_grads = np.zeros(data_dims, dtype=np.float32)

        _assert_equal(mean_grads, expected_mean_grads, atol=5e-3)
        _assert_equal(log_scale_grads, expected_log_scale_grads, atol=5e-3)

    @chex.all_variants
    @parameterized.named_parameters(
        chex.params_product([
            ('_score_function_jacobians', sge.score_function_jacobians),
            ('_pathwise_jacobians', sge.pathwise_jacobians),
            ('_measure_valued_jacobians', sge.measure_valued_jacobians),
        ], [
            ('0.5_-1.', 0.5, -1.),
            ('0.7_0.0)', 0.7, 0.0),
            ('0.8_0.1', 0.8, 0.1),
        ],
                            named=True))
    def testLinearFunction(self, estimator, effective_mean,
                           effective_log_scale):
        data_dims = 3
        num_samples = _estimator_to_num_samples[estimator]
        rng = jax.random.PRNGKey(1)

        mean = effective_mean * _ones(data_dims)
        log_scale = effective_log_scale * _ones(data_dims)

        jacobians = _estimator_variant(self.variant,
                                       estimator)(np.sum, [mean, log_scale],
                                                  utils.multi_normal, rng,
                                                  num_samples)

        mean_jacobians = jacobians[0]
        chex.assert_shape(mean_jacobians, (num_samples, data_dims))
        mean_grads = np.mean(mean_jacobians, axis=0)
        expected_mean_grads = np.ones(data_dims, dtype=np.float32)

        log_scale_jacobians = jacobians[1]
        chex.assert_shape(log_scale_jacobians, (num_samples, data_dims))
        log_scale_grads = np.mean(log_scale_jacobians, axis=0)
        expected_log_scale_grads = np.zeros(data_dims, dtype=np.float32)

        _assert_equal(mean_grads, expected_mean_grads)
        _assert_equal(log_scale_grads, expected_log_scale_grads)

    @chex.all_variants
    @parameterized.named_parameters(
        chex.params_product([
            ('_score_function_jacobians', sge.score_function_jacobians),
            ('_pathwise_jacobians', sge.pathwise_jacobians),
            ('_measure_valued_jacobians', sge.measure_valued_jacobians),
        ], [
            ('1.0_0.3', 1.0, 0.3),
        ],
                            named=True))
    def testQuadraticFunction(self, estimator, effective_mean,
                              effective_log_scale):
        data_dims = 3
        num_samples = _estimator_to_num_samples[estimator]
        rng = jax.random.PRNGKey(1)

        mean = effective_mean * _ones(data_dims)
        log_scale = effective_log_scale * _ones(data_dims)

        jacobians = _estimator_variant(self.variant,
                                       estimator)(lambda x: np.sum(x**2) / 2,
                                                  [mean, log_scale],
                                                  utils.multi_normal, rng,
                                                  num_samples)

        mean_jacobians = jacobians[0]
        chex.assert_shape(mean_jacobians, (num_samples, data_dims))
        mean_grads = np.mean(mean_jacobians, axis=0)
        expected_mean_grads = effective_mean * np.ones(data_dims,
                                                       dtype=np.float32)

        log_scale_jacobians = jacobians[1]
        chex.assert_shape(log_scale_jacobians, (num_samples, data_dims))
        log_scale_grads = np.mean(log_scale_jacobians, axis=0)
        expected_log_scale_grads = np.exp(2 * effective_log_scale) * np.ones(
            data_dims, dtype=np.float32)

        _assert_equal(mean_grads, expected_mean_grads, atol=5e-2)
        _assert_equal(log_scale_grads, expected_log_scale_grads, atol=5e-2)

    @chex.all_variants
    @parameterized.named_parameters(
        chex.params_product([
            ('_score_function_jacobians', sge.score_function_jacobians),
            ('_pathwise_jacobians', sge.pathwise_jacobians),
            ('_measure_valued_jacobians', sge.measure_valued_jacobians),
        ], [
            ('case_1', [1.0, 2.0, 3.], [-1., 0.3, -2.], [1., 1., 1.]),
            ('case_2', [1.0, 2.0, 3.], [-1., 0.3, -2.], [4., 2., 3.]),
            ('case_3', [1.0, 2.0, 3.], [0.1, 0.2, 0.1], [10., 5., 1.]),
        ],
                            named=True))
    def testWeightedLinear(self, estimator, effective_mean,
                           effective_log_scale, weights):
        num_samples = _weighted_estimator_to_num_samples[estimator]
        rng = jax.random.PRNGKey(1)

        mean = jnp.array(effective_mean)
        log_scale = jnp.array(effective_log_scale)
        weights = jnp.array(weights)

        data_dims = len(effective_mean)

        function = lambda x: jnp.sum(x * weights)
        jacobians = _estimator_variant(self.variant,
                                       estimator)(function, [mean, log_scale],
                                                  utils.multi_normal, rng,
                                                  num_samples)

        mean_jacobians = jacobians[0]
        chex.assert_shape(mean_jacobians, (num_samples, data_dims))
        mean_grads = np.mean(mean_jacobians, axis=0)

        log_scale_jacobians = jacobians[1]
        chex.assert_shape(log_scale_jacobians, (num_samples, data_dims))
        log_scale_grads = np.mean(log_scale_jacobians, axis=0)

        expected_mean_grads = weights
        expected_log_scale_grads = np.zeros(data_dims, dtype=np.float32)

        _assert_equal(mean_grads, expected_mean_grads, atol=5e-2)
        _assert_equal(log_scale_grads, expected_log_scale_grads, atol=5e-2)

    @chex.all_variants
    @parameterized.named_parameters(
        chex.params_product([
            ('_score_function_jacobians', sge.score_function_jacobians),
            ('_pathwise_jacobians', sge.pathwise_jacobians),
            ('_measure_valued_jacobians', sge.measure_valued_jacobians),
        ], [
            ('case_1', [1.0, 2.0, 3.], [-1., 0.3, -2.], [1., 1., 1.]),
            ('case_2', [1.0, 2.0, 3.], [-1., 0.3, -2.], [4., 2., 3.]),
            ('case_3', [1.0, 2.0, 3.], [0.1, 0.2, 0.1], [3., 5., 1.]),
        ],
                            named=True))
    def testWeightedQuadratic(self, estimator, effective_mean,
                              effective_log_scale, weights):
        num_samples = _weighted_estimator_to_num_samples[estimator]
        rng = jax.random.PRNGKey(1)

        mean = jnp.array(effective_mean, dtype=jnp.float32)
        log_scale = jnp.array(effective_log_scale, dtype=jnp.float32)
        weights = jnp.array(weights, dtype=jnp.float32)

        data_dims = len(effective_mean)

        function = lambda x: jnp.sum(x * weights)**2
        jacobians = _estimator_variant(self.variant,
                                       estimator)(function, [mean, log_scale],
                                                  utils.multi_normal, rng,
                                                  num_samples)

        mean_jacobians = jacobians[0]
        chex.assert_shape(mean_jacobians, (num_samples, data_dims))
        mean_grads = np.mean(mean_jacobians, axis=0)

        log_scale_jacobians = jacobians[1]
        chex.assert_shape(log_scale_jacobians, (num_samples, data_dims))
        log_scale_grads = np.mean(log_scale_jacobians, axis=0)

        expected_mean_grads = 2 * weights * np.sum(weights * mean)
        effective_scale = np.exp(log_scale)
        expected_scale_grads = 2 * weights**2 * effective_scale
        expected_log_scale_grads = expected_scale_grads * effective_scale

        _assert_equal(mean_grads, expected_mean_grads, atol=1e-1, rtol=1e-1)
        _assert_equal(log_scale_grads,
                      expected_log_scale_grads,
                      atol=1e-1,
                      rtol=1e-1)

    @chex.all_variants
    @parameterized.named_parameters(
        chex.params_product(
            [
                ('_sum_cos_x', [1.0], [1.0], lambda x: jnp.sum(jnp.cos(x))),
                # Need to ensure that the mean is not too close to 0.
                ('_sum_log_x', [10.0], [0.0], lambda x: jnp.sum(jnp.log(x))),
                ('_sum_cos_2x', [1.0, 2.0], [1.0, -2],
                 lambda x: jnp.sum(jnp.cos(2 * x))),
                ('_cos_sum_2x', [1.0, 2.0], [1.0, -2],
                 lambda x: jnp.cos(jnp.sum(2 * x))),
            ],
            [
                ('coupling', True),
                ('nocoupling', False),
            ],
            named=True))
    def testNonPolynomialFunctionConsistencyWithPathwise(
            self, effective_mean, effective_log_scale, function, coupling):
        num_samples = 10**5
        rng = jax.random.PRNGKey(1)
        measure_rng, pathwise_rng = jax.random.split(rng)

        mean = jnp.array(effective_mean, dtype=jnp.float32)
        log_scale = jnp.array(effective_log_scale, dtype=jnp.float32)
        data_dims = len(effective_mean)

        measure_valued_jacobians = _measure_valued_variant(
            self.variant)(function, [mean, log_scale], utils.multi_normal,
                          measure_rng, num_samples, coupling)

        measure_valued_mean_jacobians = measure_valued_jacobians[0]
        chex.assert_shape(measure_valued_mean_jacobians,
                          (num_samples, data_dims))
        measure_valued_mean_grads = np.mean(measure_valued_mean_jacobians,
                                            axis=0)

        measure_valued_log_scale_jacobians = measure_valued_jacobians[1]
        chex.assert_shape(measure_valued_log_scale_jacobians,
                          (num_samples, data_dims))
        measure_valued_log_scale_grads = np.mean(
            measure_valued_log_scale_jacobians, axis=0)

        pathwise_jacobians = _estimator_variant(
            self.variant, sge.pathwise_jacobians)(function, [mean, log_scale],
                                                  utils.multi_normal,
                                                  pathwise_rng, num_samples)

        pathwise_mean_jacobians = pathwise_jacobians[0]
        chex.assert_shape(pathwise_mean_jacobians, (num_samples, data_dims))
        pathwise_mean_grads = np.mean(pathwise_mean_jacobians, axis=0)

        pathwise_log_scale_jacobians = pathwise_jacobians[1]
        chex.assert_shape(pathwise_log_scale_jacobians,
                          (num_samples, data_dims))
        pathwise_log_scale_grads = np.mean(pathwise_log_scale_jacobians,
                                           axis=0)

        _assert_equal(pathwise_mean_grads,
                      measure_valued_mean_grads,
                      rtol=5e-1,
                      atol=1e-1)
        _assert_equal(pathwise_log_scale_grads,
                      measure_valued_log_scale_grads,
                      rtol=5e-1,
                      atol=1e-1)