示例#1
0
    def testSqrtWithFiniteGradsHasCorrectGradients(self):
        self.assertTrue(
            np.isnan(self.evaluate(util.sqrt_with_finite_grads(-1.))))
        xs = tf.constant(np.linspace(1e-10, 10., 100))
        tf_sqrt = tf.sqrt(xs)
        safe_sqrt = util.sqrt_with_finite_grads(xs)

        self.assertAllEqual(self.evaluate(tf.gradients(tf_sqrt, xs)[0]),
                            self.evaluate(tf.gradients(safe_sqrt, xs)[0]))

        zero = tf.constant(0.)
        tf_sqrt = tf.sqrt(zero)
        safe_sqrt = util.sqrt_with_finite_grads(zero)
        self.assertNotEqual(self.evaluate(tf.gradients(tf_sqrt, zero)[0]),
                            self.evaluate(tf.gradients(safe_sqrt, zero)[0]))
示例#2
0
  def testSqrtWithFiniteGradsHasCorrectGradients(self):
    self.assertTrue(np.isnan(self.evaluate(util.sqrt_with_finite_grads(-1.))))
    xs = tf.constant(np.linspace(1e-10, 10., 100))
    tf_sqrt = tf.sqrt(xs)
    safe_sqrt = util.sqrt_with_finite_grads(xs)

    self.assertAllEqual(
        self.evaluate(tf.gradients(tf_sqrt, xs)[0]),
        self.evaluate(tf.gradients(safe_sqrt, xs)[0]))

    zero = tf.constant(0.)
    tf_sqrt = tf.sqrt(zero)
    safe_sqrt = util.sqrt_with_finite_grads(zero)
    self.assertNotEqual(
        self.evaluate(tf.gradients(tf_sqrt, zero)[0]),
        self.evaluate(tf.gradients(safe_sqrt, zero)[0]))
示例#3
0
 def testSqrtWithFiniteGradsWithDynamicShape(self):
     x = tf.placeholder_with_default([1.], shape=[None])
     with tf.GradientTape(persistent=True) as tape:
         tape.watch(x)
         tf_sqrt = tf.sqrt(x)
         safe_sqrt = util.sqrt_with_finite_grads(x)
     self.assertAllEqual(self.evaluate(tape.gradient(tf_sqrt, x)),
                         self.evaluate(tape.gradient(safe_sqrt, x)))
示例#4
0
    def testSqrtWithFiniteGradsHasCorrectGradients(self):
        self.assertTrue(
            np.isnan(self.evaluate(util.sqrt_with_finite_grads(-1.))))
        xs = tf.constant(np.linspace(1e-10, 10., 100))
        with tf.GradientTape(persistent=True) as tape:
            tape.watch(xs)
            tf_sqrt = tf.sqrt(xs)
            safe_sqrt = util.sqrt_with_finite_grads(xs)

        self.assertAllEqual(self.evaluate(tape.gradient(tf_sqrt, xs)),
                            self.evaluate(tape.gradient(safe_sqrt, xs)))

        zero = tf.constant(0.)
        with tf.GradientTape(persistent=True) as tape:
            tape.watch(zero)
            tf_sqrt = tf.sqrt(zero)
            safe_sqrt = util.sqrt_with_finite_grads(zero)
        self.assertNotEqual(self.evaluate(tape.gradient(tf_sqrt, zero)),
                            self.evaluate(tape.gradient(safe_sqrt, zero)))
示例#5
0
    def testSqrtWithFiniteGradsHasCorrectGradients(self):
        self.assertTrue(
            np.isnan(self.evaluate(util.sqrt_with_finite_grads(-1.))))
        xs = tf.constant(np.linspace(1e-10, 10., 100))
        _, grad_tf_sqrt = value_and_gradient(tf.sqrt, xs)
        _, grad_safe_sqrt = value_and_gradient(util.sqrt_with_finite_grads, xs)
        self.assertAllEqual(*self.evaluate([grad_tf_sqrt, grad_safe_sqrt]))

        zero = tf.constant(0.)
        _, grad_tf_sqrt = value_and_gradient(tf.sqrt, zero)
        _, grad_safe_sqrt = value_and_gradient(util.sqrt_with_finite_grads,
                                               zero)
        self.assertNotEqual(*self.evaluate([grad_tf_sqrt, grad_safe_sqrt]))
示例#6
0
    def _apply(self, x1, x2, param_expansion_ndims=0):
        # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`.
        norm = util.sqrt_with_finite_grads(
            util.sum_rightmost_ndims_preserving_shape(
                tf.math.squared_difference(x1, x2), self.feature_ndims))
        if self.length_scale is not None:
            length_scale = util.pad_shape_right_with_ones(
                self.length_scale, ndims=param_expansion_ndims)
            norm /= length_scale
        log_result = -norm

        if self.amplitude is not None:
            amplitude = util.pad_shape_right_with_ones(
                self.amplitude, ndims=param_expansion_ndims)
            log_result += 2. * tf.math.log(amplitude)
        return tf.exp(log_result)
示例#7
0
  def _apply(self, x1, x2, param_expansion_ndims=0):
    # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`.
    norm = util.sqrt_with_finite_grads(
        util.sum_rightmost_ndims_preserving_shape(
            tf.squared_difference(x1, x2), self.feature_ndims))
    if self.length_scale is not None:
      length_scale = util.pad_shape_right_with_ones(
          self.length_scale, ndims=param_expansion_ndims)
      norm /= length_scale
    log_result = -norm

    if self.amplitude is not None:
      amplitude = util.pad_shape_right_with_ones(
          self.amplitude, ndims=param_expansion_ndims)
      log_result += 2. * tf.log(amplitude)
    return tf.exp(log_result)
示例#8
0
  def _apply(self, x1, x2, example_ndims=0):
    # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`.
    norm = util.sqrt_with_finite_grads(
        util.sum_rightmost_ndims_preserving_shape(
            tf.math.squared_difference(x1, x2), self.feature_ndims))
    if self.length_scale is not None:
      length_scale = util.pad_shape_with_ones(
          self.length_scale, ndims=example_ndims)
      norm /= length_scale
    series_term = np.sqrt(5) * norm
    log_result = tf.math.log1p(series_term + series_term**2 / 3.) - series_term

    if self.amplitude is not None:
      amplitude = util.pad_shape_with_ones(self.amplitude, example_ndims)
      log_result += 2. * tf.math.log(amplitude)
    return tf.exp(log_result)
示例#9
0
  def _apply(self, x1, x2, param_expansion_ndims=0):
    # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`.
    norm = util.sqrt_with_finite_grads(
        util.sum_rightmost_ndims_preserving_shape(
            tf.squared_difference(x1, x2), self.feature_ndims))
    if self.length_scale is not None:
      length_scale = util.pad_shape_right_with_ones(
          self.length_scale, ndims=param_expansion_ndims)
      norm /= length_scale
    series_term = np.sqrt(5) * norm
    result = (1. + series_term + series_term**2 / 3.) * tf.exp(-series_term)

    if self.amplitude is not None:
      amplitude = util.pad_shape_right_with_ones(self.amplitude,
                                                 param_expansion_ndims)
      result *= amplitude**2
    return result
示例#10
0
    def _apply(self, x1, x2, param_expansion_ndims=0):
        # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`.norm = util.sqrt_with_finite_grads(
        #x1 = B,Np,D -> B,Np,1,D
        #x2 = B,N,D -> B,1,N,D
        #B, Np,N
        with tf.control_dependencies([
                tf.assert_equal(
                    tf.shape(self.heights)[-1] + 1,
                    tf.shape(self.edgescales)[-1])
        ]):
            norm = util.sqrt_with_finite_grads(
                util.sum_rightmost_ndims_preserving_shape(
                    tf.squared_difference(x1, x2), self.feature_ndims))
        #B(1),1,Np,N
        norm = tf.expand_dims(norm, -(param_expansion_ndims + 1))

        #B(1), H+1, 1, 1
        edgescales = util.pad_shape_right_with_ones(
            self.edgescales, ndims=param_expansion_ndims)
        norm *= edgescales
        norm *= 2 * np.pi

        zeros = tf.zeros(tf.shape(self.heights)[:-1],
                         dtype=self.heights.dtype)[..., None]
        # B(1),1+H+1
        heights = tf.concat([zeros, self.heights, zeros], axis=-1)
        # B(1), H+1
        dheights = heights[..., :-1] - heights[..., 1:]
        #B(1), H+1, 1, 1
        dheights = util.pad_shape_right_with_ones(dheights,
                                                  ndims=param_expansion_ndims)
        #B(1), H+1, 1, 1
        dheights *= edgescales

        def _sinc(x):
            return tf.sin(x) * tf.reciprocal(x)

        #B(1), H+1, N, Np
        sincs = tf.where(tf.less(norm, tf.constant(1e-15, dtype=norm.dtype)),
                         tf.ones_like(norm), _sinc(norm))
        #B(1), H+1, N, Np
        result = dheights * sincs
        #B(1), N,Np
        return tf.reduce_sum(result, axis=-3)
示例#11
0
    def _apply(self, lamda, return_d2K=False):
        """
        Calculate K, dK, d2K, d3K, K_theta, d2K_theta

        :param lamda: tf.Tensor
            Coordinates [N,M,3]
        :return: list of tf.Tensor
            shapes [N, M], [N,M,3], [N,M,3,3], [N,M,3,3,3], [N,M,T], [N,M,T,3,3]
        """
        if return_d2K:
            raise ValueError("No d2K defined for M12")

        with tf.name_scope('M12_apply', values=[lamda]):
            norm = util.sqrt_with_finite_grads(
                tf.reduce_sum(tf.math.square(lamda / self.lengthscale),
                              axis=-1))
            log_result = -norm
            if self.sigma is not None:
                log_result += 2. * tf.math.log(self.sigma)
            return tf.math.exp(log_result)
示例#12
0
 def testSqrtWithFiniteGradsHasCorrectValues(self):
     self.assertTrue(
         np.isnan(self.evaluate(util.sqrt_with_finite_grads(-1.))))
     xs = np.linspace(0., 10., 100)
     self.assertAllEqual(self.evaluate(tf.sqrt(xs)),
                         self.evaluate(util.sqrt_with_finite_grads(xs)))
示例#13
0
 def g(x):
     return util.sqrt_with_finite_grads(x)
示例#14
0
 def testSqrtWithFiniteGradsHasCorrectValues(self):
   self.assertTrue(np.isnan(self.evaluate(util.sqrt_with_finite_grads(-1.))))
   xs = np.linspace(0., 10., 100)
   self.assertAllEqual(
       self.evaluate(tf.sqrt(xs)),
       self.evaluate(util.sqrt_with_finite_grads(xs)))
示例#15
0
 def g(x):
   return util.sqrt_with_finite_grads(x)
示例#16
0
 def testSqrtWithFiniteGradsWithDynamicShape(self):
     x = tf.placeholder_with_default([1.], shape=[None])
     self.assertAllEqual(
         self.evaluate(tf.gradients(tf.sqrt(x), x)),
         self.evaluate(tf.gradients(util.sqrt_with_finite_grads(x), x)))