def __init__(self, membership_transform=None, predictions_transform=None, membership_kernel=None, predictions_kernel=None, name: Optional[Text] = None): """Initialize `MinDiffLoss` instance. Raises: ValueError: If a `*_transform` parameter is passed in but is not callable. ValueError: If a `*_kernel` parameter has an unrecognized type or value. """ super(MinDiffLoss, self).__init__(reduction=tf.keras.losses.Reduction.NONE, name=name) self.name = name or _to_snake_case(self.__class__.__name__) _validate_transform(membership_transform, 'membership_transform') self.membership_transform = (membership_transform) _validate_transform(predictions_transform, 'predictions_transform') self.predictions_transform = predictions_transform self.membership_kernel = kernel_utils._get_kernel( membership_kernel, 'membership_kernel') self.predictions_kernel = kernel_utils._get_kernel( predictions_kernel, 'predictions_kernel')
def testGetKernelRaisesErrors(self): with self.assertRaisesRegex( TypeError, 'custom_name.*must be.*MinDiffKernel.*string.*4.*int'): utils._get_kernel(4, 'custom_name') with self.assertRaisesRegex( ValueError, 'custom_name.*must be.*supported values.*bad_name'): utils._get_kernel('bad_name', 'custom_name')
def testForCustomKernel(self): class CustomKernel(base_kernel.MinDiffKernel): def call(self, x, y): pass kernel = CustomKernel() kernel_output = utils._get_kernel(kernel) self.assertIs(kernel_output, kernel)
def testForLaplacianKernel(self): kernel = utils._get_kernel('laplace') self.assertIsInstance(kernel, laplacian_kernel.LaplacianKernel) kernel = utils._get_kernel('laplace_Kernel') self.assertIsInstance(kernel, laplacian_kernel.LaplacianKernel) kernel = utils._get_kernel('laplacian') self.assertIsInstance(kernel, laplacian_kernel.LaplacianKernel) kernel = utils._get_kernel('laplacian_kernel') self.assertIsInstance(kernel, laplacian_kernel.LaplacianKernel) kernel = utils._get_kernel(laplacian_kernel.LaplacianKernel()) self.assertIsInstance(kernel, laplacian_kernel.LaplacianKernel) kernel_length = 3 kernel = utils._get_kernel(laplacian_kernel.LaplacianKernel(kernel_length)) self.assertIsInstance(kernel, laplacian_kernel.LaplacianKernel) self.assertEqual(kernel.kernel_length, kernel_length)
def testForGaussianKernel(self): kernel = utils._get_kernel('gauss') self.assertIsInstance(kernel, gaussian_kernel.GaussianKernel) kernel = utils._get_kernel('GauSs') # Strangely capitalized. self.assertIsInstance(kernel, gaussian_kernel.GaussianKernel) kernel = utils._get_kernel('gauss_kernel') self.assertIsInstance(kernel, gaussian_kernel.GaussianKernel) kernel = utils._get_kernel('gaussian') self.assertIsInstance(kernel, gaussian_kernel.GaussianKernel) kernel = utils._get_kernel('gaussian_kernel') self.assertIsInstance(kernel, gaussian_kernel.GaussianKernel) kernel_length = 3 kernel = utils._get_kernel(gaussian_kernel.GaussianKernel(kernel_length)) self.assertIsInstance(kernel, gaussian_kernel.GaussianKernel) self.assertEqual(kernel.kernel_length, kernel_length)
def testAcceptsNone(self): kernel = utils._get_kernel(None) self.assertIsNone(kernel)