Exemplo n.º 1
0
  def testShapesAreCorrect(self):
    k = psd_kernels.ExpSinSquared(amplitude=1., length_scale=1., period=3.)

    x = np.ones([4, 3], np.float32)
    y = np.ones([5, 3], np.float32)

    self.assertAllEqual([4, 5], k.matrix(x, y).shape)
    self.assertAllEqual(
        [2, 4, 5],
        k.matrix(tf.stack([x]*2), tf.stack([y]*2)).shape)

    k = psd_kernels.ExpSinSquared(
        amplitude=np.ones([2, 1, 1], np.float32),
        length_scale=np.ones([1, 3, 1], np.float32),
        period=np.ones([2, 1, 1, 1], np.float32))
    self.assertAllEqual(
        [2, 2, 3, 2, 4, 5],
        #`-----'  |  `--'
        # |       |    `- matrix shape
        # |       `- from input batch shapes
        # `- from broadcasting kernel params
        k.matrix(
            tf.stack([x]*2),  # shape [2, 4, 3]
            tf.stack([y]*2)   # shape [2, 5, 3]
        ).shape)
Exemplo n.º 2
0
  def testValidateArgs(self):
    with self.assertRaises(tf.errors.InvalidArgumentError):
      k = psd_kernels.ExpSinSquared(
          amplitude=-1., length_scale=-1., period=-1., validate_args=True)
      self.evaluate(k.amplitude)

    if not tf.executing_eagerly():
      with self.assertRaises(tf.errors.InvalidArgumentError):
        self.evaluate(k.length_scale)

      with self.assertRaises(tf.errors.InvalidArgumentError):
        self.evaluate(k.period)

    # But `None`'s are ok
    k = psd_kernels.ExpSinSquared(
        amplitude=None, length_scale=None, period=None, validate_args=True)
    self.evaluate(k.apply([1.], [1.]))
    def testShapesAreCorrect(self):
        k = psd_kernels.ExpSinSquared(amplitude=1., length_scale=1., period=3.)

        x = np.ones([4, 3], np.float32)
        y = np.ones([5, 3], np.float32)

        self.assertAllEqual(k.matrix(x, y).shape, [4, 5])
        self.assertAllEqual(
            k.matrix(tf.stack([x] * 2), tf.stack([y] * 2)).shape, [2, 4, 5])

        k = psd_kernels.ExpSinSquared(amplitude=np.ones([2, 1, 1], np.float32),
                                      length_scale=np.ones([1, 3, 1],
                                                           np.float32),
                                      period=np.ones([2, 1, 1, 1], np.float32))
        self.assertAllEqual(
            k.matrix(
                tf.stack([x] * 2),  # shape [2, 4, 3]
                tf.stack([y] * 2)  # shape [2, 5, 3]
            ).shape,
            [2, 2, 3, 2, 4, 5])
    def testValuesAreCorrect(self, feature_ndims, dtype, dims):
        amplitude = np.array(5., dtype=dtype)
        period = np.array(1., dtype=dtype)
        length_scale = np.array(.2, dtype=dtype)

        np.random.seed(42)
        k = psd_kernels.ExpSinSquared(amplitude=amplitude,
                                      length_scale=length_scale,
                                      period=period,
                                      feature_ndims=feature_ndims)
        shape = [dims] * feature_ndims
        for _ in range(5):
            x = np.random.uniform(-3., 3., size=shape).astype(dtype)
            y = np.random.uniform(-3., 3., size=shape).astype(dtype)
            self.assertAllClose(self.evaluate(k.apply(x, y)),
                                self._exp_sin_squared_kernel(
                                    amplitude, length_scale, period, x, y),
                                rtol=1e-4)
 def testNoneShapes(self):
     k = psd_kernels.ExpSinSquared(amplitude=np.reshape([1.] * 6, [3, 2]))
     self.assertEqual([3, 2], k.batch_shape.as_list())
 def testMismatchedFloatTypesAreBad(self):
     psd_kernels.ExpSinSquared(1, 1)  # Should be OK (float32 fallback).
     psd_kernels.ExpSinSquared(1, np.float64(1))  # Should be OK.
     with self.assertRaises(TypeError):
         psd_kernels.ExpSinSquared(1, np.float64(1),
                                   np.float32(1))  # Should fail.