def imagenet(mode, params): """An input_fn for ImageNet (ILSVRC 2012) data.""" model_mode = utils.estimator_mode_to_model_mode(mode) hparams = params['hparams'] is_training = model_mode == enums.ModelMode.TRAIN preprocessor = preprocessing.ImageToMultiViewedImagePreprocessor( is_training=is_training, preprocessing_options=hparams.input_data.preprocessing, dataset_options=preprocessing.DatasetOptions(decode_input=False), bfloat16_supported=params['use_tpu']) imagenet_input = TfdsInput( dataset_name='imagenet2012:5.*.*', split='train' if is_training else 'validation', mode=model_mode, preprocessor=preprocessor, shuffle_buffer=1024, shard_per_host=hparams.input_data.shard_per_host, cache=is_training, num_parallel_calls=64, max_samples=hparams.input_data.max_samples, label_noise_prob=hparams.input_data.label_noise_prob, num_classes=get_num_classes(hparams), data_dir=params['data_dir'], ) return imagenet_input.input_fn(params)
def cifar10(mode, params): """CIFAR10 dataset creator.""" # Images are naturally 32x32. model_mode = utils.estimator_mode_to_model_mode(mode) hparams = params['hparams'] is_training = model_mode == enums.ModelMode.TRAIN preprocessor = preprocessing.ImageToMultiViewedImagePreprocessor( is_training=is_training, preprocessing_options=hparams.input_data.preprocessing, dataset_options=preprocessing.DatasetOptions( decode_input=False, image_mean_std=(np.array([[[-0.0172, -0.0356, -0.107]]]), np.array([[[0.4046, 0.3988, 0.402]]]))), bfloat16_supported=params['use_tpu']) cifar_input = TfdsInput( dataset_name='cifar10:3.*.*', split='train' if is_training else 'test', mode=model_mode, preprocessor=preprocessor, shard_per_host=hparams.input_data.shard_per_host, cache=is_training, shuffle_buffer=50000, num_parallel_calls=64, max_samples=hparams.input_data.max_samples, label_noise_prob=hparams.input_data.label_noise_prob, num_classes=get_num_classes(hparams), data_dir=params['data_dir'], ) return cifar_input.input_fn(params)
def model_fn(features, labels, mode, params): """Contrastive model function.""" model_mode = utils.estimator_mode_to_model_mode(mode) hparams = params['hparams'] trainer = ContrastiveTrainer( model_inputs=features, labels=labels, train_global_batch_size=hparams.bs, hparams=hparams, mode=model_mode, num_classes=inputs.get_num_classes(hparams), training_set_size=inputs.get_num_train_images(hparams), is_tpu=params['use_tpu']) if mode == tf.estimator.ModeKeys.PREDICT: predictions_map = trainer.signature_def_map() exports = { k: tf.estimator.export.PredictOutput(v) for k, v in predictions_map.items() } # Export a default SignatureDef to keep the API happy. exports[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( exports['contrastive_eval']) spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions_map['contrastive_eval'], export_outputs=exports) return spec # We directly write summaries for the relevant losses, so just hard-code a # dummy value to keep the Estimator API happy. loss = tf.constant(0.) if mode == tf.estimator.ModeKeys.EVAL: spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=trainer.eval_metrics()) return spec else: # TRAIN spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, train_op=trainer.train_op(), loss=loss, scaffold_fn=trainer.scaffold_fn(), host_call=trainer.host_call(FLAGS.model_dir)) return spec