def test_preprocess_image(self, decode_image, image_dtype, image_size, is_training, augmentation_type, warp_prob, augmentation_magnitude, eval_crop_method, image_mean_std): # tf.random.uniform() doesn't allow generating random values that are uint8. image = tf.cast( tf.random.uniform( shape=(300, 400, 3), minval=0, maxval=255, dtype=tf.float32), dtype=tf.uint8) if decode_image: image = tf.image.encode_jpeg(image) else: image = tf.cast(image, image_dtype) expect_error = (not decode_image and image_dtype != tf.uint8) if expect_error: context_manager = self.assertRaises(AssertionError) else: context_manager = nullcontext() with context_manager: output = preprocessing.preprocess_image( image, is_training=is_training, bfloat16_supported=False, preprocessing_options=hparams.ImagePreprocessing( image_size=image_size, augmentation_type=augmentation_type, warp_probability=warp_prob, augmentation_magnitude=augmentation_magnitude, eval_crop_method=eval_crop_method, ), dataset_options=preprocessing.DatasetOptions( image_mean_std=image_mean_std, decode_input=decode_image)) self.assertEqual(output.dtype, tf.float32) self.assertEqual([image_size, image_size, 3], output.shape.as_list())
def test_input_class(self, input_class, model_mode, image_size, max_samples): split = 'train' if model_mode == enums.ModelMode.TRAIN else 'test' batch_size = 2 dataset_size = 10 expected_num_batches = dataset_size // batch_size if max_samples is not None and model_mode == enums.ModelMode.TRAIN: expected_num_batches = max_samples // batch_size params = {'batch_size': batch_size} if input_class == 'TfdsInput': with tfds.testing.mock_data(num_examples=dataset_size): data = inputs.TfdsInput( 'cifar10', split, mode=model_mode, preprocessor=preprocessing. ImageToMultiViewedImagePreprocessor( is_training=model_mode == enums.ModelMode.TRAIN, preprocessing_options=hparams.ImagePreprocessing( image_size=image_size, num_views=2), dataset_options=preprocessing.DatasetOptions( decode_input=False)), max_samples=max_samples, num_classes=10).input_fn(params) else: raise ValueError(f'Unknown input class {input_class}') expected_num_channels = 3 if model_mode == enums.ModelMode.INFERENCE else 6 expected_batch_size = (None if model_mode == enums.ModelMode.INFERENCE else batch_size) if model_mode == enums.ModelMode.INFERENCE: self.assertIsInstance( data, tf.estimator.export.TensorServingInputReceiver) image_shape = data.features.shape.as_list() else: self.assertIsInstance(data, tf.data.Dataset) shapes = tf.data.get_output_shapes(data) image_shape = shapes[0].as_list() label_shape = shapes[1].as_list() self.assertEqual([expected_batch_size], label_shape) self.assertEqual([ expected_batch_size, image_size, image_size, expected_num_channels ], image_shape) if model_mode == enums.ModelMode.INFERENCE: return # Now extract the Tensors data = tf.data.make_one_shot_iterator(data).get_next()[0] with self.cached_session() as sess: for i in range(expected_num_batches + 1): if i == expected_num_batches and model_mode == enums.ModelMode.EVAL: with self.assertRaises(tf.errors.OutOfRangeError): sess.run(data) break else: sess.run(data)
def test_image_to_multi_viewed_image_preprocessor(self, is_training, decode_image, num_views): # tf.random.uniform() doesn't allow generating random values that are uint8. image = tf.cast( tf.random.uniform( shape=(300, 400, 3), minval=0, maxval=255, dtype=tf.float32), dtype=tf.uint8) if decode_image: image = tf.image.encode_jpeg(image) image_size = 32 preprocessor = preprocessing.ImageToMultiViewedImagePreprocessor( is_training=is_training, preprocessing_options=hparams.ImagePreprocessing( image_size=image_size, num_views=num_views), dataset_options=preprocessing.DatasetOptions(decode_input=decode_image)) output = preprocessor.preprocess(image) self.assertEqual(output.dtype, tf.float32) self.assertEqual([image_size, image_size, 3 * num_views], output.shape.as_list())
def hparams_from_flags(): return hparams.HParams( bs=FLAGS.batch_size, architecture=hparams.Architecture( encoder_architecture=FLAGS.resnet_architecture, encoder_depth=FLAGS.resnet_depth, encoder_width=FLAGS.resnet_width, first_conv_kernel_size=FLAGS.first_conv_kernel_size, first_conv_stride=FLAGS.first_conv_stride, use_initial_max_pool=FLAGS.use_initial_max_pool, projection_head_layers=tuple(map(int, FLAGS.projection_head_layers)), projection_head_use_batch_norm=FLAGS.use_projection_batch_norm, projection_head_use_batch_norm_beta=( FLAGS.use_projection_batch_norm_beta), normalize_projection_head_inputs=FLAGS.normalize_embedding, normalize_classifier_inputs=FLAGS.normalize_embedding, zero_initialize_classifier=FLAGS.zero_initialize_classifier, stop_gradient_before_classification_head=( FLAGS.stop_gradient_before_classification_head), stop_gradient_before_projection_head=( FLAGS.stop_gradient_before_projection_head), use_global_batch_norm=FLAGS.use_global_batch_norm), loss_all_stages=hparams.LossAllStages( contrastive=hparams.ContrastiveLoss( use_labels=FLAGS.use_labels, temperature=FLAGS.temperature, contrast_mode=FLAGS.contrast_mode, summation_location=FLAGS.summation_location, denominator_mode=FLAGS.denominator_mode, positives_cap=FLAGS.positives_cap, scale_by_temperature=FLAGS.scale_by_temperature), cross_entropy=hparams.CrossEntropyLoss( label_smoothing=FLAGS.label_smoothing), include_bias_in_weight_decay=FLAGS.use_bias_weight_decay), stage_1=hparams.Stage( training=hparams.TrainingStage( train_epochs=FLAGS.stage_1_epochs, learning_rate_warmup_epochs=FLAGS.stage_1_warmup_epochs, base_learning_rate=FLAGS.stage_1_base_learning_rate, learning_rate_decay=FLAGS.stage_1_learning_rate_decay, decay_rate=FLAGS.stage_1_decay_rate, decay_boundary_epochs=tuple( map(int, FLAGS.stage_1_decay_boundary_epochs)), epochs_per_decay=FLAGS.stage_1_epochs_per_decay, optimizer=FLAGS.stage_1_optimizer, update_encoder_batch_norm=( FLAGS.stage_1_update_encoder_batch_norm), rmsprop_epsilon=FLAGS.stage_1_rmsprop_epsilon), loss=hparams.LossStage( contrastive_weight=FLAGS.stage_1_contrastive_loss_weight, cross_entropy_weight=FLAGS.stage_1_cross_entropy_loss_weight, weight_decay_coeff=FLAGS.stage_1_weight_decay, use_encoder_weight_decay=FLAGS. stage_1_use_encoder_weight_decay, use_projection_head_weight_decay=( FLAGS.stage_1_use_projection_head_weight_decay), use_classification_head_weight_decay=( FLAGS.stage_1_use_classification_head_weight_decay)), ), stage_2=hparams.Stage( training=hparams.TrainingStage( train_epochs=FLAGS.stage_2_epochs, learning_rate_warmup_epochs=FLAGS.stage_2_warmup_epochs, base_learning_rate=FLAGS.stage_2_base_learning_rate, learning_rate_decay=FLAGS.stage_2_learning_rate_decay, decay_rate=FLAGS.stage_2_decay_rate, decay_boundary_epochs=tuple( map(int, FLAGS.stage_2_decay_boundary_epochs)), epochs_per_decay=FLAGS.stage_2_epochs_per_decay, optimizer=FLAGS.stage_2_optimizer, update_encoder_batch_norm=( FLAGS.stage_2_update_encoder_batch_norm), rmsprop_epsilon=FLAGS.stage_2_rmsprop_epsilon), loss=hparams.LossStage( contrastive_weight=FLAGS.stage_2_contrastive_loss_weight, cross_entropy_weight=FLAGS.stage_2_cross_entropy_loss_weight, weight_decay_coeff=FLAGS.stage_2_weight_decay, use_encoder_weight_decay=FLAGS. stage_2_use_encoder_weight_decay, use_projection_head_weight_decay=( FLAGS.stage_2_use_projection_head_weight_decay), use_classification_head_weight_decay=( FLAGS.stage_2_use_classification_head_weight_decay))), eval=hparams.Eval(batch_size=FLAGS.eval_batch_size), input_data=hparams.InputData( input_fn=FLAGS.input_fn, preprocessing=hparams.ImagePreprocessing( allow_mixed_precision=FLAGS.allow_mixed_precision, image_size=FLAGS.image_size, augmentation_type=FLAGS.augmentation_type, augmentation_magnitude=FLAGS.augmentation_magnitude, blur_probability=FLAGS.blur_probability, defer_blurring=FLAGS.defer_blurring, use_pytorch_color_jitter=FLAGS.use_pytorch_color_jitter, apply_whitening=FLAGS.apply_whitening, crop_area_range=tuple(map(float, FLAGS.crop_area_range)), eval_crop_method=FLAGS.eval_crop_method, crop_padding=FLAGS.crop_padding, ), max_samples=FLAGS.num_images, label_noise_prob=FLAGS.label_noise_prob, shard_per_host=FLAGS.shard_per_host), warm_start=hparams.WarmStart( warm_start_classifier=FLAGS.warm_start_classifier, ignore_missing_checkpoint_vars=FLAGS. ignore_missing_checkpoint_vars, warm_start_projection_head=FLAGS.warm_start_projection_head, warm_start_encoder=FLAGS.warm_start_encoder, batch_norm_in_train_mode=FLAGS.batch_norm_in_train_mode, ), )