Esempio n. 1
0
    def _apply_with_distance(self,
                             x1,
                             x2,
                             pairwise_square_distance,
                             example_ndims=0):
        # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`.
        norm = util.sqrt_with_finite_grads(pairwise_square_distance)
        if self.length_scale is not None:
            length_scale = tf.convert_to_tensor(self.length_scale)
            length_scale = util.pad_shape_with_ones(length_scale,
                                                    ndims=example_ndims)
            norm /= length_scale
        df = tf.convert_to_tensor(self.df)
        df = tf.stop_gradient(df)
        df = util.pad_shape_with_ones(df, ndims=example_ndims)
        norm = tf.math.sqrt(2 * df) * norm

        # When norm -> 0, the expression should tend to zero (along
        # with the gradient).
        safe_norm = tf.where(tf.math.equal(norm, 0.),
                             dtype_util.as_numpy_dtype(self.dtype)(1.), norm)
        log_result = tf.where(
            tf.math.equal(norm, 0.),
            dtype_util.as_numpy_dtype(self.dtype)(0.),
            df * tf.math.log(safe_norm) +
            tfp_math.log_bessel_kve(df, safe_norm) - safe_norm)

        log_result = log_result - tf.math.lgamma(df) + (1. - df) * np.log(2.)

        if self.amplitude is not None:
            amplitude = tf.convert_to_tensor(self.amplitude)
            amplitude = util.pad_shape_with_ones(amplitude,
                                                 ndims=example_ndims)
            log_result = log_result + 2. * tf.math.log(amplitude)
        return tf.exp(log_result)
Esempio n. 2
0
    def testSqrtWithFiniteGradsHasCorrectGradients(self):
        self.assertTrue(
            np.isnan(self.evaluate(util.sqrt_with_finite_grads(-1.))))
        xs = tf.constant(np.linspace(1e-10, 10., 100))
        _, grad_tf_sqrt = value_and_gradient(tf.sqrt, xs)
        _, grad_safe_sqrt = value_and_gradient(util.sqrt_with_finite_grads, xs)
        self.assertAllEqual(*self.evaluate([grad_tf_sqrt, grad_safe_sqrt]))

        zero = tf.constant(0.)
        _, grad_tf_sqrt = value_and_gradient(tf.sqrt, zero)
        _, grad_safe_sqrt = value_and_gradient(util.sqrt_with_finite_grads,
                                               zero)
        self.assertNotEqual(*self.evaluate([grad_tf_sqrt, grad_safe_sqrt]))
Esempio n. 3
0
  def _apply_with_distance(
      self, x1, x2, pairwise_square_distance, example_ndims=0):
    # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`.
    norm = util.sqrt_with_finite_grads(pairwise_square_distance)
    inverse_length_scale = self._inverse_length_scale_parameter()
    if inverse_length_scale is not None:
      inverse_length_scale = util.pad_shape_with_ones(
          inverse_length_scale, ndims=example_ndims)
      norm = norm * inverse_length_scale
    series_term = tf.math.sqrt(tf.constant(5., dtype=norm.dtype)) * norm
    log_result = tf.math.log1p(series_term + series_term**2 / 3.) - series_term

    if self.amplitude is not None:
      amplitude = tf.convert_to_tensor(self.amplitude)
      amplitude = util.pad_shape_with_ones(amplitude, example_ndims)
      log_result = log_result + 2. * tf.math.log(amplitude)
    return tf.exp(log_result)
Esempio n. 4
0
  def _apply_with_distance(
      self, x1, x2, pairwise_square_distance, example_ndims=0):
    # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`.
    norm = util.sqrt_with_finite_grads(pairwise_square_distance)
    if self.length_scale is not None:
      length_scale = tf.convert_to_tensor(self.length_scale)
      length_scale = util.pad_shape_with_ones(
          length_scale, ndims=example_ndims)
      norm = norm / length_scale
    series_term = np.sqrt(3) * norm
    log_result = tf.math.log1p(series_term) - series_term

    if self.amplitude is not None:
      amplitude = tf.convert_to_tensor(self.amplitude)
      amplitude = util.pad_shape_with_ones(amplitude, example_ndims)
      log_result = log_result + 2. * tf.math.log(amplitude)
    return tf.exp(log_result)
Esempio n. 5
0
    def _apply(self, x1, x2, example_ndims=0):
        # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`.
        norm = util.sqrt_with_finite_grads(
            util.sum_rightmost_ndims_preserving_shape(
                tf.math.squared_difference(x1, x2), self.feature_ndims))
        if self.length_scale is not None:
            length_scale = tf.convert_to_tensor(self.length_scale)
            length_scale = util.pad_shape_with_ones(length_scale,
                                                    ndims=example_ndims)
            norm /= length_scale
        series_term = np.sqrt(3) * norm
        log_result = tf.math.log1p(series_term) - series_term

        if self.amplitude is not None:
            amplitude = tf.convert_to_tensor(self.amplitude)
            amplitude = util.pad_shape_with_ones(amplitude, example_ndims)
            log_result += 2. * tf.math.log(amplitude)
        return tf.exp(log_result)
Esempio n. 6
0
 def g(x):
     return util.sqrt_with_finite_grads(x)
Esempio n. 7
0
 def testSqrtWithFiniteGradsHasCorrectValues(self):
     self.assertTrue(
         np.isnan(self.evaluate(util.sqrt_with_finite_grads(-1.))))
     xs = np.linspace(0., 10., 100)
     self.assertAllEqual(self.evaluate(tf.sqrt(xs)),
                         self.evaluate(util.sqrt_with_finite_grads(xs)))