def _apply_with_distance(self, x1, x2, pairwise_square_distance, example_ndims=0): # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`. norm = util.sqrt_with_finite_grads(pairwise_square_distance) if self.length_scale is not None: length_scale = tf.convert_to_tensor(self.length_scale) length_scale = util.pad_shape_with_ones(length_scale, ndims=example_ndims) norm /= length_scale df = tf.convert_to_tensor(self.df) df = tf.stop_gradient(df) df = util.pad_shape_with_ones(df, ndims=example_ndims) norm = tf.math.sqrt(2 * df) * norm # When norm -> 0, the expression should tend to zero (along # with the gradient). safe_norm = tf.where(tf.math.equal(norm, 0.), dtype_util.as_numpy_dtype(self.dtype)(1.), norm) log_result = tf.where( tf.math.equal(norm, 0.), dtype_util.as_numpy_dtype(self.dtype)(0.), df * tf.math.log(safe_norm) + tfp_math.log_bessel_kve(df, safe_norm) - safe_norm) log_result = log_result - tf.math.lgamma(df) + (1. - df) * np.log(2.) if self.amplitude is not None: amplitude = tf.convert_to_tensor(self.amplitude) amplitude = util.pad_shape_with_ones(amplitude, ndims=example_ndims) log_result = log_result + 2. * tf.math.log(amplitude) return tf.exp(log_result)
def testSqrtWithFiniteGradsHasCorrectGradients(self): self.assertTrue( np.isnan(self.evaluate(util.sqrt_with_finite_grads(-1.)))) xs = tf.constant(np.linspace(1e-10, 10., 100)) _, grad_tf_sqrt = value_and_gradient(tf.sqrt, xs) _, grad_safe_sqrt = value_and_gradient(util.sqrt_with_finite_grads, xs) self.assertAllEqual(*self.evaluate([grad_tf_sqrt, grad_safe_sqrt])) zero = tf.constant(0.) _, grad_tf_sqrt = value_and_gradient(tf.sqrt, zero) _, grad_safe_sqrt = value_and_gradient(util.sqrt_with_finite_grads, zero) self.assertNotEqual(*self.evaluate([grad_tf_sqrt, grad_safe_sqrt]))
def _apply_with_distance( self, x1, x2, pairwise_square_distance, example_ndims=0): # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`. norm = util.sqrt_with_finite_grads(pairwise_square_distance) inverse_length_scale = self._inverse_length_scale_parameter() if inverse_length_scale is not None: inverse_length_scale = util.pad_shape_with_ones( inverse_length_scale, ndims=example_ndims) norm = norm * inverse_length_scale series_term = tf.math.sqrt(tf.constant(5., dtype=norm.dtype)) * norm log_result = tf.math.log1p(series_term + series_term**2 / 3.) - series_term if self.amplitude is not None: amplitude = tf.convert_to_tensor(self.amplitude) amplitude = util.pad_shape_with_ones(amplitude, example_ndims) log_result = log_result + 2. * tf.math.log(amplitude) return tf.exp(log_result)
def _apply_with_distance( self, x1, x2, pairwise_square_distance, example_ndims=0): # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`. norm = util.sqrt_with_finite_grads(pairwise_square_distance) if self.length_scale is not None: length_scale = tf.convert_to_tensor(self.length_scale) length_scale = util.pad_shape_with_ones( length_scale, ndims=example_ndims) norm = norm / length_scale series_term = np.sqrt(3) * norm log_result = tf.math.log1p(series_term) - series_term if self.amplitude is not None: amplitude = tf.convert_to_tensor(self.amplitude) amplitude = util.pad_shape_with_ones(amplitude, example_ndims) log_result = log_result + 2. * tf.math.log(amplitude) return tf.exp(log_result)
def _apply(self, x1, x2, example_ndims=0): # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`. norm = util.sqrt_with_finite_grads( util.sum_rightmost_ndims_preserving_shape( tf.math.squared_difference(x1, x2), self.feature_ndims)) if self.length_scale is not None: length_scale = tf.convert_to_tensor(self.length_scale) length_scale = util.pad_shape_with_ones(length_scale, ndims=example_ndims) norm /= length_scale series_term = np.sqrt(3) * norm log_result = tf.math.log1p(series_term) - series_term if self.amplitude is not None: amplitude = tf.convert_to_tensor(self.amplitude) amplitude = util.pad_shape_with_ones(amplitude, example_ndims) log_result += 2. * tf.math.log(amplitude) return tf.exp(log_result)
def g(x): return util.sqrt_with_finite_grads(x)
def testSqrtWithFiniteGradsHasCorrectValues(self): self.assertTrue( np.isnan(self.evaluate(util.sqrt_with_finite_grads(-1.)))) xs = np.linspace(0., 10., 100) self.assertAllEqual(self.evaluate(tf.sqrt(xs)), self.evaluate(util.sqrt_with_finite_grads(xs)))