def test_get_num_train_images(self, input_fn, max_samples, expected_images): params = make_params(input_fn) params['hparams'].input_data.max_samples = max_samples self.assertEqual(inputs.get_num_train_images(params['hparams']), expected_images)
def model_fn(features, labels, mode, params): """Contrastive model function.""" model_mode = utils.estimator_mode_to_model_mode(mode) hparams = params['hparams'] trainer = ContrastiveTrainer( model_inputs=features, labels=labels, train_global_batch_size=hparams.bs, hparams=hparams, mode=model_mode, num_classes=inputs.get_num_classes(hparams), training_set_size=inputs.get_num_train_images(hparams), is_tpu=params['use_tpu']) if mode == tf.estimator.ModeKeys.PREDICT: predictions_map = trainer.signature_def_map() exports = { k: tf.estimator.export.PredictOutput(v) for k, v in predictions_map.items() } # Export a default SignatureDef to keep the API happy. exports[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( exports['contrastive_eval']) spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions_map['contrastive_eval'], export_outputs=exports) return spec # We directly write summaries for the relevant losses, so just hard-code a # dummy value to keep the Estimator API happy. loss = tf.constant(0.) if mode == tf.estimator.ModeKeys.EVAL: spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=trainer.eval_metrics()) return spec else: # TRAIN spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, train_op=trainer.train_op(), loss=loss, scaffold_fn=trainer.scaffold_fn(), host_call=trainer.host_call(FLAGS.model_dir)) return spec
def main(_): tf.disable_v2_behavior() tf.enable_resource_variables() if FLAGS.hparams is None: hparams = hparams_flags.hparams_from_flags() else: hparams = hparams_lib.HParams(FLAGS.hparams) cluster = None if FLAGS.use_tpu and FLAGS.master is None: if FLAGS.tpu_name: cluster = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) else: cluster = tf.distribute.cluster_resolver.TPUClusterResolver() tf.config.experimental_connect_to_cluster(cluster) tf.tpu.experimental.initialize_tpu_system(cluster) session_config = tf.ConfigProto() # Workaround for https://github.com/tensorflow/tensorflow/issues/26411 where # convolutions (used in blurring) get confused about data-format when used # inside a tf.data pipeline that is run on GPU. if (tf.test.is_built_with_cuda() and not hparams.input_data.preprocessing.defer_blurring): # RewriterConfig.OFF = 2 session_config.graph_options.rewrite_options.layout_optimizer = 2 run_config = tf.estimator.tpu.RunConfig( master=FLAGS.master, cluster=cluster, model_dir=FLAGS.model_dir, save_checkpoints_steps=FLAGS.save_interval_steps, keep_checkpoint_max=FLAGS.max_checkpoints_to_keep, keep_checkpoint_every_n_hours=(FLAGS.keep_checkpoint_interval_secs / (60.0 * 60.0)), log_step_count_steps=100, session_config=session_config, tpu_config=tf.estimator.tpu.TPUConfig( iterations_per_loop=FLAGS.steps_per_loop, per_host_input_for_training=True, experimental_host_call_every_n_steps=FLAGS.summary_interval_steps, tpu_job_name='train_tpu_worker' if FLAGS.mode == 'train' else None, eval_training_input_configuration=( tf.estimator.tpu.InputPipelineConfig.SLICED if FLAGS.use_tpu else tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1))) params = { 'hparams': hparams, 'use_tpu': FLAGS.use_tpu, 'data_dir': FLAGS.data_dir, } estimator = tf.estimator.tpu.TPUEstimator( model_fn=model_fn, use_tpu=FLAGS.use_tpu, config=run_config, params=params, train_batch_size=hparams.bs, eval_batch_size=hparams.eval.batch_size) if hparams.input_data.input_fn not in dir(inputs): raise ValueError('Unknown input_fn: {hparams.input_data.input_fn}') input_fn = getattr(inputs, hparams.input_data.input_fn) training_set_size = inputs.get_num_train_images(hparams) steps_per_epoch = training_set_size / hparams.bs stage_1_epochs = hparams.stage_1.training.train_epochs stage_2_epochs = hparams.stage_2.training.train_epochs total_steps = int((stage_1_epochs + stage_2_epochs) * steps_per_epoch) num_eval_examples = inputs.get_num_eval_images(hparams) eval_steps = num_eval_examples // hparams.eval.batch_size if FLAGS.mode == 'eval': for ckpt_str in tf.train.checkpoints_iterator( FLAGS.model_dir, min_interval_secs=FLAGS.eval_interval_secs, timeout=60 * 60): result = estimator.evaluate(input_fn=input_fn, checkpoint_path=ckpt_str, steps=eval_steps) estimator.export_saved_model( os.path.join(FLAGS.model_dir, 'exports'), lambda: input_fn(tf.estimator.ModeKeys.PREDICT, params), checkpoint_path=ckpt_str) if result['global_step'] >= total_steps: return else: # 'train' or 'train_then_eval'. estimator.train(input_fn=input_fn, max_steps=total_steps) if FLAGS.mode == 'train_then_eval': result = estimator.evaluate(input_fn=input_fn, steps=eval_steps) estimator.export_saved_model( os.path.join(FLAGS.model_dir, 'exports'), lambda: input_fn(tf.estimator.ModeKeys.PREDICT, params))