def testShapesAreCorrect(self): base_kernel = tfpk.ExponentiatedQuadratic(np.float64(1.), np.float64(1.)) fixed_inputs = np.random.uniform(-1., 1., size=[2, 3]) k = tfpk.SchurComplement(base_kernel, fixed_inputs) x = np.ones([4, 3], np.float64) y = np.ones([5, 3], np.float64) self.assertAllEqual([4], k.apply(x, x).shape) self.assertAllEqual([5], k.apply(y, y).shape) self.assertAllEqual([4, 5], k.matrix(x, y).shape) self.assertAllEqual([2, 4, 5], k.matrix(tf.stack([x] * 2), tf.stack([y] * 2)).shape) base_kernel = tfpk.ExponentiatedQuadratic( amplitude=np.ones([2, 1, 1], np.float64), length_scale=np.ones([1, 3, 1], np.float64)) # Batch these at the outermost shape position fixed_inputs = np.random.uniform(-1., 1., size=[7, 1, 1, 1, 2, 3]) k = tfpk.SchurComplement(base_kernel, fixed_inputs) self.assertAllEqual([7, 2, 3, 4], k.apply(x, x).shape) self.assertAllEqual( [7, 2, 3, 2, 4, 5], #| `--' | `--' #| | | `- matrix shape #| | `- from input batch shapes #| `- from broadcasting kernel params #`- from batch of obs index points k.matrix( tf.stack([x] * 2), # shape [2, 4, 3] tf.stack([y] * 2) # shape [2, 5, 3] ).shape)
def testMismatchedFloatTypesAreBad(self): base_kernel = tfpk.ExponentiatedQuadratic(np.float64(5.), np.float64(.2)) # Should be OK tfpk.SchurComplement( base_kernel=base_kernel, # float64 fixed_inputs=np.random.uniform(-1., 1., [2, 1])) with self.assertRaises(TypeError): float32_inputs = np.random.uniform(-1., 1., [2, 1]).astype(np.float32) tfpk.SchurComplement(base_kernel=base_kernel, fixed_inputs=float32_inputs)
def testBaseKernelNoneDtype(self): # Test that we don't have problems when base_kernel has no explicit dtype # (ie, params are all None), but fixed_inputs has a different dtype than the # "common_dtype" default value of np.float32. fixed_inputs = np.arange(3, dtype=np.float64).reshape([3, 1]) # Should raise when there's an explicit mismatch. with self.assertRaises(TypeError): schur_complement = tfpk.SchurComplement( tfpk.ExponentiatedQuadratic(np.float32(1)), fixed_inputs) # Should not throw an exception when the kernel doesn't get an explicit # dtype from its inputs. schur_complement = tfpk.SchurComplement(tfpk.ExponentiatedQuadratic(), fixed_inputs) schur_complement.matrix(fixed_inputs, fixed_inputs)
def testApplyShapesAreCorrect(self): for example_ndims in range(0, 4): # An integer generator. ints = itertools.count(start=2, step=1) feature_shape = [next(ints), next(ints)] x_batch_shape = [next(ints)] z_batch_shape = [next(ints), 1] num_x = [next(ints) for _ in range(example_ndims)] num_z = [next(ints)] x_shape = x_batch_shape + num_x + feature_shape z_shape = z_batch_shape + num_z + feature_shape x = np.ones(x_shape, np.float64) z = np.random.uniform(-1., 1., size=z_shape) base_kernel = tfpk.ExponentiatedQuadratic( amplitude=np.ones([next(ints), 1, 1], np.float64), feature_ndims=len(feature_shape)) k = tfpk.SchurComplement(base_kernel, fixed_inputs=z) expected = broadcast_shapes(base_kernel.batch_shape, x_batch_shape, z_batch_shape) + num_x actual = k.apply(x, x, example_ndims=example_ndims).shape self.assertAllEqual(expected, actual)
def testEmptyFixedInputs(self): base_kernel = tfpk.ExponentiatedQuadratic(1., 1.) fixed_inputs = tf.ones([0, 2], np.float32) schur = tfpk.SchurComplement(base_kernel, fixed_inputs) x = np.ones([4, 3], np.float32) y = np.ones([5, 3], np.float32) self.assertAllEqual(self.evaluate(base_kernel.matrix(x, y)), self.evaluate(schur.matrix(x, y))) # Test batch shapes base_kernel = tfpk.ExponentiatedQuadratic([1., 2.]) fixed_inputs = tf.ones([0, 2], np.float32) schur = tfpk.SchurComplement(base_kernel, fixed_inputs) self.assertAllEqual([2], schur.batch_shape) self.assertAllEqual([2], self.evaluate(schur.batch_shape_tensor()))
def testNoneFixedInputs(self): base_kernel = tfpk.ExponentiatedQuadratic(1., 1.) schur = tfpk.SchurComplement(base_kernel, fixed_inputs=None) x = np.ones([4, 3], np.float32) y = np.ones([5, 3], np.float32) self.assertAllEqual(self.evaluate(base_kernel.matrix(x, y)), self.evaluate(schur.matrix(x, y)))
def testValuesAreCorrect(self, feature_ndims, dims): np.random.seed(42) num_obs = 5 num_x = 3 num_y = 3 shape = [dims] * feature_ndims base_kernel = tfpk.ExponentiatedQuadratic( np.float64(5.), np.float64(.2), feature_ndims=feature_ndims) fixed_inputs = np.random.uniform(-1., 1., size=[num_obs] + shape) k = tfpk.SchurComplement( base_kernel=base_kernel, fixed_inputs=fixed_inputs) k_obs = self.evaluate(base_kernel.matrix(fixed_inputs, fixed_inputs)) k_obs_chol_linop = tf.linalg.LinearOperatorLowerTriangular( tf.linalg.cholesky(k_obs)) for _ in range(5): x = np.random.uniform(-1, 1, size=[num_x] + shape) y = np.random.uniform(-1, 1, size=[num_y] + shape) k_x_y = self.evaluate(base_kernel.apply(x, y)) k_x_obs = self.evaluate(base_kernel.matrix(x, fixed_inputs)) k_obs_y = self.evaluate(base_kernel.matrix(y, fixed_inputs)) k_x_obs = np.expand_dims(k_x_obs, -2) k_obs_y = np.expand_dims(k_obs_y, -1) k_obs_inv_k_obs_y = self.evaluate( k_obs_chol_linop.solve( k_obs_chol_linop.solve(k_obs_y), adjoint=True)) cov_dec = np.einsum('ijk,ikl->ijl', k_x_obs, k_obs_inv_k_obs_y) cov_dec = cov_dec[..., 0, 0] # np.squeeze didn't like list of axes expected = k_x_y - cov_dec self.assertAllClose(expected, self.evaluate(k.apply(x, y)))
def testTensorShapesAreCorrect(self): for x1_example_ndims in range(0, 3): for x2_example_ndims in range(0, 3): # An integer generator. ints = itertools.count(start=2, step=1) feature_shape = [next(ints), next(ints)] x_batch_shape = [next(ints)] y_batch_shape = [next(ints), 1] z_batch_shape = [next(ints), 1, 1] num_x = [next(ints) for _ in range(x1_example_ndims)] num_y = [next(ints) for _ in range(x2_example_ndims)] num_z = [next(ints)] x_shape = x_batch_shape + num_x + feature_shape y_shape = y_batch_shape + num_y + feature_shape z_shape = z_batch_shape + num_z + feature_shape x = np.ones(x_shape, np.float64) y = np.ones(y_shape, np.float64) z = np.random.uniform(-1., 1., size=z_shape) base_kernel = tfpk.ExponentiatedQuadratic( amplitude=np.ones([next(ints), 1, 1, 1], np.float64), feature_ndims=len(feature_shape)) k = tfpk.SchurComplement(base_kernel, fixed_inputs=z) expected = broadcast_shapes(base_kernel.batch_shape, x_batch_shape, y_batch_shape, z_batch_shape) + num_x + num_y mat = k.tensor(x, y, x1_example_ndims=x1_example_ndims, x2_example_ndims=x2_example_ndims) actual = mat.shape self.assertAllEqual(expected, actual)