Beispiel #1
0
 def testCreateVariables(self):
   inputs = tf.random.uniform((3, 4), dtype=tf.float64, seed=1)
   classifier = classification_head.ClassificationHead(num_classes=10)
   classifier(inputs)
   self.assertLen(
       [var for var in tf.trainable_variables() if 'kernel' in var.name], 1)
   self.assertLen(
       [var for var in tf.trainable_variables() if 'bias' in var.name], 1)
Beispiel #2
0
 def testConstructClassificationHead(self, dtype):
   batch_size = 3
   num_classes = 10
   input_shape = [batch_size, 4]
   expected_output_shape = [batch_size, num_classes]
   inputs = tf.random.uniform(input_shape, seed=1, dtype=dtype)
   classifier = classification_head.ClassificationHead(num_classes=num_classes)
   output = classifier(inputs)
   self.assertListEqual(expected_output_shape, output.shape.as_list())
   self.assertEqual(inputs.dtype, output.dtype)
Beispiel #3
0
 def testInputOutput(self):
   batch_size = 3
   num_classes = 10
   expected_output_shape = (batch_size, num_classes)
   inputs = tf.random.uniform((batch_size, 4), dtype=tf.float64, seed=1)
   classifier = classification_head.ClassificationHead(num_classes=num_classes)
   output_tensor = classifier(inputs)
   with self.cached_session() as sess:
     sess.run(tf.compat.v1.global_variables_initializer())
     outputs = sess.run(output_tensor)
     # Make sure that there are no NaNs
     self.assertFalse(np.isnan(outputs).any())
     self.assertEqual(outputs.shape, expected_output_shape)
Beispiel #4
0
    def __init__(self,
                 architecture=enums.EncoderArchitecture.RESNET_V1,
                 normalize_projection_head_input=True,
                 normalize_classification_head_input=True,
                 stop_gradient_before_projection_head=False,
                 stop_gradient_before_classification_head=True,
                 encoder_kwargs=None,
                 projection_head_kwargs=None,
                 classification_head_kwargs=None,
                 name='ContrastiveModel',
                 **kwargs):
        super(ContrastiveModel, self).__init__(name=name, **kwargs)

        self.normalize_projection_head_input = normalize_projection_head_input
        self.normalize_classification_head_input = (
            normalize_classification_head_input)
        self.stop_gradient_before_projection_head = (
            stop_gradient_before_projection_head)
        self.stop_gradient_before_classification_head = (
            stop_gradient_before_classification_head)

        encoder_fns = {
            enums.EncoderArchitecture.RESNET_V1: resnet.ResNetV1,
            enums.EncoderArchitecture.RESNEXT: resnet.ResNext,
        }
        if architecture not in encoder_fns:
            raise ValueError(
                f'Architecture should be one of {encoder_fns.keys()}, '
                f'found: {architecture}.')
        encoder_fn = encoder_fns[architecture]

        assert encoder_kwargs is not None
        projection_head_kwargs = projection_head_kwargs or {}
        classification_head_kwargs = classification_head_kwargs or {}

        self.encoder = encoder_fn(name='Encoder', **encoder_kwargs)
        self.projection_head = projection_head.ProjectionHead(
            **projection_head_kwargs)
        self.classification_head = classification_head.ClassificationHead(
            **classification_head_kwargs)
Beispiel #5
0
 def testGradient(self):
     inputs = tf.random.uniform((3, 4), dtype=tf.float64, seed=1)
     classifier = classification_head.ClassificationHead(num_classes=10)
     output = classifier(inputs)
     gradient = tf.gradients(output, inputs)
     self.assertIsNotNone(gradient)
Beispiel #6
0
 def testIncorrectRank(self, rank):
     inputs = tf.compat.v1.placeholder(tf.float32, shape=[10] * rank)
     with self.assertRaisesRegex(ValueError, 'is expected to have rank 2'):
         classifier = classification_head.ClassificationHead(num_classes=10)
         classifier(inputs)