def testCreateVariables(self): inputs = tf.random.uniform((3, 4), dtype=tf.float64, seed=1) classifier = classification_head.ClassificationHead(num_classes=10) classifier(inputs) self.assertLen( [var for var in tf.trainable_variables() if 'kernel' in var.name], 1) self.assertLen( [var for var in tf.trainable_variables() if 'bias' in var.name], 1)
def testConstructClassificationHead(self, dtype): batch_size = 3 num_classes = 10 input_shape = [batch_size, 4] expected_output_shape = [batch_size, num_classes] inputs = tf.random.uniform(input_shape, seed=1, dtype=dtype) classifier = classification_head.ClassificationHead(num_classes=num_classes) output = classifier(inputs) self.assertListEqual(expected_output_shape, output.shape.as_list()) self.assertEqual(inputs.dtype, output.dtype)
def testInputOutput(self): batch_size = 3 num_classes = 10 expected_output_shape = (batch_size, num_classes) inputs = tf.random.uniform((batch_size, 4), dtype=tf.float64, seed=1) classifier = classification_head.ClassificationHead(num_classes=num_classes) output_tensor = classifier(inputs) with self.cached_session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) outputs = sess.run(output_tensor) # Make sure that there are no NaNs self.assertFalse(np.isnan(outputs).any()) self.assertEqual(outputs.shape, expected_output_shape)
def __init__(self, architecture=enums.EncoderArchitecture.RESNET_V1, normalize_projection_head_input=True, normalize_classification_head_input=True, stop_gradient_before_projection_head=False, stop_gradient_before_classification_head=True, encoder_kwargs=None, projection_head_kwargs=None, classification_head_kwargs=None, name='ContrastiveModel', **kwargs): super(ContrastiveModel, self).__init__(name=name, **kwargs) self.normalize_projection_head_input = normalize_projection_head_input self.normalize_classification_head_input = ( normalize_classification_head_input) self.stop_gradient_before_projection_head = ( stop_gradient_before_projection_head) self.stop_gradient_before_classification_head = ( stop_gradient_before_classification_head) encoder_fns = { enums.EncoderArchitecture.RESNET_V1: resnet.ResNetV1, enums.EncoderArchitecture.RESNEXT: resnet.ResNext, } if architecture not in encoder_fns: raise ValueError( f'Architecture should be one of {encoder_fns.keys()}, ' f'found: {architecture}.') encoder_fn = encoder_fns[architecture] assert encoder_kwargs is not None projection_head_kwargs = projection_head_kwargs or {} classification_head_kwargs = classification_head_kwargs or {} self.encoder = encoder_fn(name='Encoder', **encoder_kwargs) self.projection_head = projection_head.ProjectionHead( **projection_head_kwargs) self.classification_head = classification_head.ClassificationHead( **classification_head_kwargs)
def testGradient(self): inputs = tf.random.uniform((3, 4), dtype=tf.float64, seed=1) classifier = classification_head.ClassificationHead(num_classes=10) output = classifier(inputs) gradient = tf.gradients(output, inputs) self.assertIsNotNone(gradient)
def testIncorrectRank(self, rank): inputs = tf.compat.v1.placeholder(tf.float32, shape=[10] * rank) with self.assertRaisesRegex(ValueError, 'is expected to have rank 2'): classifier = classification_head.ClassificationHead(num_classes=10) classifier(inputs)