Example #1
0
 def test_batch_random_blur(self, blur_prob):
     batch_size = 100
     side_length = 30
     image_batch = np.random.uniform(low=-1.,
                                     high=1.,
                                     size=(batch_size, side_length,
                                           side_length,
                                           3)).astype(np.float32)
     output = preprocessing.batch_random_blur(tf.constant(image_batch),
                                              side_length, blur_prob)
     with self.cached_session() as sess:
         output_np = sess.run(output)
     num_blurred = 0
     for i in range(batch_size):
         if not np.allclose(image_batch[i, Ellipsis], output_np[i,
                                                                Ellipsis]):
             num_blurred += 1
     # Note there is some chance that these will fail due to randomness, but it
     # should be rare.
     if blur_prob < 1.:
         self.assertLess(num_blurred, batch_size)
     else:
         self.assertEqual(num_blurred, batch_size)
     if blur_prob > 0.:
         self.assertGreater(num_blurred, 0)
     else:
         self.assertEqual(num_blurred, 0)
Example #2
0
    def _call_model(self, training):
        """Passes data through the model.

    Manipulates the input data to get it ready for passing into the model,
    including applying some data augmentation that is more efficient to apply on
    the TPU than on the host. It then passes it into the model, which will first
    build the model and create its variables.

    Args:
      training: Whether the model should be run in training mode.

    Returns:
      A tuple of the model outputs (as Tensors):
      * unnormalized_embedding: The output of the encoder, not including
        normalization, which is sometimes applied before this gets passed into
        the projection and classification heads.
      * normalized_embedding: A normalized version of `unnormalized_embedding`.
      * projection: The output of the projection head.
      * logits: The output of the classification head.
    """
        with tf.name_scope('call_model'):
            model_inputs = self.model_inputs

            # In most cases, the data format NCHW instead of NHWC should be used for a
            # significant performance boost on GPU. NHWC should be used only if the
            # network needs to be run on CPU since the pooling operations are only
            # supported on NHWC. TPU uses XLA compiler to figure out best layout.
            if self.data_format == 'channels_first':
                model_inputs = tf.transpose(model_inputs, [0, 3, 1, 2])

            channels_index = 1 if self.data_format == 'channels_first' else -1
            inputs_are_multiview = tf.compat.dimension_value(
                model_inputs.shape[channels_index]) > 3
            if inputs_are_multiview:
                model_inputs = utils.stacked_multiview_image_channels_to_batch(
                    model_inputs, self.data_format)

            # Perform blur augmentations here, since they're faster on TPU than CPU.
            if (self.hparams.input_data.preprocessing.augmentation_type
                    in (enums.AugmentationType.SIMCLR,
                        enums.AugmentationType.STACKED_RANDAUGMENT) and
                    self.hparams.input_data.preprocessing.blur_probability > 0.
                    and self.hparams.input_data.preprocessing.defer_blurring
                    and self.train):
                model_inputs = preprocessing.batch_random_blur(
                    model_inputs,
                    tf.compat.dimension_value(model_inputs.shape[1]),
                    blur_probability=(self.hparams.input_data.preprocessing.
                                      blur_probability))

            with tf.tpu.bfloat16_scope():
                model_outputs = self.model(model_inputs, training)

            if inputs_are_multiview:
                model_outputs = [
                    utils.stacked_multiview_embeddings_to_channel(
                        tf.cast(x, tf.float32)) if x is not None else x
                    for x in model_outputs
                ]

            (unnormalized_embedding, normalized_embedding, projection,
             logits) = model_outputs

            if inputs_are_multiview:
                # If we keep everything in batch dimension then we don't need this. In
                # cross_entropy mode we should just stop generating the 2nd
                # augmentation.
                logits = tf.split(logits, 2, axis=1)[0]

            return unnormalized_embedding, normalized_embedding, projection, logits