def _apply(self, x1, x2, example_ndims=0): if self.shift is None: dot_prod = util.sum_rightmost_ndims_preserving_shape( x1 * x2, ndims=self.feature_ndims) else: shift = tf.convert_to_tensor(self.shift) shift = util.pad_shape_with_ones( shift, example_ndims + self.feature_ndims) dot_prod = util.sum_rightmost_ndims_preserving_shape( (x1 - shift) * (x2 - shift), ndims=self.feature_ndims) if self.exponent is not None: exponent = tf.convert_to_tensor(self.exponent) exponent = util.pad_shape_with_ones(exponent, example_ndims) dot_prod **= exponent if self.slope_variance is not None: slope_variance = tf.convert_to_tensor(self.slope_variance) slope_variance = util.pad_shape_with_ones(slope_variance, example_ndims) dot_prod *= slope_variance**2. if self.bias_variance is not None: bias_variance = tf.convert_to_tensor(self.bias_variance) bias_variance = util.pad_shape_with_ones(bias_variance, example_ndims) dot_prod += bias_variance**2. return dot_prod
def _log_apply(self, lx1, lx2): loglx1 = tf.math.log(lx1) loglx2 = tf.math.log(lx2) lognum = util.sum_rightmost_ndims_preserving_shape( loglx1 + loglx2 + math.log(2.0), self.feature_ndims) logdenom = util.sum_rightmost_ndims_preserving_shape( tf.math.log(lx1**2 + lx2**2), self.feature_ndims) return tf.exp(0.5 * (lognum - logdenom))
def testSumRightmostNdimsPreservingShapeDynamicRank(self): if tf.executing_eagerly(): return x = tf1.placeholder_with_default(np.ones((5, 4, 3, 2)), shape=None) self.assertIsNone( util.sum_rightmost_ndims_preserving_shape(x, ndims=2).shape.ndims) self.assertAllEqual( self.evaluate(util.sum_rightmost_ndims_preserving_shape( x, ndims=2)).shape, [5, 4])
def _fast_apply(self, x1, x2): lx1 = tf.convert_to_tensor(self._length_scale_fn(x1, *self._fn_args)) lx2 = tf.convert_to_tensor(self._length_scale_fn(x2, *self._fn_args)) lx12, lx22 = lx1**2, lx2**2 scal = util.sum_rightmost_ndims_preserving_shape( tf.sqrt(2 * lx1 * lx2 / (lx12 + lx22)), self.feature_ndims) sqdist = tf.math.squared_difference(x1, x2) sqdist /= lx12 + lx22 sqdist = util.sum_rightmost_ndims_preserving_shape( sqdist, self.feature_ndims) return scal * tf.exp(-sqdist)
def testSumRightmostNdimsPreservingShapeStaticRank(self): x = np.ones((5, 4, 3, 2)) self.assertAllEqual( util.sum_rightmost_ndims_preserving_shape(x, ndims=2).shape, [5, 4]) x = tf1.placeholder_with_default( np.ones((5, 4, 3, 2)), shape=[5, 4, None, None]) self.assertAllEqual( util.sum_rightmost_ndims_preserving_shape(x, ndims=1).shape.as_list(), [5, 4, 3 if tf.executing_eagerly() else None])
def _apply(self, x1, x2, example_ndims=0): pairwise_square_distance = util.sum_rightmost_ndims_preserving_shape( tf.math.squared_difference(x1, x2), ndims=self.feature_ndims) return self._apply_with_distance(x1, x2, pairwise_square_distance, example_ndims=example_ndims)
def _apply(self, x1, x2, example_ndims=0): difference = np.pi * tf.abs(x1 - x2) if self.period is not None: period = tf.convert_to_tensor(self.period) # period acts as a batch of periods, and hence we must additionally # pad the shape with self.feature_ndims number of ones. period = util.pad_shape_with_ones(period, ndims=(example_ndims + self.feature_ndims)) difference /= period log_kernel = util.sum_rightmost_ndims_preserving_shape( -2 * tf.sin(difference)**2, ndims=self.feature_ndims) if self.length_scale is not None: length_scale = tf.convert_to_tensor(self.length_scale) length_scale = util.pad_shape_with_ones(length_scale, ndims=example_ndims) log_kernel /= length_scale**2 if self.amplitude is not None: amplitude = tf.convert_to_tensor(self.amplitude) amplitude = util.pad_shape_with_ones(amplitude, ndims=example_ndims) log_kernel += 2. * tf.math.log(amplitude) return tf.exp(log_kernel)
def _apply(self, x1, x2, example_ndims=0): if self._length_scale_fn is not None: if x1.shape[-1] == 1 and self.feature_ndims == 1: return self._fast_apply(x1, x2) lx1 = tf.convert_to_tensor( self._length_scale_fn(x1, *self._fn_args)) lx2 = tf.convert_to_tensor( self._length_scale_fn(x2, *self._fn_args)) scal = self._log_apply(lx1, lx2) sqdist = tf.math.squared_difference(x1, x2) sqdist /= lx1**2 + lx2**2 sqdist = util.sum_rightmost_ndims_preserving_shape( sqdist, self.feature_ndims) return scal * tf.exp(-sqdist) sqdist = util.sum_rightmost_ndims_preserving_shape( tf.math.squared_difference(x1, x2), self.feature_ndims) return tf.exp(-sqdist / 2)
def _apply(self, x1, x2, example_ndims=0): cov = self._kernel._apply(x1, x2, example_ndims) if self._scaling_fn is not None: scal_x1 = tf.convert_to_tensor(self._scaling_fn( x1, *self._fn_args)) scal_x2 = tf.convert_to_tensor(self._scaling_fn( x2, *self._fn_args)) scal = util.sum_rightmost_ndims_preserving_shape( scal_x1 * scal_x2, self._feature_ndims) return scal * cov return cov
def _apply(self, x1, x2, example_ndims=0): a = tf.convert_to_tensor(self._concentration) b = tf.convert_to_tensor(self._rate) a = util.pad_shape_with_ones(a, ndims=example_ndims) b = util.pad_shape_with_ones(b, ndims=example_ndims) # The kernel is defined for scalars where t >= 0. # TODO(jburnim,srvasude): Raise or return NaN when `any(x1 < 0 | x2 < 0)`? sum_x1_x2 = util.sum_rightmost_ndims_preserving_shape( x1 + x2, ndims=self.feature_ndims) log_result = tf.math.xlogy(a, b) - tf.math.xlogy(a, sum_x1_x2 + b) return tf.math.exp(log_result)
def _apply(self, x1, x2, example_ndims=0): component = (2.0 * math.pi * tf.sqrt( util.sum_rightmost_ndims_preserving_shape( tf.math.squared_difference(x1, x2), self.feature_ndims))) if self.length_scale is not None: length_scale = tf.convert_to_tensor(self._length_scale) length_scale = util.pad_shape_with_ones(length_scale, example_ndims) component /= length_scale**2 if self.amplitude is not None: amplitude = tf.convert_to_tensor(self._amplitude) amplitude = util.pad_shape_with_ones(amplitude, example_ndims) return amplitude**2 * tf.math.cos(component) return tf.math.cos(component)
def _apply(self, x1, x2, example_ndims=0): exponent = -0.5 * util.sum_rightmost_ndims_preserving_shape( tf.math.squared_difference(x1, x2), self.feature_ndims) if self.length_scale is not None: length_scale = tf.convert_to_tensor(self.length_scale) length_scale = util.pad_shape_with_ones(length_scale, example_ndims) exponent /= length_scale**2 if self.amplitude is not None: amplitude = tf.convert_to_tensor(self.amplitude) amplitude = util.pad_shape_with_ones(amplitude, example_ndims) exponent += 2. * tf.math.log(amplitude) return tf.exp(exponent)
def _apply(self, x1, x2, example_ndims=0): sqdist = util.sum_rightmost_ndims_preserving_shape( tf.math.squared_difference(x1, x2), self.feature_ndims) ndist = -0.5 * tf.sqrt(sqdist + 1e-12) if self.length_scale is not None: length_scale = tf.convert_to_tensor(self._length_scale) length_scale = util.pad_shape_with_ones(length_scale, example_ndims) ndist /= length_scale**2 if self.amplitude is not None: amplitude = tf.convert_to_tensor(self._amplitude) amplitude = util.pad_shape_with_ones(amplitude, example_ndims) return amplitude**2 * tf.exp(ndist) return tf.exp(ndist)
def _matrix(self, x1, x2): cov = self._kernel._matrix(x1, x2) if self._scaling_fn is not None: scal_x1 = util.pad_shape_with_ones( tf.convert_to_tensor(self._scaling_fn(x1, *self._fn_args)), ndims=1, start=-(self._feature_ndims + 1), ) scal_x2 = util.pad_shape_with_ones( tf.convert_to_tensor(self._scaling_fn(x2, *self._fn_args)), ndims=1, start=-(self._feature_ndims + 2), ) scal = util.sum_rightmost_ndims_preserving_shape( scal_x1 * scal_x2, ndims=self._feature_ndims) return scal * cov return cov
def _apply(self, x1, x2, example_ndims=0): # Add an extra dimension to x1 and x2 so it broadcasts with scales. # [B1, ...., E1, ...., E2, M, F1, ..., F2] x1 = util.pad_shape_with_ones( x1, ndims=1, start=-(self.feature_ndims + example_ndims + 1)) x2 = util.pad_shape_with_ones( x2, ndims=1, start=-(self.feature_ndims + example_ndims + 1)) scales = util.pad_shape_with_ones(self.scales, ndims=example_ndims, start=-(self.feature_ndims + 1)) pairwise_square_distance = util.sum_rightmost_ndims_preserving_shape( tf.math.square(np.pi * (x1 - x2) * scales), ndims=self.feature_ndims) return self._apply_with_distance(x1, x2, pairwise_square_distance, example_ndims=example_ndims)
def _apply(self, x1, x2, example_ndims=0): # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`. norm = util.sqrt_with_finite_grads( util.sum_rightmost_ndims_preserving_shape( tf.math.squared_difference(x1, x2), self.feature_ndims)) if self.length_scale is not None: length_scale = tf.convert_to_tensor(self.length_scale) length_scale = util.pad_shape_with_ones(length_scale, ndims=example_ndims) norm /= length_scale series_term = np.sqrt(3) * norm log_result = tf.math.log1p(series_term) - series_term if self.amplitude is not None: amplitude = tf.convert_to_tensor(self.amplitude) amplitude = util.pad_shape_with_ones(amplitude, example_ndims) log_result += 2. * tf.math.log(amplitude) return tf.exp(log_result)
def testPairwiseSquareDistanceMatrix(self, feature_ndims, dims): batch_shape = [2, 3] seed_stream = test_util.test_seed_stream('pairwise_square_distance') x1 = tf.random.normal( dtype=np.float64, shape=batch_shape + [dims] * feature_ndims, seed=seed_stream()) x2 = tf.random.normal( dtype=np.float64, shape=batch_shape + [dims] * feature_ndims, seed=seed_stream()) pairwise_square_distance = util.pairwise_square_distance_matrix( x1, x2, feature_ndims) x1_pad = util.pad_shape_with_ones( x1, ndims=1, start=-(feature_ndims + 1)) x2_pad = util.pad_shape_with_ones( x2, ndims=1, start=-(feature_ndims + 2)) actual_square_distance = util.sum_rightmost_ndims_preserving_shape( tf.math.squared_difference(x1_pad, x2_pad), feature_ndims) pairwise_square_distance_, actual_square_distance_ = self.evaluate([ pairwise_square_distance, actual_square_distance]) self.assertAllClose(pairwise_square_distance_, actual_square_distance_)