def _apply(self, x1, x2, example_ndims): # In the shape annotations below, # # - x1 has shape B1 + E1 + F (batch, example, feature), # - x2 has shape B2 + E2 + F, # - z refers to self.fixed_inputs, and has shape Bz + [ez] + F, ie its # example ndims is exactly 1, # - self.base_kernel has batch shape Bk, # - bc(A, B, C) means "the result of broadcasting shapes A, B, and C". # Shape: bc(Bk, B1, B2) + bc(E1, E2) k12 = self.base_kernel.apply(x1, x2, example_ndims) if self._is_fixed_inputs_empty(): return k12 fixed_inputs = tf.convert_to_tensor(self._fixed_inputs) # Shape: bc(Bk, B1, Bz) + E1 + [ez] k1z = self.base_kernel.tensor(x1, fixed_inputs, x1_example_ndims=example_ndims, x2_example_ndims=1) # Shape: bc(Bk, B2, Bz) + E2 + [ez] k2z = self.base_kernel.tensor(x2, fixed_inputs, x1_example_ndims=example_ndims, x2_example_ndims=1) # Shape: bc(Bz, Bk) + [ez, ez] div_mat_chol = self._divisor_matrix_cholesky(fixed_inputs=fixed_inputs) # Shape: bc(Bz, Bk) + [1, ..., 1] + [ez, ez] # `--------' # `-- (example_ndims - 1) ones # This reshape ensures that the batch shapes here align correctly with the # batch shape of k2z, below: `example_ndims` because E2 has rank # `example_ndims`, and "- 1" because one of the ez's here already "pushed" # the batch dims over by one. div_mat_chol = util.pad_shape_with_ones(div_mat_chol, example_ndims - 1, -3) div_mat_chol_linop = tf.linalg.LinearOperatorLowerTriangular( div_mat_chol) # Shape: bc(Bz, Bk, B2) + E1 + [ez] cholinv_kz1 = tf.linalg.matrix_transpose( div_mat_chol_linop.solve(k1z, adjoint_arg=True)) # Shape: bc(Bz, Bk, B2) + E2 + [ez] cholinv_kz2 = tf.linalg.matrix_transpose( div_mat_chol_linop.solve(k2z, adjoint_arg=True)) k1z_kzzinv_kz2 = tf.reduce_sum(cholinv_kz1 * cholinv_kz2, axis=-1) # Shape: bc(Bz, Bk, B1, B2) + bc(E1, E2) return k12 - k1z_kzzinv_kz2
def _apply(self, x1, x2, example_ndims=0): x1 = tf.convert_to_tensor(x1) x2 = tf.convert_to_tensor(x2) value = tf.reduce_sum(x1 + x2, axis=-1) if self.multiplier is not None: multiplier = kernels_util.pad_shape_with_ones( self._multiplier, example_ndims) value *= multiplier return value
def _apply_with_distance(self, x1, x2, pairwise_square_distance, example_ndims=0): # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`. norm = util.sqrt_with_finite_grads(pairwise_square_distance) if self.length_scale is not None: length_scale = tf.convert_to_tensor(self.length_scale) length_scale = util.pad_shape_with_ones(length_scale, ndims=example_ndims) norm = norm / length_scale series_term = np.sqrt(3) * norm log_result = tf.math.log1p(series_term) - series_term if self.amplitude is not None: amplitude = tf.convert_to_tensor(self.amplitude) amplitude = util.pad_shape_with_ones(amplitude, example_ndims) log_result = log_result + 2. * tf.math.log(amplitude) return tf.exp(log_result)
def _apply_with_distance(self, x1, x2, pairwise_square_distance, example_ndims=0): if self.length_scale is not None: length_scale = tf.convert_to_tensor(self.length_scale) length_scale = util.pad_shape_with_ones(length_scale, example_ndims) pairwise_square_distance = pairwise_square_distance / length_scale**2 default_scale = tf.cast(.75, pairwise_square_distance.dtype) result = tf.nn.relu(1 - pairwise_square_distance) * default_scale if self.amplitude is not None: amplitude = tf.convert_to_tensor(self.amplitude) amplitude = util.pad_shape_with_ones(amplitude, example_ndims) result = result * amplitude return result
def _apply(self, x1, x2, example_ndims=0): shape = tf.broadcast_dynamic_shape( x1.shape[:-(self.feature_ndims)], x2.shape[:-(self.feature_ndims)], ) expected = tf.ones(shape, dtype=self._dtype) if self.coef is not None: coef = tf.convert_to_tensor(self._coef) coef = util.pad_shape_with_ones(coef, example_ndims) expected *= coef return expected
def _apply_with_distance(self, x1, x2, pairwise_square_distance, example_ndims=0): # Use util.sqrt_with_finite_grads to avoid NaN gradients when `x1 == x2`. norm = util.sqrt_with_finite_grads(pairwise_square_distance) inverse_length_scale = self._inverse_length_scale_parameter() if inverse_length_scale is not None: inverse_length_scale = util.pad_shape_with_ones( inverse_length_scale, ndims=example_ndims) norm = norm * inverse_length_scale log_result = -norm if self.amplitude is not None: amplitude = tf.convert_to_tensor(self.amplitude) amplitude = util.pad_shape_with_ones(amplitude, ndims=example_ndims) log_result = log_result + 2. * tf.math.log(amplitude) return tf.exp(log_result)
def testPairwiseSquareDistanceMatrix(self, feature_ndims, dims): batch_shape = [2, 3] seed_stream = test_util.test_seed_stream('pairwise_square_distance') x1 = tf.random.normal( dtype=np.float64, shape=batch_shape + [dims] * feature_ndims, seed=seed_stream()) x2 = tf.random.normal( dtype=np.float64, shape=batch_shape + [dims] * feature_ndims, seed=seed_stream()) pairwise_square_distance = util.pairwise_square_distance_matrix( x1, x2, feature_ndims) x1_pad = util.pad_shape_with_ones( x1, ndims=1, start=-(feature_ndims + 1)) x2_pad = util.pad_shape_with_ones( x2, ndims=1, start=-(feature_ndims + 2)) actual_square_distance = util.sum_rightmost_ndims_preserving_shape( tf.math.squared_difference(x1_pad, x2_pad), feature_ndims) pairwise_square_distance_, actual_square_distance_ = self.evaluate([ pairwise_square_distance, actual_square_distance]) self.assertAllClose(pairwise_square_distance_, actual_square_distance_)
def rescale_input(x, feature_ndims, example_ndims): """Computes `x / scale_diag`.""" inverse_scale_diag = self.inverse_scale_diag if inverse_scale_diag is None: inverse_scale_diag = tf.math.reciprocal(self.scale_diag) inverse_scale_diag = tf.convert_to_tensor(inverse_scale_diag) inverse_scale_diag = util.pad_shape_with_ones( inverse_scale_diag, example_ndims, # Start before the first feature dimension. We assume scale_diag has # at least as many dimensions as feature_ndims. start=-(feature_ndims + 1)) return x * inverse_scale_diag
def _apply(self, x1, x2, example_ndims=0): difference = np.pi * tf.abs(x1 - x2) if self.period is not None: period = tf.convert_to_tensor(self.period) # period acts as a batch of periods, and hence we must additionally # pad the shape with self.feature_ndims number of ones. period = util.pad_shape_with_ones( period, ndims=(example_ndims + self.feature_ndims)) difference /= period log_kernel = util.sum_rightmost_ndims_preserving_shape( -2 * tf.sin(difference) ** 2, ndims=self.feature_ndims) if self.length_scale is not None: length_scale = tf.convert_to_tensor(self.length_scale) length_scale = util.pad_shape_with_ones( length_scale, ndims=example_ndims) log_kernel /= length_scale ** 2 if self.amplitude is not None: amplitude = tf.convert_to_tensor(self.amplitude) amplitude = util.pad_shape_with_ones(amplitude, ndims=example_ndims) log_kernel += 2. * tf.math.log(amplitude) return tf.exp(log_kernel)
def _matrix(self, x1, x2): # Add an extra dimension to x1 and x2 so it broadcasts with scales. x1 = util.pad_shape_with_ones(x1, ndims=1, start=-(self.feature_ndims + 2)) x2 = util.pad_shape_with_ones(x2, ndims=1, start=-(self.feature_ndims + 2)) scales = util.pad_shape_with_ones(self.scales, ndims=1, start=-(self.feature_ndims + 1)) pairwise_square_distance = util.pairwise_square_distance_matrix( np.pi * x1 * scales, np.pi * x2 * scales, self.feature_ndims) x1 = util.pad_shape_with_ones(x1, ndims=1, start=-(self.feature_ndims + 1)) x2 = util.pad_shape_with_ones(x2, ndims=1, start=-(self.feature_ndims + 2)) # Expand `x1` and `x2` so that the broadcast against each other. return self._apply_with_distance(x1, x2, pairwise_square_distance, example_ndims=2)
def _matrix(self, x1, x2): shape = tf.broadcast_dynamic_shape( x1.shape[:-(1 + self.feature_ndims)], x2.shape[:-(1 + self.feature_ndims)], ) expected = tf.linalg.eye( x1.shape[-(1 + self.feature_ndims)], x2.shape[-(1 + self.feature_ndims)], batch_shape=shape, dtype=self._dtype, ) if self.noise is not None: noise = tf.convert_to_tensor(self._noise) noise = util.pad_shape_with_ones(noise, 2) expected *= noise return expected
def _inner_apply(x1, x2): order = ps.shape(self.amplitudes)[-1] def scan_fn(esp, i): s = self.kernel[..., i].apply( x1[..., i][..., tf.newaxis], x2[..., i][..., tf.newaxis], example_ndims=example_ndims) next_esp = esp[..., 1:] + s[..., tf.newaxis] * esp[..., :-1] # Add the zero-th polynomial. next_esp = tf.concat( [tf.ones_like(esp[..., 0][..., tf.newaxis]), next_esp], axis=-1) return next_esp batch_shape = ps.broadcast_shape( ps.shape(x1)[:-self.kernel.feature_ndims], ps.shape(x2)[:-self.kernel.feature_ndims]) batch_shape = ps.broadcast_shape( batch_shape, ps.concat([ self.batch_shape_tensor(), [1] * example_ndims], axis=0)) initializer = tf.concat( [tf.ones(ps.concat([batch_shape, [1]], axis=0), dtype=self.dtype), tf.zeros(ps.concat([batch_shape, [order]], axis=0), dtype=self.dtype)], axis=-1) esps = tf.scan( scan_fn, elems=ps.range(0, ps.shape(x1)[-1], dtype=tf.int32), parallel_iterations=32, initializer=initializer)[-1, ..., 1:] amplitudes = util.pad_shape_with_ones( self.amplitudes, ndims=example_ndims, start=-2) return tf.reduce_sum(esps * tf.math.square(amplitudes), axis=-1)
def _matrix(self, x1, x2): x1 = util.pad_shape_with_ones( x1, ndims=1, start=-(self.feature_ndims + 1)) x2 = util.pad_shape_with_ones( x2, ndims=1, start=-(self.feature_ndims + 2)) return self._call_apply(x1, x2, example_ndims=2)
def tensor(self, x1, x2, x1_example_ndims, x2_example_ndims): """Construct (batched) tensors from (batches of) collections of inputs. Args: x1: `Tensor` input to the first positional parameter of the kernel, of shape `B1 + E1 + F`, where `B1` and `E1` arbitrary shapes which may be empty (ie, no batch/example dims, resp.), and `F` (the feature shape) must have rank equal to the kernel's `feature_ndims` property. Batch shape must broadcast with the batch shape of `x2` and with the kernel's batch shape. x2: `Tensor` input to the second positional parameter of the kernel, shape `B2 + E2 + F`, where `B2` and `E2` arbitrary shapes which may be empty (ie, no batch/example dims, resp.), and `F` (the feature shape) must have rank equal to the kernel's `feature_ndims` property. Batch shape must broadcast with the batch shape of `x1` and with the kernel's batch shape. x1_example_ndims: A python integer greater than or equal to 0, the number of example dims in the first input. This affects both the alignment of batch shapes and the shape of the final output of the function. Everything left of the feature shape and the example dims in `x1` is considered "batch shape", and must broadcast as specified above. x2_example_ndims: A python integer greater than or equal to 0, the number of example dims in the second input. This affects both the alignment of batch shapes and the shape of the final output of the function. Everything left of the feature shape and the example dims in `x1` is considered "batch shape", and must broadcast as specified above. Returns: `Tensor` containing (possibly batched) kernel applications to pairs from inputs `x1` and `x2`. If the kernel parameters' batch shape is `Bk` then the shape of the `Tensor` resulting from this method call is `broadcast(Bk, B1, B2) + E1 + E2`. Note this differs from `apply`: the example dimensions are concatenated, whereas in `apply` the example dims are broadcast together. It also differs from `matrix`: the example shapes are arbitrary here, and the result accrues a rank equal to the sum of the ranks of the input example shapes. #### Examples First, consider a kernel with a single scalar parameter. ```python import tensorflow_probability as tfp scalar_kernel = tfp.math.psd_kernels.SomeKernel(param=.5) scalar_kernel.batch_shape # ==> [] # Our inputs are two rank-2 collections of 3-D vectors x = np.ones([5, 6, 3], np.float32) y = np.ones([7, 8, 3], np.float32) scalar_kernel.tensor(x, y, x1_example_ndims=2, x2_example_ndims=2).shape # ==> [5, 6, 7, 8] # Empty example shapes work too! x = np.ones([3], np.float32) y = np.ones([5, 3], np.float32) scalar_kernel.tensor(x, y, x1_example_ndims=0, x2_example_ndims=1).shape # ==> [5] ``` The result comes from applying the kernel to the entries in `x` and `y` pairwise, across all pairs: ```none | k(x[0], y[0]) k(x[0], y[1]) ... k(x[0], y[3]) | | k(x[1], y[0]) k(x[1], y[1]) ... k(x[1], y[3]) | | ... ... ... | | k(x[4], y[0]) k(x[4], y[1]) ... k(x[4], y[3]) | ``` Now consider a kernel with batched parameters. ```python batch_kernel = tfp.math.psd_kernels.SomeKernel(param=[1., .5]) batch_kernel.batch_shape # ==> [2] # Inputs are two rank-2 collections of 3-D vectors x = np.ones([5, 6, 3], np.float32) y = np.ones([7, 8, 3], np.float32) scalar_kernel.tensor(x, y, x1_example_ndims=2, x2_example_ndims=2).shape # ==> [2, 5, 6, 7, 8] ``` We also support batching of the inputs. First, let's look at that with the scalar kernel again. ```python # Batch of 10 lists of 5x6 collections of dimension 3 x = np.ones([10, 5, 6, 3], np.float32) # Batch of 10 lists of 7x8 collections of dimension 3 y = np.ones([10, 7, 8, 3], np.float32) scalar_kernel.tensor(x, y, x1_example_ndims=2, x2_example_ndims=2).shape # ==> [10, 5, 6, 7, 8] ``` The result is a batch of 10 tensors built from the batch of 10 rank-2 collections of input vectors. The batch shapes have to be broadcastable. The following will *not* work: ```python x = np.ones([10, 5, 3], np.float32) y = np.ones([20, 4, 3], np.float32) scalar_kernel.tensor(x, y, x1_example_ndims=1, x2_example_ndims=1).shape # ==> Error! [10] and [20] can't broadcast. ``` Now let's consider batches of inputs in conjunction with batches of kernel parameters. We require that the input batch shapes be broadcastable with the kernel parameter batch shapes, otherwise we get an error: ```python x = np.ones([10, 5, 6, 3], np.float32) y = np.ones([10, 7, 8, 3], np.float32) batch_kernel = tfp.math.psd_kernels.SomeKernel(params=[1., .5]) batch_kernel.batch_shape # ==> [2] batch_kernel.tensor(x, y, x1_example_ndims=2, x2_example_ndims=2).shape # ==> Error! [2] and [10] can't broadcast. ``` The fix is to make the kernel parameter shape broadcastable with `[10]` (or reshape the inputs to be broadcastable!): ```python x = np.ones([10, 5, 6, 3], np.float32) y = np.ones([10, 7, 8, 3], np.float32) batch_kernel = tfp.math.psd_kernels.SomeKernel( params=[[1.], [.5]]) batch_kernel.batch_shape # ==> [2, 1] batch_kernel.tensor(x, y, x1_example_ndims=2, x2_example_ndims=2).shape # ==> [2, 10, 5, 6, 7, 8] # Or, make the inputs broadcastable: x = np.ones([10, 1, 5, 6, 3], np.float32) y = np.ones([10, 1, 7, 8, 3], np.float32) batch_kernel = tfp.math.psd_kernels.SomeKernel( params=[1., .5]) batch_kernel.batch_shape # ==> [2] batch_kernel.tensor(x, y, x1_example_ndims=2, x2_example_ndims=2).shape # ==> [10, 2, 5, 6, 7, 8] ``` """ with self._name_and_control_scope(self._name): x1 = tf.convert_to_tensor(x1, name='x1', dtype_hint=self.dtype) x2 = tf.convert_to_tensor(x2, name='x2', dtype_hint=self.dtype) x1 = util.pad_shape_with_ones(x1, ndims=x2_example_ndims, start=-(self.feature_ndims + 1)) x2 = util.pad_shape_with_ones( x2, ndims=x1_example_ndims, start=-(self.feature_ndims + 1 + x2_example_ndims)) return self.apply(x1, x2, example_ndims=(x1_example_ndims + x2_example_ndims))
def testPadShapeStartWithOnes(self): # Test nominal behavior. x = np.ones([3], np.float32) self.assertAllEqual( self.evaluate(util.pad_shape_with_ones(x, 3, start=-2)).shape, [1, 1, 1, 3])
def vector_transform(x, feature_ndims, param_expansion_ndims): diag = util.pad_shape_with_ones(scale_diag, param_expansion_ndims + feature_ndims - 1, start=-2) return diag * x
def testPadShapeRightWithOnes(self): # Test nominal behavior. x = np.ones([3], np.float32) self.assertAllEqual( self.evaluate(util.pad_shape_with_ones(x, 3)).shape, [3, 1, 1, 1])