コード例 #1
0
  def testValidateArgs(self):
    k = psd_kernels.MaternThreeHalves(
        amplitude=-1., length_scale=-1., validate_args=True)
    with self.assertRaises(tf.errors.InvalidArgumentError):
      self.evaluate(k.amplitude)

    with self.assertRaises(tf.errors.InvalidArgumentError):
      self.evaluate(k.length_scale)

    # But `None`'s are ok
    k = psd_kernels.MaternThreeHalves(
        amplitude=None, length_scale=None, validate_args=True)
    self.evaluate(k.apply([1.], [1.]))
コード例 #2
0
  def testShapesAreCorrect(self):
    k = psd_kernels.MaternThreeHalves(amplitude=1., length_scale=1.)

    x = np.ones([4, 3], np.float32)
    y = np.ones([5, 3], np.float32)

    self.assertAllEqual(k.matrix(x, y).shape, [4, 5])
    self.assertAllEqual(
        k.matrix(tf.stack([x] * 2), tf.stack([y] * 2)).shape, [2, 4, 5])

    k = psd_kernels.MaternThreeHalves(
        amplitude=np.ones([2, 1, 1], np.float32),
        length_scale=np.ones([1, 3, 1], np.float32))
    self.assertAllEqual(
        k.matrix(
            tf.stack([x] * 2),  # shape [2, 4, 3]
            tf.stack([y] * 2)  # shape [2, 5, 3]
        ).shape,
        [2, 3, 2, 4, 5])
コード例 #3
0
  def testValuesAreCorrect(self, feature_ndims, dtype, dims):
    amplitude = np.array(5., dtype=dtype)
    length_scale = np.array(.2, dtype=dtype)

    np.random.seed(42)
    k = psd_kernels.MaternThreeHalves(amplitude, length_scale, feature_ndims)
    shape = [dims] * feature_ndims
    for _ in range(5):
      x = np.random.uniform(-1, 1, size=shape).astype(dtype)
      y = np.random.uniform(-1, 1, size=shape).astype(dtype)
      self.assertAllClose(
          self.evaluate(k.apply(x, y)),
          self._matern_three_halves(amplitude, length_scale, x, y))
コード例 #4
0
 def testMismatchedFloatTypesAreBad(self):
   with self.assertRaises(ValueError):
     psd_kernels.MaternThreeHalves(np.float32(1.), np.float64(1.))