def test_preprocess_image(self, decode_image, image_dtype, image_size,
                            is_training, augmentation_type, warp_prob,
                            augmentation_magnitude, eval_crop_method,
                            image_mean_std):
    # tf.random.uniform() doesn't allow generating random values that are uint8.
    image = tf.cast(
        tf.random.uniform(
            shape=(300, 400, 3), minval=0, maxval=255, dtype=tf.float32),
        dtype=tf.uint8)
    if decode_image:
      image = tf.image.encode_jpeg(image)
    else:
      image = tf.cast(image, image_dtype)

    expect_error = (not decode_image and image_dtype != tf.uint8)
    if expect_error:
      context_manager = self.assertRaises(AssertionError)
    else:
      context_manager = nullcontext()
    with context_manager:
      output = preprocessing.preprocess_image(
          image,
          is_training=is_training,
          bfloat16_supported=False,
          preprocessing_options=hparams.ImagePreprocessing(
              image_size=image_size,
              augmentation_type=augmentation_type,
              warp_probability=warp_prob,
              augmentation_magnitude=augmentation_magnitude,
              eval_crop_method=eval_crop_method,
          ),
          dataset_options=preprocessing.DatasetOptions(
              image_mean_std=image_mean_std, decode_input=decode_image))
      self.assertEqual(output.dtype, tf.float32)
      self.assertEqual([image_size, image_size, 3], output.shape.as_list())
Beispiel #2
0
    def test_input_class(self, input_class, model_mode, image_size,
                         max_samples):
        split = 'train' if model_mode == enums.ModelMode.TRAIN else 'test'
        batch_size = 2
        dataset_size = 10
        expected_num_batches = dataset_size // batch_size
        if max_samples is not None and model_mode == enums.ModelMode.TRAIN:
            expected_num_batches = max_samples // batch_size

        params = {'batch_size': batch_size}
        if input_class == 'TfdsInput':
            with tfds.testing.mock_data(num_examples=dataset_size):
                data = inputs.TfdsInput(
                    'cifar10',
                    split,
                    mode=model_mode,
                    preprocessor=preprocessing.
                    ImageToMultiViewedImagePreprocessor(
                        is_training=model_mode == enums.ModelMode.TRAIN,
                        preprocessing_options=hparams.ImagePreprocessing(
                            image_size=image_size, num_views=2),
                        dataset_options=preprocessing.DatasetOptions(
                            decode_input=False)),
                    max_samples=max_samples,
                    num_classes=10).input_fn(params)
        else:
            raise ValueError(f'Unknown input class {input_class}')

        expected_num_channels = 3 if model_mode == enums.ModelMode.INFERENCE else 6
        expected_batch_size = (None if model_mode == enums.ModelMode.INFERENCE
                               else batch_size)

        if model_mode == enums.ModelMode.INFERENCE:
            self.assertIsInstance(
                data, tf.estimator.export.TensorServingInputReceiver)
            image_shape = data.features.shape.as_list()
        else:
            self.assertIsInstance(data, tf.data.Dataset)
            shapes = tf.data.get_output_shapes(data)
            image_shape = shapes[0].as_list()
            label_shape = shapes[1].as_list()
            self.assertEqual([expected_batch_size], label_shape)
        self.assertEqual([
            expected_batch_size, image_size, image_size, expected_num_channels
        ], image_shape)

        if model_mode == enums.ModelMode.INFERENCE:
            return

        # Now extract the Tensors
        data = tf.data.make_one_shot_iterator(data).get_next()[0]

        with self.cached_session() as sess:
            for i in range(expected_num_batches + 1):
                if i == expected_num_batches and model_mode == enums.ModelMode.EVAL:
                    with self.assertRaises(tf.errors.OutOfRangeError):
                        sess.run(data)
                    break
                else:
                    sess.run(data)
  def test_image_to_multi_viewed_image_preprocessor(self, is_training,
                                                    decode_image, num_views):
    # tf.random.uniform() doesn't allow generating random values that are uint8.
    image = tf.cast(
        tf.random.uniform(
            shape=(300, 400, 3), minval=0, maxval=255, dtype=tf.float32),
        dtype=tf.uint8)
    if decode_image:
      image = tf.image.encode_jpeg(image)

    image_size = 32
    preprocessor = preprocessing.ImageToMultiViewedImagePreprocessor(
        is_training=is_training,
        preprocessing_options=hparams.ImagePreprocessing(
            image_size=image_size, num_views=num_views),
        dataset_options=preprocessing.DatasetOptions(decode_input=decode_image))
    output = preprocessor.preprocess(image)
    self.assertEqual(output.dtype, tf.float32)
    self.assertEqual([image_size, image_size, 3 * num_views],
                     output.shape.as_list())
def hparams_from_flags():
    return hparams.HParams(
        bs=FLAGS.batch_size,
        architecture=hparams.Architecture(
            encoder_architecture=FLAGS.resnet_architecture,
            encoder_depth=FLAGS.resnet_depth,
            encoder_width=FLAGS.resnet_width,
            first_conv_kernel_size=FLAGS.first_conv_kernel_size,
            first_conv_stride=FLAGS.first_conv_stride,
            use_initial_max_pool=FLAGS.use_initial_max_pool,
            projection_head_layers=tuple(map(int,
                                             FLAGS.projection_head_layers)),
            projection_head_use_batch_norm=FLAGS.use_projection_batch_norm,
            projection_head_use_batch_norm_beta=(
                FLAGS.use_projection_batch_norm_beta),
            normalize_projection_head_inputs=FLAGS.normalize_embedding,
            normalize_classifier_inputs=FLAGS.normalize_embedding,
            zero_initialize_classifier=FLAGS.zero_initialize_classifier,
            stop_gradient_before_classification_head=(
                FLAGS.stop_gradient_before_classification_head),
            stop_gradient_before_projection_head=(
                FLAGS.stop_gradient_before_projection_head),
            use_global_batch_norm=FLAGS.use_global_batch_norm),
        loss_all_stages=hparams.LossAllStages(
            contrastive=hparams.ContrastiveLoss(
                use_labels=FLAGS.use_labels,
                temperature=FLAGS.temperature,
                contrast_mode=FLAGS.contrast_mode,
                summation_location=FLAGS.summation_location,
                denominator_mode=FLAGS.denominator_mode,
                positives_cap=FLAGS.positives_cap,
                scale_by_temperature=FLAGS.scale_by_temperature),
            cross_entropy=hparams.CrossEntropyLoss(
                label_smoothing=FLAGS.label_smoothing),
            include_bias_in_weight_decay=FLAGS.use_bias_weight_decay),
        stage_1=hparams.Stage(
            training=hparams.TrainingStage(
                train_epochs=FLAGS.stage_1_epochs,
                learning_rate_warmup_epochs=FLAGS.stage_1_warmup_epochs,
                base_learning_rate=FLAGS.stage_1_base_learning_rate,
                learning_rate_decay=FLAGS.stage_1_learning_rate_decay,
                decay_rate=FLAGS.stage_1_decay_rate,
                decay_boundary_epochs=tuple(
                    map(int, FLAGS.stage_1_decay_boundary_epochs)),
                epochs_per_decay=FLAGS.stage_1_epochs_per_decay,
                optimizer=FLAGS.stage_1_optimizer,
                update_encoder_batch_norm=(
                    FLAGS.stage_1_update_encoder_batch_norm),
                rmsprop_epsilon=FLAGS.stage_1_rmsprop_epsilon),
            loss=hparams.LossStage(
                contrastive_weight=FLAGS.stage_1_contrastive_loss_weight,
                cross_entropy_weight=FLAGS.stage_1_cross_entropy_loss_weight,
                weight_decay_coeff=FLAGS.stage_1_weight_decay,
                use_encoder_weight_decay=FLAGS.
                stage_1_use_encoder_weight_decay,
                use_projection_head_weight_decay=(
                    FLAGS.stage_1_use_projection_head_weight_decay),
                use_classification_head_weight_decay=(
                    FLAGS.stage_1_use_classification_head_weight_decay)),
        ),
        stage_2=hparams.Stage(
            training=hparams.TrainingStage(
                train_epochs=FLAGS.stage_2_epochs,
                learning_rate_warmup_epochs=FLAGS.stage_2_warmup_epochs,
                base_learning_rate=FLAGS.stage_2_base_learning_rate,
                learning_rate_decay=FLAGS.stage_2_learning_rate_decay,
                decay_rate=FLAGS.stage_2_decay_rate,
                decay_boundary_epochs=tuple(
                    map(int, FLAGS.stage_2_decay_boundary_epochs)),
                epochs_per_decay=FLAGS.stage_2_epochs_per_decay,
                optimizer=FLAGS.stage_2_optimizer,
                update_encoder_batch_norm=(
                    FLAGS.stage_2_update_encoder_batch_norm),
                rmsprop_epsilon=FLAGS.stage_2_rmsprop_epsilon),
            loss=hparams.LossStage(
                contrastive_weight=FLAGS.stage_2_contrastive_loss_weight,
                cross_entropy_weight=FLAGS.stage_2_cross_entropy_loss_weight,
                weight_decay_coeff=FLAGS.stage_2_weight_decay,
                use_encoder_weight_decay=FLAGS.
                stage_2_use_encoder_weight_decay,
                use_projection_head_weight_decay=(
                    FLAGS.stage_2_use_projection_head_weight_decay),
                use_classification_head_weight_decay=(
                    FLAGS.stage_2_use_classification_head_weight_decay))),
        eval=hparams.Eval(batch_size=FLAGS.eval_batch_size),
        input_data=hparams.InputData(
            input_fn=FLAGS.input_fn,
            preprocessing=hparams.ImagePreprocessing(
                allow_mixed_precision=FLAGS.allow_mixed_precision,
                image_size=FLAGS.image_size,
                augmentation_type=FLAGS.augmentation_type,
                augmentation_magnitude=FLAGS.augmentation_magnitude,
                blur_probability=FLAGS.blur_probability,
                defer_blurring=FLAGS.defer_blurring,
                use_pytorch_color_jitter=FLAGS.use_pytorch_color_jitter,
                apply_whitening=FLAGS.apply_whitening,
                crop_area_range=tuple(map(float, FLAGS.crop_area_range)),
                eval_crop_method=FLAGS.eval_crop_method,
                crop_padding=FLAGS.crop_padding,
            ),
            max_samples=FLAGS.num_images,
            label_noise_prob=FLAGS.label_noise_prob,
            shard_per_host=FLAGS.shard_per_host),
        warm_start=hparams.WarmStart(
            warm_start_classifier=FLAGS.warm_start_classifier,
            ignore_missing_checkpoint_vars=FLAGS.
            ignore_missing_checkpoint_vars,
            warm_start_projection_head=FLAGS.warm_start_projection_head,
            warm_start_encoder=FLAGS.warm_start_encoder,
            batch_norm_in_train_mode=FLAGS.batch_norm_in_train_mode,
        ),
    )