コード例 #1
0
    def __init__(self,
                 membership_transform=None,
                 predictions_transform=None,
                 membership_kernel=None,
                 predictions_kernel=None,
                 name: Optional[Text] = None):
        """Initialize `MinDiffLoss` instance.

    Raises:
      ValueError: If a `*_transform` parameter is passed in but is not callable.
      ValueError: If a `*_kernel` parameter has an unrecognized type or value.
    """
        super(MinDiffLoss,
              self).__init__(reduction=tf.keras.losses.Reduction.NONE,
                             name=name)
        self.name = name or _to_snake_case(self.__class__.__name__)
        _validate_transform(membership_transform, 'membership_transform')
        self.membership_transform = (membership_transform)
        _validate_transform(predictions_transform, 'predictions_transform')
        self.predictions_transform = predictions_transform

        self.membership_kernel = kernel_utils._get_kernel(
            membership_kernel, 'membership_kernel')
        self.predictions_kernel = kernel_utils._get_kernel(
            predictions_kernel, 'predictions_kernel')
コード例 #2
0
  def testGetKernelRaisesErrors(self):
    with self.assertRaisesRegex(
        TypeError, 'custom_name.*must be.*MinDiffKernel.*string.*4.*int'):
      utils._get_kernel(4, 'custom_name')

    with self.assertRaisesRegex(
        ValueError, 'custom_name.*must be.*supported values.*bad_name'):
      utils._get_kernel('bad_name', 'custom_name')
コード例 #3
0
    def testForCustomKernel(self):
        class CustomKernel(base_kernel.MinDiffKernel):
            def call(self, x, y):
                pass

        kernel = CustomKernel()
        kernel_output = utils._get_kernel(kernel)
        self.assertIs(kernel_output, kernel)
コード例 #4
0
 def testForLaplacianKernel(self):
   kernel = utils._get_kernel('laplace')
   self.assertIsInstance(kernel, laplacian_kernel.LaplacianKernel)
   kernel = utils._get_kernel('laplace_Kernel')
   self.assertIsInstance(kernel, laplacian_kernel.LaplacianKernel)
   kernel = utils._get_kernel('laplacian')
   self.assertIsInstance(kernel, laplacian_kernel.LaplacianKernel)
   kernel = utils._get_kernel('laplacian_kernel')
   self.assertIsInstance(kernel, laplacian_kernel.LaplacianKernel)
   kernel = utils._get_kernel(laplacian_kernel.LaplacianKernel())
   self.assertIsInstance(kernel, laplacian_kernel.LaplacianKernel)
   kernel_length = 3
   kernel = utils._get_kernel(laplacian_kernel.LaplacianKernel(kernel_length))
   self.assertIsInstance(kernel, laplacian_kernel.LaplacianKernel)
   self.assertEqual(kernel.kernel_length, kernel_length)
コード例 #5
0
 def testForGaussianKernel(self):
   kernel = utils._get_kernel('gauss')
   self.assertIsInstance(kernel, gaussian_kernel.GaussianKernel)
   kernel = utils._get_kernel('GauSs')  # Strangely capitalized.
   self.assertIsInstance(kernel, gaussian_kernel.GaussianKernel)
   kernel = utils._get_kernel('gauss_kernel')
   self.assertIsInstance(kernel, gaussian_kernel.GaussianKernel)
   kernel = utils._get_kernel('gaussian')
   self.assertIsInstance(kernel, gaussian_kernel.GaussianKernel)
   kernel = utils._get_kernel('gaussian_kernel')
   self.assertIsInstance(kernel, gaussian_kernel.GaussianKernel)
   kernel_length = 3
   kernel = utils._get_kernel(gaussian_kernel.GaussianKernel(kernel_length))
   self.assertIsInstance(kernel, gaussian_kernel.GaussianKernel)
   self.assertEqual(kernel.kernel_length, kernel_length)
コード例 #6
0
 def testAcceptsNone(self):
     kernel = utils._get_kernel(None)
     self.assertIsNone(kernel)