def testConstructProjectionHead(self, dtype):
     shape = [3, 4]
     feature_dims = [2048, 128]
     expected_output_shape = [3, 128]
     inputs = tf.random.uniform(shape, seed=1, dtype=dtype)
     projection_head = projection_head_lib.ProjectionHead(
         feature_dims=feature_dims)
     output = projection_head(inputs)
     self.assertListEqual(expected_output_shape, output.shape.as_list())
     self.assertEqual(inputs.dtype, output.dtype)
 def testInputOutput(self):
     feature_dims = (128, 128)
     expected_output_shape = (3, 128)
     inputs = tf.random.uniform((3, 4), dtype=tf.float64, seed=1)
     projection_head = projection_head_lib.ProjectionHead(
         feature_dims=feature_dims)
     output_tensor = projection_head(inputs)
     with self.cached_session() as sess:
         sess.run(tf.compat.v1.global_variables_initializer())
         outputs = sess.run(output_tensor)
         # Make sure that there are no NaNs
         self.assertFalse(np.isnan(outputs).any())
         self.assertEqual(outputs.shape, expected_output_shape)
 def testBatchNormIsTraining(self, is_training):
     feature_dims = (128, 128)
     inputs = tf.random.uniform((3, 4), dtype=tf.float64, seed=1)
     projection_head = projection_head_lib.ProjectionHead(
         feature_dims=feature_dims, use_batch_norm=True)
     outputs = projection_head(inputs, training=is_training)
     statistics_vars = [
         var for var in tf.all_variables() if 'moving_' in var.name
     ]
     self.assertLen(statistics_vars, 2)
     grads = tf.gradients(outputs, statistics_vars)
     self.assertLen(grads, 2)
     if is_training:
         self.assertAllEqual([None, None], grads)
         self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
     else:
         self.assertNotIn(None, grads)
Esempio n. 4
0
    def __init__(self,
                 architecture=enums.EncoderArchitecture.RESNET_V1,
                 normalize_projection_head_input=True,
                 normalize_classification_head_input=True,
                 stop_gradient_before_projection_head=False,
                 stop_gradient_before_classification_head=True,
                 encoder_kwargs=None,
                 projection_head_kwargs=None,
                 classification_head_kwargs=None,
                 name='ContrastiveModel',
                 **kwargs):
        super(ContrastiveModel, self).__init__(name=name, **kwargs)

        self.normalize_projection_head_input = normalize_projection_head_input
        self.normalize_classification_head_input = (
            normalize_classification_head_input)
        self.stop_gradient_before_projection_head = (
            stop_gradient_before_projection_head)
        self.stop_gradient_before_classification_head = (
            stop_gradient_before_classification_head)

        encoder_fns = {
            enums.EncoderArchitecture.RESNET_V1: resnet.ResNetV1,
            enums.EncoderArchitecture.RESNEXT: resnet.ResNext,
        }
        if architecture not in encoder_fns:
            raise ValueError(
                f'Architecture should be one of {encoder_fns.keys()}, '
                f'found: {architecture}.')
        encoder_fn = encoder_fns[architecture]

        assert encoder_kwargs is not None
        projection_head_kwargs = projection_head_kwargs or {}
        classification_head_kwargs = classification_head_kwargs or {}

        self.encoder = encoder_fn(name='Encoder', **encoder_kwargs)
        self.projection_head = projection_head.ProjectionHead(
            **projection_head_kwargs)
        self.classification_head = classification_head.ClassificationHead(
            **classification_head_kwargs)
 def testCreateVariables(self, num_projection_layers, use_batch_norm,
                         use_batch_norm_beta):
     feature_dims = (128, ) * num_projection_layers
     inputs = tf.random.uniform((3, 4), dtype=tf.float64, seed=1)
     projection_head = projection_head_lib.ProjectionHead(
         feature_dims=feature_dims,
         use_batch_norm=use_batch_norm,
         use_batch_norm_beta=use_batch_norm_beta)
     projection_head(inputs)
     self.assertLen(
         [var for var in tf.trainable_variables() if 'kernel' in var.name],
         num_projection_layers)
     self.assertLen(
         [var for var in tf.trainable_variables() if 'bias' in var.name],
         0 if use_batch_norm else num_projection_layers - 1)
     self.assertLen(
         [var for var in tf.trainable_variables() if 'gamma' in var.name],
         num_projection_layers - 1 if use_batch_norm else 0)
     self.assertLen(
         [var for var in tf.trainable_variables() if 'beta' in var.name],
         (num_projection_layers - 1 if
          (use_batch_norm and use_batch_norm_beta) else 0))
 def testGradient(self):
     inputs = tf.random.uniform((3, 4), dtype=tf.float64, seed=1)
     projection_head = projection_head_lib.ProjectionHead()
     output = projection_head(inputs)
     gradient = tf.gradients(output, inputs)
     self.assertIsNotNone(gradient)
 def testIncorrectRank(self, rank):
     inputs = tf.compat.v1.placeholder(tf.float32, shape=[10] * rank)
     with self.assertRaisesRegex(ValueError, 'is expected to have rank 2'):
         projection_head = projection_head_lib.ProjectionHead()
         projection_head(inputs)