def train(flags_obj, model_function, dataset_name):
    run_config = tf.estimator.RunConfig(save_checkpoints_steps=100000,
                                        keep_checkpoint_max=1000)

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        params={
            'num_classes': flags_obj.num_classes,
            'vocab_size': flags_obj.vocab_size,
            'embedding_dim': flags_obj.embedding_dim,
            'mlp_dim': flags_obj.mlp_dim,
            'kmer': flags_obj.kmer,
            'max_len': flags_obj.max_len,
            'lr': flags_obj.lr,
            'lr_decay': flags_obj.lr_decay,
            'cnn_num_filters': flags_obj.cnn_num_filters,
            'cnn_filter_sizes': flags_obj.cnn_filter_sizes,
            'lstm_dim': flags_obj.lstm_dim,
            'pooling_type': flags_obj.pooling_type,
            'row': flags_obj.row,
            'da': flags_obj.da,
            'keep_prob': flags_obj.keep_prob
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'train_epochs': flags_obj.train_epochs,
    }
    benchmark_logger = logger.config_benchmark_logger(flags_obj)
    benchmark_logger.log_run_info('model', dataset_name, run_params)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train():
        if flags_obj.encode_method == 'kmer':
            input_fn = input_function_train_kmer(flags_obj.input_tfrec,
                                                 flags_obj.train_epochs,
                                                 flags_obj.batch_size,
                                                 flags_obj.cpus)
            if flags_obj.model_name in [
                    'embed_pool', 'embed_cnn', 'embed_lstm',
                    'embed_cnn_no_pool'
            ]:
                input_fn = input_function_train_kmer_pad_to_fixed_len(
                    flags_obj.input_tfrec, flags_obj.train_epochs,
                    flags_obj.batch_size, flags_obj.cpus, flags_obj.max_len,
                    flags_obj.kmer)
        else:
            input_fn = input_function_train_one_hot(flags_obj.input_tfrec,
                                                    flags_obj.train_epochs,
                                                    flags_obj.batch_size,
                                                    flags_obj.cpus,
                                                    flags_obj.max_len)

        return input_fn

    classifier.train(input_fn=input_fn_train, hooks=train_hooks)
 def validate_train_hook_name(self,
                              test_hook_name,
                              expected_hook_name,
                              **kwargs):
   returned_hook = hooks_helper.get_train_hooks([test_hook_name], **kwargs)
   self.assertEqual(len(returned_hook), 1)
   self.assertIsInstance(returned_hook[0], tf.train.SessionRunHook)
   self.assertEqual(returned_hook[0].__class__.__name__.lower(),
                    expected_hook_name)
Beispiel #3
0
def run_loop(name,
             train_input_fn,
             eval_input_fn,
             model_column_fn,
             build_estimator_fn,
             flags_obj,
             tensors_to_log,
             early_stop=False):
    """Define training loop."""
    model_helpers.apply_clean(flags.FLAGS)
    model = build_estimator_fn(model_dir=flags_obj.model_dir,
                               model_type=flags_obj.model_type,
                               model_column_fn=model_column_fn,
                               inter_op=flags_obj.inter_op_parallelism_threads,
                               intra_op=flags_obj.intra_op_parallelism_threads)

    run_params = {
        'batch_size': flags_obj.batch_size,
        'train_epochs': flags_obj.train_epochs,
        'model_type': flags_obj.model_type,
    }

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('wide_deep',
                                  name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')
    tensors_to_log = {
        k: v.format(loss_prefix=loss_prefix)
        for k, v in tensors_to_log.items()
    }
    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size,
                                               tensors_to_log=tensors_to_log)

    # Train and evaluate the model every `flags.epochs_between_evals` epochs.
    for n in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
        if flags_obj.infer == False:
            model.train(input_fn=train_input_fn, hooks=train_hooks)

        results = model.evaluate(input_fn=eval_input_fn)

        # Display evaluation metrics
        tf.logging.info('Results at epoch %d / %d',
                        (n + 1) * flags_obj.epochs_between_evals,
                        flags_obj.train_epochs)
        tf.logging.info('-' * 60)

        for key in sorted(results):
            tf.logging.info('%s: %s' % (key, results[key]))

        benchmark_logger.log_evaluation_result(results)

        if early_stop and model_helpers.past_stop_threshold(
                flags_obj.stop_threshold, results['accuracy']):
            break

    # Export the model
    if flags_obj.export_dir is not None:
        export_model(model, flags_obj.model_type, flags_obj.export_dir,
                     model_column_fn)
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):

    model_helpers.apply_clean(flags.FLAGS)

    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config)

    if flags_obj.pretrained_model_checkpoint_path is not None:
        warm_start_settings = tf.estimator.WarmStartSettings(
            flags_obj.pretrained_model_checkpoint_path,
            vars_to_warm_start='^(?!.*dense)')
    else:
        warm_start_settings = None

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        warm_start_from=warm_start_settings,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj),
            'fine_tune': flags_obj.fine_tune
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train(num_epochs):
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            num_gpus=flags_core.get_num_gpus(flags_obj),
            dtype=flags_core.get_tf_dtype(flags_obj))

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1,
            dtype=flags_core.get_tf_dtype(flags_obj))

    if flags_obj.eval_only or not flags_obj.train_epochs:
        schedule, n_loops = [0], 1
    else:

        n_loops = math.ceil(flags_obj.train_epochs /
                            flags_obj.epochs_between_evals)
        schedule = [
            flags_obj.epochs_between_evals for _ in range(int(n_loops))
        ]
        schedule[-1] = flags_obj.train_epochs - sum(
            schedule[:-1])  # over counting.

    for cycle_index, num_train_epochs in enumerate(schedule):
        tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

        if num_train_epochs:
            classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                             hooks=train_hooks,
                             max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    if flags_obj.export_dir is not None:
        dtype = flags_core.get_tf_dtype(flags_obj)
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags_obj.batch_size, dtype=dtype)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Beispiel #5
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

    model_helpers.apply_clean(flags.FLAGS)

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config)

    # initialize our model with all but the dense layer from pretrained resnet
    if flags_obj.pretrained_model_checkpoint_path is not None:
        warm_start_settings = tf.estimator.WarmStartSettings(
            flags_obj.pretrained_model_checkpoint_path,
            vars_to_warm_start='^(?!.*dense)')
    else:
        warm_start_settings = None

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        warm_start_from=warm_start_settings,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj),
            'fine_tune': flags_obj.fine_tune
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train(num_epochs):
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            num_gpus=flags_core.get_num_gpus(flags_obj),
            dtype=flags_core.get_tf_dtype(flags_obj))

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1,
            dtype=flags_core.get_tf_dtype(flags_obj))

    if flags_obj.eval_only or not flags_obj.train_epochs:
        # If --eval_only is set, perform a single loop with zero train epochs.
        schedule, n_loops = [0], 1
    else:
        # Compute the number of times to loop while training. All but the last
        # pass will train for `epochs_between_evals` epochs, while the last will
        # train for the number needed to reach `training_epochs`. For instance if
        #   train_epochs = 25 and epochs_between_evals = 10
        # schedule will be set to [10, 10, 5]. That is to say, the loop will:
        #   Train for 10 epochs and then evaluate.
        #   Train for another 10 epochs and then evaluate.
        #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
        n_loops = math.ceil(flags_obj.train_epochs /
                            flags_obj.epochs_between_evals)
        schedule = [
            flags_obj.epochs_between_evals for _ in range(int(n_loops))
        ]
        schedule[-1] = flags_obj.train_epochs - sum(
            schedule[:-1])  # over counting.

    for cycle_index, num_train_epochs in enumerate(schedule):
        tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

        if num_train_epochs:
            classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                             hooks=train_hooks,
                             max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        dtype = flags_core.get_tf_dtype(flags_obj)
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags_obj.batch_size, dtype=dtype)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
 def test_raise_in_invalid_names(self):
   invalid_names = ['StepCounterHook', 'StopAtStepHook']
   with self.assertRaises(ValueError):
     hooks_helper.get_train_hooks(invalid_names, batch_size=256)
Beispiel #7
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.

  Returns:
     Dict of results of the run.  Contains the keys `eval_results` and
    `train_hooks`. `eval_results` contains accuracy (top_1) and accuracy_top_5.
    `train_hooks` is a list the instances of hooks used during training.
  """

    model_helpers.apply_clean(flags.FLAGS)

    # Ensures flag override logic is only executed if explicitly triggered.
    if flags_obj.tf_gpu_thread_mode:
        override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

    # Configures cluster spec for distribution strategy.
    num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                                       flags_obj.task_index)

    # Creates session config. allow_soft_placement = True, is required for
    # multi-GPU and is not harmful for other modes.
    session_config = tf.compat.v1.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_core.get_num_gpus(flags_obj),
        num_workers=num_workers,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs)

    # Creates a `RunConfig` that checkpoints every 24 hours which essentially
    # results in checkpoints determined only by `epochs_between_evals`.
    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config,
                                        save_checkpoints_secs=None,
                                        save_checkpoints_steps=2000)

    # Initializes model with all but the dense layer from pretrained ResNet.
    if flags_obj.pretrained_model_checkpoint_path is not None:
        warm_start_settings = tf.estimator.WarmStartSettings(
            flags_obj.pretrained_model_checkpoint_path,
            vars_to_warm_start='^(?!.*dense)')
    else:
        warm_start_settings = None

    classifier = tf.estimator.Estimator(model_fn=model_function,
                                        model_dir=flags_obj.model_dir,
                                        config=run_config,
                                        warm_start_from=warm_start_settings,
                                        params={
                                            'resnet_size':
                                            int(flags_obj.resnet_size),
                                            'data_format':
                                            flags_obj.data_format,
                                            'batch_size':
                                            flags_obj.batch_size,
                                            'resnet_version':
                                            int(flags_obj.resnet_version),
                                            'loss_scale':
                                            flags_core.get_loss_scale(
                                                flags_obj,
                                                default_for_fp16=128),
                                            'dtype':
                                            flags_core.get_tf_dtype(flags_obj),
                                            'fine_tune':
                                            flags_obj.fine_tune,
                                            'num_workers':
                                            num_workers,
                                        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
        'num_workers': num_workers,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train(num_epochs, input_context=None):
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_replica_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            dtype=flags_core.get_tf_dtype(flags_obj),
            datasets_num_private_threads=flags_obj.
            datasets_num_private_threads,
            input_context=input_context)

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_replica_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1,
            dtype=flags_core.get_tf_dtype(flags_obj))

    train_epochs = (0 if flags_obj.eval_only or not flags_obj.train_epochs else
                    flags_obj.train_epochs)

    use_train_and_evaluate = flags_obj.use_train_and_evaluate or num_workers > 1
    if use_train_and_evaluate:
        train_spec = tf.estimator.TrainSpec(
            input_fn=lambda input_context=None: input_fn_train(
                train_epochs, input_context=input_context),
            hooks=train_hooks,
            max_steps=flags_obj.max_train_steps)
        eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_eval)
        tf.compat.v1.logging.info('Starting to train and evaluate.')
        tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
        # tf.estimator.train_and_evalute doesn't return anything in multi-worker
        # case.
        eval_results = {}
    else:
        if train_epochs == 0:
            # If --eval_only is set, perform a single loop with zero train epochs.
            schedule, n_loops = [0], 1
        else:
            # Compute the number of times to loop while training. All but the last
            # pass will train for `epochs_between_evals` epochs, while the last will
            # train for the number needed to reach `training_epochs`. For instance if
            #   train_epochs = 25 and epochs_between_evals = 10
            # schedule will be set to [10, 10, 5]. That is to say, the loop will:
            #   Train for 10 epochs and then evaluate.
            #   Train for another 10 epochs and then evaluate.
            #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
            n_loops = math.ceil(train_epochs / flags_obj.epochs_between_evals)
            schedule = [
                flags_obj.epochs_between_evals for _ in range(int(n_loops))
            ]
            schedule[-1] = train_epochs - sum(schedule[:-1])  # over counting.

        for cycle_index, num_train_epochs in enumerate(schedule):
            tf.compat.v1.logging.info('Starting cycle: %d/%d', cycle_index,
                                      int(n_loops))

            if num_train_epochs:
                # Since we are calling classifier.train immediately in each loop, the
                # value of num_train_epochs in the lambda function will not be changed
                # before it is used. So it is safe to ignore the pylint error here
                # pylint: disable=cell-var-from-loop
                classifier.train(
                    input_fn=lambda input_context=None: input_fn_train(
                        num_train_epochs, input_context=input_context),
                    hooks=train_hooks,
                    max_steps=flags_obj.max_train_steps)

            # flags_obj.max_train_steps is generally associated with testing and
            # profiling. As a result it is frequently called with synthetic data,
            # which will iterate forever. Passing steps=flags_obj.max_train_steps
            # allows the eval (which is generally unimportant in those circumstances)
            # to terminate.  Note that eval will run for max_train_steps each loop,
            # regardless of the global_step count.
            tf.compat.v1.logging.info('Starting to evaluate.')
            eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                               steps=flags_obj.max_train_steps)

            benchmark_logger.log_evaluation_result(eval_results)

            if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                                 eval_results['accuracy']):
                break

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        export_dtype = flags_core.get_tf_dtype(flags_obj)
        if flags_obj.image_bytes_as_serving_input:
            input_receiver_fn = functools.partial(image_bytes_serving_input_fn,
                                                  shape,
                                                  dtype=export_dtype)
        else:
            input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
                shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
        classifier.export_savedmodel(flags_obj.export_dir,
                                     input_receiver_fn,
                                     strip_default_attrs=True)

    stats = {}
    stats['eval_results'] = eval_results
    stats['train_hooks'] = train_hooks

    return stats
Beispiel #8
0
def run_transformer(flags_obj):
  """Create tf.Estimator to train and evaluate transformer model.

  Args:
    flags_obj: Object containing parsed flag values.
  """
  num_gpus = flags_core.get_num_gpus(flags_obj)

  # Add flag-defined parameters to params object
  params = PARAMS_MAP[flags_obj.param_set]
  if num_gpus > 1:
    if flags_obj.param_set == "big":
      params = model_params.BIG_MULTI_GPU_PARAMS
    elif flags_obj.param_set == "base":
      params = model_params.BASE_MULTI_GPU_PARAMS

  params["data_dir"] = flags_obj.data_dir
  params["model_dir"] = flags_obj.model_dir
  params["num_parallel_calls"] = flags_obj.num_parallel_calls

  params["tpu"] = flags_obj.tpu
  params["use_tpu"] = bool(flags_obj.tpu)  # was a tpu specified.
  params["static_batch"] = flags_obj.static_batch or params["use_tpu"]
  params["allow_ffn_pad"] = not params["use_tpu"]

  params["use_synthetic_data"] = flags_obj.use_synthetic_data

  # Set batch size parameter, which depends on the availability of
  # TPU and GPU, and distribution settings.
  params["batch_size"] = (flags_obj.batch_size or (
      params["default_batch_size_tpu"] if params["use_tpu"]
      else params["default_batch_size"]))

  if not params["use_tpu"]:
    params["batch_size"] = distribution_utils.per_device_batch_size(
        params["batch_size"], num_gpus)

  schedule_manager = schedule.Manager(
      train_steps=flags_obj.train_steps,
      steps_between_evals=flags_obj.steps_between_evals,
      train_epochs=flags_obj.train_epochs,
      epochs_between_evals=flags_obj.epochs_between_evals,
      default_train_epochs=DEFAULT_TRAIN_EPOCHS,
      batch_size=params["batch_size"],
      max_length=params["max_length"],
      use_tpu=params["use_tpu"],
      num_tpu_shards=flags_obj.num_tpu_shards
  )

  params["repeat_dataset"] = schedule_manager.repeat_dataset

  model_helpers.apply_clean(flags.FLAGS)

  # Create hooks that log information about the training and metric values
  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      model_dir=flags_obj.model_dir,
      tensors_to_log=TENSORS_TO_LOG,  # used for logging hooks
      batch_size=schedule_manager.batch_size,  # for ExamplesPerSecondHook
      use_tpu=params["use_tpu"]  # Not all hooks can run with TPUs
  )
  benchmark_logger = logger.get_benchmark_logger()
  benchmark_logger.log_run_info(
      model_name="transformer",
      dataset_name="wmt_translate_ende",
      run_params=params,
      test_id=flags_obj.benchmark_test_id)

  # Train and evaluate transformer model
  estimator = construct_estimator(flags_obj, params, schedule_manager)
  run_loop(
      estimator=estimator,
      # Training arguments
      schedule_manager=schedule_manager,
      train_hooks=train_hooks,
      benchmark_logger=benchmark_logger,
      # BLEU calculation arguments
      bleu_source=flags_obj.bleu_source,
      bleu_ref=flags_obj.bleu_ref,
      bleu_threshold=flags_obj.stop_threshold,
      vocab_file=flags_obj.vocab_file)

  if flags_obj.export_dir and not params["use_tpu"]:
    serving_input_fn = export.build_tensor_serving_input_receiver_fn(
        shape=[None], dtype=tf.int64, batch_size=None)
    # Export saved model, and save the vocab file as an extra asset. The vocab
    # file is saved to allow consistent input encoding and output decoding.
    # (See the "Export trained model" section in the README for an example of
    # how to use the vocab file.)
    # Since the model itself does not use the vocab file, this file is saved as
    # an extra asset rather than a core asset.
    estimator.export_savedmodel(
        flags_obj.export_dir, serving_input_fn,
        assets_extra={"vocab.txt": flags_obj.vocab_file},
        strip_default_attrs=True)
Beispiel #9
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.

	Args:
		flags_obj: An object containing parsed flags. See define_resnet_flags()
			for details.
		model_function: the function that instantiates the Model and builds the
			ops for train/eval. This will be passed directly into the estimator.
		input_function: the function that processes the dataset and returns a
			dataset that the estimator can train on. This will be wrapped with
			all the relevant flags for running and passed to estimator.
		dataset_name: the name of the dataset for training and evaluation. This is
			used for logging purpose.
		shape: list of ints representing the shape of the images used for training.
			This is only used if flags_obj.export_dir is passed.
	"""

    model_helpers.apply_clean(flags.FLAGS)

    # Ensures flag override logic is only executed if explicitly triggered.
    if flags_obj.tf_gpu_thread_mode:
        override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

    # Creates session config. allow_soft_placement = True, is required for
    # multi-GPU and is not harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    # Creates a `RunConfig` that checkpoints every 24 hours which essentially
    # results in checkpoints determined only by `epochs_between_evals`.
    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config,
                                        save_checkpoints_secs=60 * 60 * 24)

    # Initializes model with all but the dense layer from pretrained ResNet.
    if flags_obj.pretrained_model_checkpoint_path is not None:
        if flags_obj.fine_tune:
            if string.lower(flags_obj.optimizer) == 'adam':
                if flags_obj.no_dense_init:
                    warm_start_settings = tf.estimator.WarmStartSettings(
                        flags_obj.pretrained_model_checkpoint_path,
                        vars_to_warm_start=[
                            '^(?!.*(resnet_model/dense|beta1_power|beta2_power|Adam|global_step))'
                        ])
                    # vars_to_warm_start=['^(?!.*(resnet_model/dense|global_step))'])
                else:
                    warm_start_settings = tf.estimator.WarmStartSettings(
                        flags_obj.pretrained_model_checkpoint_path,
                        vars_to_warm_start=[
                            '^(?!.*(resnet_model/dense/kernel/Momentum|resnet_model/dense/bias/Momentum|beta1_power|beta2_power|Adam|global_step))'
                        ])
                    # vars_to_warm_start=['^(?!.*(resnet_model/dense|global_step))'])
            else:
                if flags_obj.no_dense_init:
                    warm_start_settings = tf.estimator.WarmStartSettings(
                        flags_obj.pretrained_model_checkpoint_path,
                        vars_to_warm_start=[
                            '^(?!.*(resnet_model/dense|Momentum|global_step))'
                        ])
                else:
                    warm_start_settings = tf.estimator.WarmStartSettings(
                        flags_obj.pretrained_model_checkpoint_path,
                        vars_to_warm_start=[
                            '^(?!.*(resnet_model/dense/kernel/Momentum|resnet_model/dense/bias/Momentum|global_step))'
                        ])
                    # vars_to_warm_start=['^(?!.*(resnet_model/dense|global_step))'])
        else:
            if string.lower(flags_obj.optimizer) == 'adam':
                warm_start_settings = tf.estimator.WarmStartSettings(
                    flags_obj.pretrained_model_checkpoint_path,
                    vars_to_warm_start=[
                        '^(?!.*(endecoder|Momentum|beta1_power|beta2_power|global_step))'
                    ])
                # vars_to_warm_start='^(?!.*dense)')
            else:
                warm_start_settings = tf.estimator.WarmStartSettings(
                    flags_obj.pretrained_model_checkpoint_path,
                    vars_to_warm_start=['^(?!.*(endecoder|global_step))'])
                # vars_to_warm_start='^(?!.*dense)')
    else:
        warm_start_settings = None

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        warm_start_from=warm_start_settings,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj),
            'fine_tune': flags_obj.fine_tune,
            'reconst_loss_scale': flags_obj.reconst_loss_scale,
            'use_ce': flags_obj.use_ce,
            'optimizer': string.lower(flags_obj.optimizer),
            'clip_grad': flags_obj.clip_grad,
            'spectral_norm': flags_obj.spectral_norm,
            'ce_scale': flags_obj.ce_scale,
            'sep_grad_nrom': flags_obj.sep_grad_nrom,
            'norm_teach_feature': flags_obj.norm_teach_feature,
            'no_dense_init': flags_obj.no_dense_init,
            'compress_ratio': flags_obj.compress_ratio
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
        'fine_tune': flags_obj.fine_tune,
        'reconst_loss_scale': flags_obj.reconst_loss_scale,
        'use_ce': flags_obj.use_ce,
        'optimizer': string.lower(flags_obj.optimizer),
        'clip_grad': flags_obj.clip_grad,
        'spectral_norm': flags_obj.spectral_norm,
        'ce_scale': flags_obj.ce_scale,
        'sep_grad_nrom': flags_obj.sep_grad_nrom,
        'norm_teach_feature': flags_obj.norm_teach_feature,
        'no_dense_init': flags_obj.no_dense_init,
        'compress_ratio': flags_obj.compress_ratio,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train(num_epochs):
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=num_epochs,
            dtype=flags_core.get_tf_dtype(flags_obj),
            datasets_num_private_threads=flags_obj.
            datasets_num_private_threads,
            num_parallel_batches=flags_obj.datasets_num_parallel_batches)

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1,
            dtype=flags_core.get_tf_dtype(flags_obj))

    if flags_obj.eval_only or not flags_obj.train_epochs:
        # If --eval_only is set, perform a single loop with zero train epochs.
        schedule, n_loops = [0], 1
    else:
        # Compute the number of times to loop while training. All but the last
        # pass will train for `epochs_between_evals` epochs, while the last will
        # train for the number needed to reach `training_epochs`. For instance if
        #   train_epochs = 25 and epochs_between_evals = 10
        # schedule will be set to [10, 10, 5]. That is to say, the loop will:
        #   Train for 10 epochs and then evaluate.
        #   Train for another 10 epochs and then evaluate.
        #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
        n_loops = math.ceil(flags_obj.train_epochs /
                            flags_obj.epochs_between_evals)
        schedule = [
            flags_obj.epochs_between_evals for _ in range(int(n_loops))
        ]
        schedule[-1] = flags_obj.train_epochs - sum(
            schedule[:-1])  # over counting.

    print('schedule: ', schedule, flags_obj.epochs_between_evals,
          flags_obj.max_train_steps)
    for cycle_index, num_train_epochs in enumerate(schedule):
        tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

        if num_train_epochs:
            classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
                             hooks=train_hooks,
                             max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        export_dtype = flags_core.get_tf_dtype(flags_obj)
        if flags_obj.image_bytes_as_serving_input:
            input_receiver_fn = functools.partial(image_bytes_serving_input_fn,
                                                  shape,
                                                  dtype=export_dtype)
        else:
            input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
                shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
        classifier.export_savedmodel(flags_obj.export_dir,
                                     input_receiver_fn,
                                     strip_default_attrs=True)
def run_wide_deep(flags_obj):
    """Run Wide-Deep training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """

    # Clean up the model directory if present
    shutil.rmtree(flags_obj.model_dir, ignore_errors=True)
    model = build_estimator(flags_obj.model_dir, flags_obj.model_type)

    train_file = os.path.join(flags_obj.data_dir, 'adult.data')
    test_file = os.path.join(flags_obj.data_dir, 'adult.test')

    # Train and evaluate the model every `flags.epochs_between_evals` epochs.
    def train_input_fn():
        return input_fn(train_file, flags_obj.epochs_between_evals, True,
                        flags_obj.batch_size)

    def eval_input_fn():
        return input_fn(test_file, 1, False, flags_obj.batch_size)

    run_params = {
        'batch_size': flags_obj.batch_size,
        'train_epochs': flags_obj.train_epochs,
        'model_type': flags_obj.model_type,
    }

    benchmark_logger = logger.config_benchmark_logger(flags_obj)
    benchmark_logger.log_run_info('wide_deep', 'Census Income', run_params)

    loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')
    train_hooks = hooks_helper.get_train_hooks(
        flags_obj.hooks,
        batch_size=flags_obj.batch_size,
        tensors_to_log={
            'average_loss': loss_prefix + 'head/truediv',
            'loss': loss_prefix + 'head/weighted_loss/Sum'
        })

    # Train and evaluate the model every `flags.epochs_between_evals` epochs.
    for n in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
        model.train(input_fn=train_input_fn, hooks=train_hooks)
        results = model.evaluate(input_fn=eval_input_fn)

        # Display evaluation metrics
        tf.logging.info('Results at epoch %d / %d',
                        (n + 1) * flags_obj.epochs_between_evals,
                        flags_obj.train_epochs)
        tf.logging.info('-' * 60)

        for key in sorted(results):
            tf.logging.info('%s: %s' % (key, results[key]))

        benchmark_logger.log_evaluation_result(results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             results['accuracy']):
            break

    # Export the model
    if flags_obj.export_dir is not None:
        export_model(model, flags_obj.model_type, flags_obj.export_dir)
def resnet_main(flags, model_function, input_function, shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags: FLAGS object that contains the params for running. See
      ResnetArgParser for created flags.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags.export_dir is passed.
  """

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    if flags.multi_gpu:
        validate_batch_size_for_multi_gpu(flags.batch_size)

        # There are two steps required if using multi-GPU: (1) wrap the model_fn,
        # and (2) wrap the optimizer. The first happens here, and (2) happens
        # in the model_fn itself when the optimizer is defined.
        model_function = tf.contrib.estimator.replicate_model_fn(
            model_function, loss_reduction=tf.losses.Reduction.MEAN)

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags.intra_op_parallelism_threads,
        allow_soft_placement=True)

    # Set up a RunConfig to save checkpoint and set session config.
    run_config = tf.estimator.RunConfig().replace(
        save_checkpoints_secs=1e9, session_config=session_config)
    classifier = tf.estimator.Estimator(model_fn=model_function,
                                        model_dir=flags.model_dir,
                                        config=run_config,
                                        params={
                                            'resnet_size': flags.resnet_size,
                                            'data_format': flags.data_format,
                                            'batch_size': flags.batch_size,
                                            'multi_gpu': flags.multi_gpu,
                                            'version': flags.version,
                                            'loss_scale': flags.loss_scale,
                                            'dtype': flags.dtype
                                        })

    benchmark_logger = logger.config_benchmark_logger(flags.benchmark_log_dir)
    benchmark_logger.log_run_info('resnet')

    for _ in range(flags.train_epochs // flags.epochs_between_evals):
        train_hooks = hooks_helper.get_train_hooks(
            flags.hooks,
            batch_size=flags.batch_size,
            benchmark_log_dir=flags.benchmark_log_dir)

        print('Starting a training cycle.')

        def input_fn_train():
            return input_function(True, flags.data_dir, flags.batch_size,
                                  flags.epochs_between_evals,
                                  flags.num_parallel_calls, flags.multi_gpu)

        classifier.train(input_fn=input_fn_train,
                         hooks=train_hooks,
                         max_steps=flags.max_train_steps)

        print('Starting to evaluate.')

        # Evaluate the model and print results
        def input_fn_eval():
            return input_function(False, flags.data_dir, flags.batch_size, 1,
                                  flags.num_parallel_calls, flags.multi_gpu)

        # flags.max_train_steps is generally associated with testing and profiling.
        # As a result it is frequently called with synthetic data, which will
        # iterate forever. Passing steps=flags.max_train_steps allows the eval
        # (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags.stop_threshold,
                                             eval_results['accuracy']):
            break

    if flags.export_dir is not None:
        warn_on_multi_gpu_export(flags.multi_gpu)

        # Exports a saved model for the given classifier.
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags.batch_size)
        classifier.export_savedmodel(flags.export_dir, input_receiver_fn)
Beispiel #12
0
def resnet_main(seed, flags, model_function, input_function, shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags: FLAGS object that contains the params for running. See
      ResnetArgParser for created flags.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags.export_dir is passed.
  """

    mlperf_log.resnet_print(key=mlperf_log.RUN_START)

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags.intra_op_parallelism_threads,
        allow_soft_placement=True)

    if flags.num_gpus == 0:
        distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')
    elif flags.num_gpus == 1:
        distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
    else:
        distribution = tf.contrib.distribute.MirroredStrategy(
            num_gpus=flags.num_gpus)

    mlperf_log.resnet_print(key=mlperf_log.RUN_SET_RANDOM_SEED, value=seed)
    run_config = tf.estimator.RunConfig(train_distribute=distribution,
                                        save_summary_steps=2000,
                                        save_checkpoints_steps=1000,
                                        session_config=session_config,
                                        tf_random_seed=seed,
                                        keep_checkpoint_max=2)

    mlperf_log.resnet_print(key=mlperf_log.INPUT_BATCH_SIZE,
                            value=flags.batch_size)

    classifier = tf.estimator.Estimator(model_fn=model_function,
                                        model_dir=flags.model_dir,
                                        config=run_config,
                                        params={
                                            'resnet_size': flags.resnet_size,
                                            'final_size': flags.final_size,
                                            'pickle_model': flags.pickle_model,
                                            'random_init': flags.random_init,
                                            'data_format': flags.data_format,
                                            'batch_size': flags.batch_size,
                                            'train_epochs': flags.train_epochs,
                                            'version': flags.version,
                                            'version_t': flags.version_t,
                                            'loss_scale': flags.loss_scale,
                                            'gap_train': flags.gap_train,
                                            'gap_lambda': flags.gap_lambda,
                                            'gap_ft': flags.gap_ft,
                                            'gap_start': flags.gap_start,
                                            'dtype': flags.dtype,
                                            'learn_rate': flags.learn_rate,
                                            'label_smoothing':
                                            flags.label_smoothing,
                                            'enable_lars': flags.enable_lars,
                                            'enable_cos': flags.enable_cos,
                                            'cos_alpha': flags.cos_alpha,
                                            'warm_up': flags.warm_up,
                                            'weight_decay': flags.weight_decay,
                                            'fine_tune': flags.fine_tune,
                                            'enable_kd': flags.enable_kd,
                                            'kd_size': flags.kd_size,
                                            'temp_dst': flags.temp_dst,
                                            'w_dst': flags.w_dst,
                                            'mix_up': flags.mix_up,
                                            'mx_mode': flags.mx_mode,
                                            'enable_quantize':
                                            flags.enable_quantize,
                                            'online_quantize':
                                            flags.online_quantize,
                                            'enable_at': flags.enable_at,
                                            'w_at': flags.w_at,
                                        })

    if flags.benchmark_log_dir is not None:
        benchmark_logger = logger.BenchmarkLogger(flags.benchmark_log_dir)
        benchmark_logger.log_run_info('resnet')
    else:
        benchmark_logger = None

    mlperf_log.resnet_print(key=mlperf_log.TRAIN_LOOP)

    # The reference performs the first evaluation on the fourth epoch. (offset
    # eval by 3 epochs)
    mlperf_log.resnet_print(key=mlperf_log.EVAL_EPOCH_OFFSET, value=3)
    success = False
    print('Training epochs: {}'.format(flags.train_epochs))
    iter_train_epochs = flags.train_epochs
    for i in range(iter_train_epochs // flags.epochs_between_evals):
        # Data for epochs_between_evals (i.e. 4 epochs between evals) worth of
        # epochs is concatenated and run as a single block inside a session. For
        # this reason we declare all of the epochs that will be run at the start.
        # Submitters may report in a way which is reasonable for their control flow.
        for j in range(flags.epochs_between_evals):
            mlperf_log.resnet_print(key=mlperf_log.TRAIN_EPOCH,
                                    value=i * flags.epochs_between_evals + j)

        # input functions
        def input_fn_eval():
            return input_function(is_training=False,
                                  data_dir=flags.data_dir,
                                  batch_size=per_device_batch_size(
                                      flags.batch_size, flags.num_gpus),
                                  num_epochs=1,
                                  dtype=flags.dtype,
                                  oss_load=flags.oss_load)

        def input_fn_train():
            return input_function(is_training=True,
                                  data_dir=flags.data_dir,
                                  batch_size=per_device_batch_size(
                                      flags.batch_size, flags.num_gpus),
                                  num_epochs=flags.epochs_between_evals,
                                  num_gpus=flags.num_gpus,
                                  dtype=flags.dtype,
                                  mix_up=flags.mix_up,
                                  oss_load=flags.oss_load)

        # hooks for training
        train_hooks = hooks_helper.get_train_hooks(
            flags.hooks,
            batch_size=flags.batch_size,
            benchmark_log_dir=flags.benchmark_log_dir)

        _log_cache = []

        def formatter(x):
            """Abuse side effects to get tensors out of the model_fn."""
            if _log_cache:
                _log_cache.pop()
            _log_cache.append(x.copy())
            return str(x)

        compliance_hook = tf.train.LoggingTensorHook(
            tensors={_NUM_EXAMPLES_NAME: _NUM_EXAMPLES_NAME},
            every_n_iter=int(1e10),
            at_end=True,
            formatter=formatter)

        extra_hooks = [compliance_hook]
        if flags.enable_quantize:
            if flags.online_quantize:
                # online calculate the KL scale before train-eval
                quant_online_hook = QuantHook(bits=flags.q_bits,
                                              online=True,
                                              quant_mode=flags.q_mode)
                eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                                   steps=1,
                                                   hooks=[quant_online_hook])

            quant_train_hook = QuantHook(bits=flags.q_bits,
                                         quant_copy_num=flags.copy_num,
                                         quant_mode=flags.q_mode)
            extra_hooks.append(quant_train_hook)

        print('Starting a training cycle.')
        classifier.train(input_fn=input_fn_train,
                         hooks=train_hooks + extra_hooks,
                         max_steps=flags.max_train_steps)

        train_examples = int(_log_cache.pop()[_NUM_EXAMPLES_NAME])
        mlperf_log.resnet_print(key=mlperf_log.INPUT_SIZE,
                                value=train_examples)

        # Evaluate the model and print results
        print('Starting to evaluate.')
        mlperf_log.resnet_print(key=mlperf_log.EVAL_START)
        # flags.max_train_steps is generally associated with testing and profiling.
        # As a result it is frequently called with synthetic data, which will
        # iterate forever. Passing steps=flags.max_train_steps allows the eval
        # (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_hooks = None
        if flags.enable_quantize:
            quant_eval_hook = QuantHook(bits=flags.q_bits,
                                        quant_mode=flags.q_mode)
            eval_hooks = [quant_eval_hook]
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags.max_train_steps,
                                           hooks=eval_hooks)
        mlperf_log.resnet_print(key=mlperf_log.EVAL_STOP)
        mlperf_log.resnet_print(key=mlperf_log.EVAL_SIZE,
                                value=int(eval_results[_NUM_EXAMPLES_NAME]))
        mlperf_log.resnet_print(key=mlperf_log.EVAL_ACCURACY,
                                value=float(eval_results['accuracy']))
        mlperf_log.resnet_print(key=mlperf_log.EVAL_TARGET,
                                value=flags.stop_threshold)
        print(eval_results)

        if benchmark_logger:
            benchmark_logger.log_estimator_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags.stop_threshold,
                                             eval_results['accuracy']):
            success = True
            break

    mlperf_log.resnet_print(key=mlperf_log.RUN_STOP,
                            value={"success": success})
    mlperf_log.resnet_print(key=mlperf_log.RUN_FINAL)
Beispiel #13
0
def resnet_main(flags_obj,
                model_function,
                input_function,
                dataset_name,
                shape=None):
    """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.
  """

    model_helpers.apply_clean(flags.FLAGS)

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config)

    classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        params={
            'resnet_size': int(flags_obj.resnet_size),
            'data_format': flags_obj.data_format,
            'batch_size': flags_obj.batch_size,
            'resnet_version': int(flags_obj.resnet_version),
            'loss_scale': flags_core.get_loss_scale(flags_obj),
            'dtype': flags_core.get_tf_dtype(flags_obj)
        })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    if flags_obj.use_synthetic_data:
        dataset_name = dataset_name + '-synthetic'

    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info('resnet',
                                  dataset_name,
                                  run_params,
                                  test_id=flags_obj.benchmark_test_id)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    def input_fn_train():
        return input_function(
            is_training=True,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=flags_obj.epochs_between_evals,
            num_gpus=flags_core.get_num_gpus(flags_obj))

    def input_fn_eval():
        return input_function(
            is_training=False,
            data_dir=flags_obj.data_dir,
            batch_size=distribution_utils.per_device_batch_size(
                flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
            num_epochs=1)

    total_training_cycle = (flags_obj.train_epochs //
                            flags_obj.epochs_between_evals)
    for cycle_index in range(total_training_cycle):
        tf.logging.info('Starting a training cycle: %d/%d', cycle_index,
                        total_training_cycle)

        classifier.train(input_fn=input_fn_train,
                         hooks=train_hooks,
                         max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
            shape, batch_size=flags_obj.batch_size)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
Beispiel #14
0
def main(argv):
    parser = MNISTArgParser()
    flags = parser.parse_args(args=argv[1:])

    model_function = model_fn

    if flags.multi_gpu:
        validate_batch_size_for_multi_gpu(flags.batch_size)

        # There are two steps required if using multi-GPU: (1) wrap the model_fn,
        # and (2) wrap the optimizer. The first happens here, and (2) happens
        # in the model_fn itself when the optimizer is defined.
        model_function = tf.contrib.estimator.replicate_model_fn(
            model_fn, loss_reduction=tf.losses.Reduction.MEAN)

    data_format = flags.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    mnist_classifier = tf.estimator.Estimator(model_fn=model_function,
                                              model_dir=flags.model_dir,
                                              params={
                                                  'data_format': data_format,
                                                  'multi_gpu': flags.multi_gpu
                                              })

    # Set up training and evaluation input functions.
    def train_input_fn():
        """Prepare data for training."""

        # When choosing shuffle buffer sizes, larger sizes result in better
        # randomness, while smaller sizes use less memory. MNIST is a small
        # enough dataset that we can easily shuffle the full epoch.
        ds = dataset.train(flags.data_dir)
        ds = ds.cache().shuffle(buffer_size=50000).batch(flags.batch_size)

        # Iterate through the dataset a set number (`epochs_between_evals`) of times
        # during each training session.
        ds = ds.repeat(flags.epochs_between_evals)
        return ds

    def eval_input_fn():
        return dataset.test(flags.data_dir).batch(
            flags.batch_size).make_one_shot_iterator().get_next()

    # Set up hook that outputs training logs every 100 steps.
    train_hooks = hooks_helper.get_train_hooks(flags.hooks,
                                               batch_size=flags.batch_size)

    # Train and evaluate model.
    for _ in range(flags.train_epochs // flags.epochs_between_evals):
        mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
        eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
        print('\nEvaluation results:\n\t%s\n' % eval_results)

        if model_helpers.past_stop_threshold(flags.stop_threshold,
                                             eval_results['accuracy']):
            break

    # Export the model
    if flags.export_dir is not None:
        image = tf.placeholder(tf.float32, [None, 28, 28])
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
            'image':
            image,
        })
        mnist_classifier.export_savedmodel(flags.export_dir, input_fn)
Beispiel #15
0
def run_mnist(flags_obj):
    """Run MNIST training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """
    model_helpers.apply_clean(flags_obj)
    model_function = model_fn

    # Get number of GPUs as defined by the --num_gpus flags and the number of
    # GPUs available on the machine.
    num_gpus = flags_core.get_num_gpus(flags_obj)
    multi_gpu = num_gpus > 1

    if multi_gpu:
        # Validate that the batch size can be split into devices.
        distribution_utils.per_device_batch_size(flags_obj.batch_size,
                                                 num_gpus)

        # There are two steps required if using multi-GPU: (1) wrap the model_fn,
        # and (2) wrap the optimizer. The first happens here, and (2) happens
        # in the model_fn itself when the optimizer is defined.
        model_function = tf.contrib.estimator.replicate_model_fn(
            model_fn,
            loss_reduction=tf.losses.Reduction.MEAN,
            devices=["/device:GPU:%d" % d for d in range(num_gpus)])

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    mnist_classifier = tf.estimator.Estimator(model_fn=model_function,
                                              model_dir=flags_obj.model_dir,
                                              params={
                                                  'data_format': data_format,
                                                  'multi_gpu': multi_gpu
                                              })

    # Set up training and evaluation input functions.
    def train_input_fn():
        """Prepare data for training."""

        # When choosing shuffle buffer sizes, larger sizes result in better
        # randomness, while smaller sizes use less memory. MNIST is a small
        # enough dataset that we can easily shuffle the full epoch.
        ds = dataset.train(flags_obj.data_dir)
        ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)

        # Iterate through the dataset a set number (`epochs_between_evals`) of times
        # during each training session.
        ds = ds.repeat(flags_obj.epochs_between_evals)
        return ds

    def eval_input_fn():
        return dataset.test(flags_obj.data_dir).batch(
            flags_obj.batch_size).make_one_shot_iterator().get_next()

    # Set up hook that outputs training logs every 100 steps.
    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    # Train and evaluate model.
    for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
        mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
        eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
        print('\nEvaluation results:\n\t%s\n' % eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    # Export the model
    if flags_obj.export_dir is not None:
        image = tf.placeholder(tf.float32, [None, 28, 28])
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
            'image':
            image,
        })
        mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn)
Beispiel #16
0
def run_mnist(flags_obj):
    """Run MNIST training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """
    model_helpers.apply_clean(flags_obj)
    model_function = model_fn

    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_core.get_num_gpus(flags_obj),
        all_reduce_alg=flags_obj.all_reduce_alg)

    run_config = tf.estimator.RunConfig(train_distribute=distribution_strategy,
                                        session_config=session_config)

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    mnist_classifier = tf.estimator.Estimator(model_fn=model_function,
                                              model_dir=flags_obj.model_dir,
                                              config=run_config,
                                              params={
                                                  'data_format': data_format,
                                              })

    # Set up training and evaluation input functions.
    def train_input_fn():
        """Prepare data for training."""

        # When choosing shuffle buffer sizes, larger sizes result in better
        # randomness, while smaller sizes use less memory. MNIST is a small
        # enough dataset that we can easily shuffle the full epoch.
        ds = dataset.train(flags_obj.data_dir)
        ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)

        # Iterate through the dataset a set number (`epochs_between_evals`) of times
        # during each training session.
        ds = ds.repeat(flags_obj.epochs_between_evals)
        return ds

    def eval_input_fn():
        return dataset.test(flags_obj.data_dir).batch(
            flags_obj.batch_size).make_one_shot_iterator().get_next()

    # Set up hook that outputs training logs every 100 steps.
    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,
                                               model_dir=flags_obj.model_dir,
                                               batch_size=flags_obj.batch_size)

    # Train and evaluate model.
    for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
        mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
        eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
        print('\nEvaluation results:\n\t%s\n' % eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
                                             eval_results['accuracy']):
            break

    # Export the model
    if flags_obj.export_dir is not None:
        image = tf.placeholder(tf.float32, [None, 28, 28])
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
            'image':
            image,
        })
        mnist_classifier.export_savedmodel(flags_obj.export_dir,
                                           input_fn,
                                           strip_default_attrs=True)
 def test_raise_in_non_list_names(self):
   with self.assertRaises(ValueError):
     hooks_helper.get_train_hooks(
         'LoggingTensorHook, ProfilerHook', batch_size=256)
Beispiel #18
0
def run_mnist(flags_obj):
    """Run MNIST training and eval loop.
    Args:
      flags_obj: An object containing parsed flag values.
    """
    model_helpers.apply_clean(flags_obj)
    model_function = model_fn

    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

    run_config = tf.estimator.RunConfig(
        train_distribute=distribution_strategy,
        session_config=session_config,
        save_checkpoints_steps=flags_obj.ckpt_steps,
        keep_checkpoint_max=flags_obj.max_ckpts,
        save_summary_steps=flags_obj.save_summary_steps,
        log_step_count_steps=flags_obj.log_step_count_steps
    )

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
    mnist_classifier = tf.estimator.Estimator(
        model_fn=model_function,
        model_dir=flags_obj.model_dir,
        config=run_config,
        params={
            'data_format': data_format,
        })

    # Set up training and evaluation input functions.
    def train_input_fn():
        """Prepare data for training."""

        # When choosing shuffle buffer sizes, larger sizes result in better
        # randomness, while smaller sizes use less memory. MNIST is a small
        # enough dataset that we can easily shuffle the full epoch.
        ds = dataset.train(flags_obj.data_dir)
        ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)

        # Iterate through the dataset a set number (`epochs_between_evals`) of times
        # during each training session.
        ds = ds.repeat(flags_obj.epochs_between_evals)
        return ds

    def eval_input_fn():
        return dataset.test(flags_obj.data_dir).batch(
            flags_obj.batch_size).make_one_shot_iterator().get_next()

    # Set up hook that outputs training logs every 100 steps.
    train_hooks = hooks_helper.get_train_hooks(
        flags_obj.hooks, model_dir=flags_obj.model_dir,
        batch_size=flags_obj.batch_size)

    train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, hooks=train_hooks, max_steps=flags_obj.max_steps)
    eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=None,
                                      start_delay_secs=10,
                                      throttle_secs=flags_obj.eval_secs)

    tf.estimator.train_and_evaluate(mnist_classifier, train_spec, eval_spec)

    # Export the model if node is master and export_dir is set and if experiment is multinode - check if its master
    if os.environ.get('PS_CONFIG') and os.environ.get('TYPE') != 'master':
        tf.logging.debug('No model was exported')
        return

    if flags_obj.export_dir:
        tf.logging.debug('Starting to Export model to {}'.format(str(flags_obj.export_dir)))
        image = tf.placeholder(tf.float32, [None, 28, 28])
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
            'image': image,
        })
        mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn,
                                           strip_default_attrs=True)
        tf.logging.debug('Model Exported')
Beispiel #19
0
def run_transformer(flags_obj):
    """Create tf.Estimator to train and evaluate transformer model.

  Args:
    flags_obj: Object containing parsed flag values.
  """
    num_gpus = flags_core.get_num_gpus(flags_obj)

    # Add flag-defined parameters to params object
    params = PARAMS_MAP[flags_obj.param_set]
    if num_gpus > 1:
        if flags_obj.param_set == "big":
            params = model_params.BIG_MULTI_GPU_PARAMS
        elif flags_obj.param_set == "base":
            params = model_params.BASE_MULTI_GPU_PARAMS

    params["data_dir"] = flags_obj.data_dir
    params["model_dir"] = flags_obj.model_dir
    params["num_parallel_calls"] = flags_obj.num_parallel_calls

    params["tpu"] = flags_obj.tpu
    params["use_tpu"] = bool(flags_obj.tpu)  # was a tpu specified.
    params["static_batch"] = flags_obj.static_batch or params["use_tpu"]
    params["allow_ffn_pad"] = not params["use_tpu"]

    params["use_synthetic_data"] = flags_obj.use_synthetic_data

    params["worker_hosts"] = flags_obj.worker_hosts
    params["task_index"] = flags_obj.task_index
    params["server_protocol"] = flags_obj.server_protocol

    # Set batch size parameter, which depends on the availability of
    # TPU and GPU, and distribution settings.
    params["batch_size"] = (
        flags_obj.batch_size
        or (params["default_batch_size_tpu"]
            if params["use_tpu"] else params["default_batch_size"]))

    if not params["use_tpu"]:
        params["batch_size"] = distribution_utils.per_device_batch_size(
            params["batch_size"], num_gpus)

    print("============== Batch Size for each GPU ==============")
    print("Batch Size for each GPU", params["batch_size"])
    print("============== Batch Size for each GPU ==============")

    schedule_manager = schedule.Manager(
        train_steps=flags_obj.train_steps,
        steps_between_evals=flags_obj.steps_between_evals,
        train_epochs=flags_obj.train_epochs,
        epochs_between_evals=flags_obj.epochs_between_evals,
        default_train_epochs=DEFAULT_TRAIN_EPOCHS,
        batch_size=params["batch_size"],
        max_length=params["max_length"],
        use_tpu=params["use_tpu"],
        num_tpu_shards=flags_obj.num_tpu_shards)

    params["repeat_dataset"] = schedule_manager.repeat_dataset

    model_helpers.apply_clean(flags.FLAGS)

    print("============== Train Hooks ==============")
    print(flags_obj.hooks)
    print("============== Train Hooks ==============")

    # Create hooks that log information about the training and metric values
    train_hooks = hooks_helper.get_train_hooks(
        flags_obj.hooks,
        model_dir=flags_obj.model_dir,
        save_steps=5000,
        tensors_to_log=TENSORS_TO_LOG,  # used for logging hooks
        batch_size=schedule_manager.batch_size  # for ExamplesPerSecondHook
    )
    benchmark_logger = logger.get_benchmark_logger()
    benchmark_logger.log_run_info(model_name="transformer",
                                  dataset_name="wmt_translate_ende",
                                  run_params=params,
                                  test_id=flags_obj.benchmark_test_id)

    # Train and evaluate transformer model
    network = construct_network(num_gpus, flags_obj, params, schedule_manager)
    run_loop(
        network=network,
        # Training arguments
        schedule_manager=schedule_manager,
        train_hooks=train_hooks,
        benchmark_logger=benchmark_logger,
        # BLEU calculation arguments
        bleu_source=flags_obj.bleu_source,
        bleu_ref=flags_obj.bleu_ref,
        bleu_threshold=flags_obj.stop_threshold,
        vocab_file=flags_obj.vocab_file)

    if flags_obj.export_dir and not params["use_tpu"]:
        serving_input_fn = export.build_tensor_serving_input_receiver_fn(
            shape=[None], dtype=tf.int64, batch_size=None)