def testValidateArgs(self): k = psd_kernels.MaternThreeHalves( amplitude=-1., length_scale=-1., validate_args=True) with self.assertRaises(tf.errors.InvalidArgumentError): self.evaluate(k.amplitude) with self.assertRaises(tf.errors.InvalidArgumentError): self.evaluate(k.length_scale) # But `None`'s are ok k = psd_kernels.MaternThreeHalves( amplitude=None, length_scale=None, validate_args=True) self.evaluate(k.apply([1.], [1.]))
def testShapesAreCorrect(self): k = psd_kernels.MaternThreeHalves(amplitude=1., length_scale=1.) x = np.ones([4, 3], np.float32) y = np.ones([5, 3], np.float32) self.assertAllEqual(k.matrix(x, y).shape, [4, 5]) self.assertAllEqual( k.matrix(tf.stack([x] * 2), tf.stack([y] * 2)).shape, [2, 4, 5]) k = psd_kernels.MaternThreeHalves( amplitude=np.ones([2, 1, 1], np.float32), length_scale=np.ones([1, 3, 1], np.float32)) self.assertAllEqual( k.matrix( tf.stack([x] * 2), # shape [2, 4, 3] tf.stack([y] * 2) # shape [2, 5, 3] ).shape, [2, 3, 2, 4, 5])
def testValuesAreCorrect(self, feature_ndims, dtype, dims): amplitude = np.array(5., dtype=dtype) length_scale = np.array(.2, dtype=dtype) np.random.seed(42) k = psd_kernels.MaternThreeHalves(amplitude, length_scale, feature_ndims) shape = [dims] * feature_ndims for _ in range(5): x = np.random.uniform(-1, 1, size=shape).astype(dtype) y = np.random.uniform(-1, 1, size=shape).astype(dtype) self.assertAllClose( self.evaluate(k.apply(x, y)), self._matern_three_halves(amplitude, length_scale, x, y))
def testMismatchedFloatTypesAreBad(self): with self.assertRaises(ValueError): psd_kernels.MaternThreeHalves(np.float32(1.), np.float64(1.))