Beispiel #1
0
    def testShapesAreCorrect(self):
        base_kernel = tfpk.ExponentiatedQuadratic(np.float64(1.),
                                                  np.float64(1.))
        fixed_inputs = np.random.uniform(-1., 1., size=[2, 3])
        k = tfpk.SchurComplement(base_kernel, fixed_inputs)

        x = np.ones([4, 3], np.float64)
        y = np.ones([5, 3], np.float64)

        self.assertAllEqual([4], k.apply(x, x).shape)
        self.assertAllEqual([5], k.apply(y, y).shape)
        self.assertAllEqual([4, 5], k.matrix(x, y).shape)
        self.assertAllEqual([2, 4, 5],
                            k.matrix(tf.stack([x] * 2),
                                     tf.stack([y] * 2)).shape)

        base_kernel = tfpk.ExponentiatedQuadratic(
            amplitude=np.ones([2, 1, 1], np.float64),
            length_scale=np.ones([1, 3, 1], np.float64))
        # Batch these at the outermost shape position
        fixed_inputs = np.random.uniform(-1., 1., size=[7, 1, 1, 1, 2, 3])
        k = tfpk.SchurComplement(base_kernel, fixed_inputs)
        self.assertAllEqual([7, 2, 3, 4], k.apply(x, x).shape)
        self.assertAllEqual(
            [7, 2, 3, 2, 4, 5],
            #|  `--'  |  `--'
            #|    |   |    `- matrix shape
            #|    |   `- from input batch shapes
            #|    `- from broadcasting kernel params
            #`- from batch of obs index points
            k.matrix(
                tf.stack([x] * 2),  # shape [2, 4, 3]
                tf.stack([y] * 2)  # shape [2, 5, 3]
            ).shape)
Beispiel #2
0
    def testMismatchedFloatTypesAreBad(self):
        base_kernel = tfpk.ExponentiatedQuadratic(np.float64(5.),
                                                  np.float64(.2))

        # Should be OK
        tfpk.SchurComplement(
            base_kernel=base_kernel,  # float64
            fixed_inputs=np.random.uniform(-1., 1., [2, 1]))

        with self.assertRaises(TypeError):
            float32_inputs = np.random.uniform(-1., 1.,
                                               [2, 1]).astype(np.float32)

            tfpk.SchurComplement(base_kernel=base_kernel,
                                 fixed_inputs=float32_inputs)
Beispiel #3
0
    def testBaseKernelNoneDtype(self):
        # Test that we don't have problems when base_kernel has no explicit dtype
        # (ie, params are all None), but fixed_inputs has a different dtype than the
        # "common_dtype" default value of np.float32.
        fixed_inputs = np.arange(3, dtype=np.float64).reshape([3, 1])

        # Should raise when there's an explicit mismatch.
        with self.assertRaises(TypeError):
            schur_complement = tfpk.SchurComplement(
                tfpk.ExponentiatedQuadratic(np.float32(1)), fixed_inputs)

        # Should not throw an exception when the kernel doesn't get an explicit
        # dtype from its inputs.
        schur_complement = tfpk.SchurComplement(tfpk.ExponentiatedQuadratic(),
                                                fixed_inputs)
        schur_complement.matrix(fixed_inputs, fixed_inputs)
Beispiel #4
0
    def testApplyShapesAreCorrect(self):
        for example_ndims in range(0, 4):
            # An integer generator.
            ints = itertools.count(start=2, step=1)
            feature_shape = [next(ints), next(ints)]

            x_batch_shape = [next(ints)]
            z_batch_shape = [next(ints), 1]
            num_x = [next(ints) for _ in range(example_ndims)]
            num_z = [next(ints)]

            x_shape = x_batch_shape + num_x + feature_shape
            z_shape = z_batch_shape + num_z + feature_shape

            x = np.ones(x_shape, np.float64)
            z = np.random.uniform(-1., 1., size=z_shape)

            base_kernel = tfpk.ExponentiatedQuadratic(
                amplitude=np.ones([next(ints), 1, 1], np.float64),
                feature_ndims=len(feature_shape))

            k = tfpk.SchurComplement(base_kernel, fixed_inputs=z)

            expected = broadcast_shapes(base_kernel.batch_shape, x_batch_shape,
                                        z_batch_shape) + num_x
            actual = k.apply(x, x, example_ndims=example_ndims).shape

            self.assertAllEqual(expected, actual)
Beispiel #5
0
    def testEmptyFixedInputs(self):
        base_kernel = tfpk.ExponentiatedQuadratic(1., 1.)
        fixed_inputs = tf.ones([0, 2], np.float32)
        schur = tfpk.SchurComplement(base_kernel, fixed_inputs)

        x = np.ones([4, 3], np.float32)
        y = np.ones([5, 3], np.float32)

        self.assertAllEqual(self.evaluate(base_kernel.matrix(x, y)),
                            self.evaluate(schur.matrix(x, y)))

        # Test batch shapes
        base_kernel = tfpk.ExponentiatedQuadratic([1., 2.])
        fixed_inputs = tf.ones([0, 2], np.float32)
        schur = tfpk.SchurComplement(base_kernel, fixed_inputs)
        self.assertAllEqual([2], schur.batch_shape)
        self.assertAllEqual([2], self.evaluate(schur.batch_shape_tensor()))
Beispiel #6
0
    def testNoneFixedInputs(self):
        base_kernel = tfpk.ExponentiatedQuadratic(1., 1.)
        schur = tfpk.SchurComplement(base_kernel, fixed_inputs=None)

        x = np.ones([4, 3], np.float32)
        y = np.ones([5, 3], np.float32)

        self.assertAllEqual(self.evaluate(base_kernel.matrix(x, y)),
                            self.evaluate(schur.matrix(x, y)))
  def testValuesAreCorrect(self, feature_ndims, dims):
    np.random.seed(42)
    num_obs = 5
    num_x = 3
    num_y = 3

    shape = [dims] * feature_ndims

    base_kernel = tfpk.ExponentiatedQuadratic(
        np.float64(5.), np.float64(.2), feature_ndims=feature_ndims)

    fixed_inputs = np.random.uniform(-1., 1., size=[num_obs] + shape)

    k = tfpk.SchurComplement(
        base_kernel=base_kernel,
        fixed_inputs=fixed_inputs)

    k_obs = self.evaluate(base_kernel.matrix(fixed_inputs, fixed_inputs))

    k_obs_chol_linop = tf.linalg.LinearOperatorLowerTriangular(
        tf.linalg.cholesky(k_obs))
    for _ in range(5):
      x = np.random.uniform(-1, 1, size=[num_x] + shape)
      y = np.random.uniform(-1, 1, size=[num_y] + shape)

      k_x_y = self.evaluate(base_kernel.apply(x, y))
      k_x_obs = self.evaluate(base_kernel.matrix(x, fixed_inputs))
      k_obs_y = self.evaluate(base_kernel.matrix(y, fixed_inputs))

      k_x_obs = np.expand_dims(k_x_obs, -2)
      k_obs_y = np.expand_dims(k_obs_y, -1)

      k_obs_inv_k_obs_y = self.evaluate(
          k_obs_chol_linop.solve(
              k_obs_chol_linop.solve(k_obs_y),
              adjoint=True))

      cov_dec = np.einsum('ijk,ikl->ijl', k_x_obs, k_obs_inv_k_obs_y)
      cov_dec = cov_dec[..., 0, 0]  # np.squeeze didn't like list of axes
      expected = k_x_y - cov_dec
      self.assertAllClose(expected, self.evaluate(k.apply(x, y)))
Beispiel #8
0
    def testTensorShapesAreCorrect(self):
        for x1_example_ndims in range(0, 3):
            for x2_example_ndims in range(0, 3):
                # An integer generator.
                ints = itertools.count(start=2, step=1)
                feature_shape = [next(ints), next(ints)]

                x_batch_shape = [next(ints)]
                y_batch_shape = [next(ints), 1]
                z_batch_shape = [next(ints), 1, 1]

                num_x = [next(ints) for _ in range(x1_example_ndims)]
                num_y = [next(ints) for _ in range(x2_example_ndims)]
                num_z = [next(ints)]

                x_shape = x_batch_shape + num_x + feature_shape
                y_shape = y_batch_shape + num_y + feature_shape
                z_shape = z_batch_shape + num_z + feature_shape

                x = np.ones(x_shape, np.float64)
                y = np.ones(y_shape, np.float64)
                z = np.random.uniform(-1., 1., size=z_shape)

                base_kernel = tfpk.ExponentiatedQuadratic(
                    amplitude=np.ones([next(ints), 1, 1, 1], np.float64),
                    feature_ndims=len(feature_shape))

                k = tfpk.SchurComplement(base_kernel, fixed_inputs=z)

                expected = broadcast_shapes(base_kernel.batch_shape,
                                            x_batch_shape, y_batch_shape,
                                            z_batch_shape) + num_x + num_y

                mat = k.tensor(x,
                               y,
                               x1_example_ndims=x1_example_ndims,
                               x2_example_ndims=x2_example_ndims)
                actual = mat.shape
                self.assertAllEqual(expected, actual)