def testConstructProjectionHead(self, dtype): shape = [3, 4] feature_dims = [2048, 128] expected_output_shape = [3, 128] inputs = tf.random.uniform(shape, seed=1, dtype=dtype) projection_head = projection_head_lib.ProjectionHead( feature_dims=feature_dims) output = projection_head(inputs) self.assertListEqual(expected_output_shape, output.shape.as_list()) self.assertEqual(inputs.dtype, output.dtype)
def testInputOutput(self): feature_dims = (128, 128) expected_output_shape = (3, 128) inputs = tf.random.uniform((3, 4), dtype=tf.float64, seed=1) projection_head = projection_head_lib.ProjectionHead( feature_dims=feature_dims) output_tensor = projection_head(inputs) with self.cached_session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) outputs = sess.run(output_tensor) # Make sure that there are no NaNs self.assertFalse(np.isnan(outputs).any()) self.assertEqual(outputs.shape, expected_output_shape)
def testBatchNormIsTraining(self, is_training): feature_dims = (128, 128) inputs = tf.random.uniform((3, 4), dtype=tf.float64, seed=1) projection_head = projection_head_lib.ProjectionHead( feature_dims=feature_dims, use_batch_norm=True) outputs = projection_head(inputs, training=is_training) statistics_vars = [ var for var in tf.all_variables() if 'moving_' in var.name ] self.assertLen(statistics_vars, 2) grads = tf.gradients(outputs, statistics_vars) self.assertLen(grads, 2) if is_training: self.assertAllEqual([None, None], grads) self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) else: self.assertNotIn(None, grads)
def __init__(self, architecture=enums.EncoderArchitecture.RESNET_V1, normalize_projection_head_input=True, normalize_classification_head_input=True, stop_gradient_before_projection_head=False, stop_gradient_before_classification_head=True, encoder_kwargs=None, projection_head_kwargs=None, classification_head_kwargs=None, name='ContrastiveModel', **kwargs): super(ContrastiveModel, self).__init__(name=name, **kwargs) self.normalize_projection_head_input = normalize_projection_head_input self.normalize_classification_head_input = ( normalize_classification_head_input) self.stop_gradient_before_projection_head = ( stop_gradient_before_projection_head) self.stop_gradient_before_classification_head = ( stop_gradient_before_classification_head) encoder_fns = { enums.EncoderArchitecture.RESNET_V1: resnet.ResNetV1, enums.EncoderArchitecture.RESNEXT: resnet.ResNext, } if architecture not in encoder_fns: raise ValueError( f'Architecture should be one of {encoder_fns.keys()}, ' f'found: {architecture}.') encoder_fn = encoder_fns[architecture] assert encoder_kwargs is not None projection_head_kwargs = projection_head_kwargs or {} classification_head_kwargs = classification_head_kwargs or {} self.encoder = encoder_fn(name='Encoder', **encoder_kwargs) self.projection_head = projection_head.ProjectionHead( **projection_head_kwargs) self.classification_head = classification_head.ClassificationHead( **classification_head_kwargs)
def testCreateVariables(self, num_projection_layers, use_batch_norm, use_batch_norm_beta): feature_dims = (128, ) * num_projection_layers inputs = tf.random.uniform((3, 4), dtype=tf.float64, seed=1) projection_head = projection_head_lib.ProjectionHead( feature_dims=feature_dims, use_batch_norm=use_batch_norm, use_batch_norm_beta=use_batch_norm_beta) projection_head(inputs) self.assertLen( [var for var in tf.trainable_variables() if 'kernel' in var.name], num_projection_layers) self.assertLen( [var for var in tf.trainable_variables() if 'bias' in var.name], 0 if use_batch_norm else num_projection_layers - 1) self.assertLen( [var for var in tf.trainable_variables() if 'gamma' in var.name], num_projection_layers - 1 if use_batch_norm else 0) self.assertLen( [var for var in tf.trainable_variables() if 'beta' in var.name], (num_projection_layers - 1 if (use_batch_norm and use_batch_norm_beta) else 0))
def testGradient(self): inputs = tf.random.uniform((3, 4), dtype=tf.float64, seed=1) projection_head = projection_head_lib.ProjectionHead() output = projection_head(inputs) gradient = tf.gradients(output, inputs) self.assertIsNotNone(gradient)
def testIncorrectRank(self, rank): inputs = tf.compat.v1.placeholder(tf.float32, shape=[10] * rank) with self.assertRaisesRegex(ValueError, 'is expected to have rank 2'): projection_head = projection_head_lib.ProjectionHead() projection_head(inputs)