Exemple #1
0
    def test_get_num_train_images(self, input_fn, max_samples,
                                  expected_images):
        params = make_params(input_fn)
        params['hparams'].input_data.max_samples = max_samples

        self.assertEqual(inputs.get_num_train_images(params['hparams']),
                         expected_images)
Exemple #2
0
def model_fn(features, labels, mode, params):
    """Contrastive model function."""

    model_mode = utils.estimator_mode_to_model_mode(mode)
    hparams = params['hparams']

    trainer = ContrastiveTrainer(
        model_inputs=features,
        labels=labels,
        train_global_batch_size=hparams.bs,
        hparams=hparams,
        mode=model_mode,
        num_classes=inputs.get_num_classes(hparams),
        training_set_size=inputs.get_num_train_images(hparams),
        is_tpu=params['use_tpu'])

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions_map = trainer.signature_def_map()
        exports = {
            k: tf.estimator.export.PredictOutput(v)
            for k, v in predictions_map.items()
        }
        # Export a default SignatureDef to keep the API happy.
        exports[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
            exports['contrastive_eval'])
        spec = tf.estimator.tpu.TPUEstimatorSpec(
            mode=mode,
            predictions=predictions_map['contrastive_eval'],
            export_outputs=exports)
        return spec

    # We directly write summaries for the relevant losses, so just hard-code a
    # dummy value to keep the Estimator API happy.
    loss = tf.constant(0.)

    if mode == tf.estimator.ModeKeys.EVAL:
        spec = tf.estimator.tpu.TPUEstimatorSpec(
            mode=mode, loss=loss, eval_metrics=trainer.eval_metrics())
        return spec
    else:  # TRAIN
        spec = tf.estimator.tpu.TPUEstimatorSpec(
            mode=mode,
            train_op=trainer.train_op(),
            loss=loss,
            scaffold_fn=trainer.scaffold_fn(),
            host_call=trainer.host_call(FLAGS.model_dir))
        return spec
Exemple #3
0
def main(_):
    tf.disable_v2_behavior()
    tf.enable_resource_variables()

    if FLAGS.hparams is None:
        hparams = hparams_flags.hparams_from_flags()
    else:
        hparams = hparams_lib.HParams(FLAGS.hparams)

    cluster = None
    if FLAGS.use_tpu and FLAGS.master is None:
        if FLAGS.tpu_name:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        else:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
            tf.config.experimental_connect_to_cluster(cluster)
            tf.tpu.experimental.initialize_tpu_system(cluster)

    session_config = tf.ConfigProto()
    # Workaround for https://github.com/tensorflow/tensorflow/issues/26411 where
    # convolutions (used in blurring) get confused about data-format when used
    # inside a tf.data pipeline that is run on GPU.
    if (tf.test.is_built_with_cuda()
            and not hparams.input_data.preprocessing.defer_blurring):
        # RewriterConfig.OFF = 2
        session_config.graph_options.rewrite_options.layout_optimizer = 2
    run_config = tf.estimator.tpu.RunConfig(
        master=FLAGS.master,
        cluster=cluster,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=FLAGS.save_interval_steps,
        keep_checkpoint_max=FLAGS.max_checkpoints_to_keep,
        keep_checkpoint_every_n_hours=(FLAGS.keep_checkpoint_interval_secs /
                                       (60.0 * 60.0)),
        log_step_count_steps=100,
        session_config=session_config,
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.steps_per_loop,
            per_host_input_for_training=True,
            experimental_host_call_every_n_steps=FLAGS.summary_interval_steps,
            tpu_job_name='train_tpu_worker' if FLAGS.mode == 'train' else None,
            eval_training_input_configuration=(
                tf.estimator.tpu.InputPipelineConfig.SLICED if FLAGS.use_tpu
                else tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1)))
    params = {
        'hparams': hparams,
        'use_tpu': FLAGS.use_tpu,
        'data_dir': FLAGS.data_dir,
    }
    estimator = tf.estimator.tpu.TPUEstimator(
        model_fn=model_fn,
        use_tpu=FLAGS.use_tpu,
        config=run_config,
        params=params,
        train_batch_size=hparams.bs,
        eval_batch_size=hparams.eval.batch_size)

    if hparams.input_data.input_fn not in dir(inputs):
        raise ValueError('Unknown input_fn: {hparams.input_data.input_fn}')
    input_fn = getattr(inputs, hparams.input_data.input_fn)

    training_set_size = inputs.get_num_train_images(hparams)
    steps_per_epoch = training_set_size / hparams.bs
    stage_1_epochs = hparams.stage_1.training.train_epochs
    stage_2_epochs = hparams.stage_2.training.train_epochs
    total_steps = int((stage_1_epochs + stage_2_epochs) * steps_per_epoch)

    num_eval_examples = inputs.get_num_eval_images(hparams)
    eval_steps = num_eval_examples // hparams.eval.batch_size

    if FLAGS.mode == 'eval':
        for ckpt_str in tf.train.checkpoints_iterator(
                FLAGS.model_dir,
                min_interval_secs=FLAGS.eval_interval_secs,
                timeout=60 * 60):
            result = estimator.evaluate(input_fn=input_fn,
                                        checkpoint_path=ckpt_str,
                                        steps=eval_steps)
            estimator.export_saved_model(
                os.path.join(FLAGS.model_dir, 'exports'),
                lambda: input_fn(tf.estimator.ModeKeys.PREDICT, params),
                checkpoint_path=ckpt_str)
            if result['global_step'] >= total_steps:
                return
    else:  # 'train' or 'train_then_eval'.
        estimator.train(input_fn=input_fn, max_steps=total_steps)
        if FLAGS.mode == 'train_then_eval':
            result = estimator.evaluate(input_fn=input_fn, steps=eval_steps)
            estimator.export_saved_model(
                os.path.join(FLAGS.model_dir, 'exports'),
                lambda: input_fn(tf.estimator.ModeKeys.PREDICT, params))