def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
  params = train_utils.parse_configuration(FLAGS)
  model_dir = FLAGS.model_dir
  if 'train' in FLAGS.mode:
    # Pure eval modes do not output yaml files. Otherwise continuous eval job
    # may race against the train job for writing the same file.
    train_utils.serialize_config(params, model_dir)

  # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
  # can have significant impact on model speeds by utilizing float16 in case of
  # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
  # dtype is float16
  if params.runtime.mixed_precision_dtype:
    performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
                                           params.runtime.loss_scale)
  distribution_strategy = distribute_utils.get_distribution_strategy(
      distribution_strategy=params.runtime.distribution_strategy,
      all_reduce_alg=params.runtime.all_reduce_alg,
      num_gpus=params.runtime.num_gpus,
      tpu_address=params.runtime.tpu,
      **params.runtime.model_parallelism())
  with distribution_strategy.scope():
    task = task_factory.get_task(params.task, logging_dir=model_dir)

  train_lib.run_experiment(
      distribution_strategy=distribution_strategy,
      task=task,
      mode=FLAGS.mode,
      params=params,
      model_dir=model_dir)
예제 #2
0
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
  params = train_utils.parse_configuration(FLAGS)
  model_dir = FLAGS.model_dir
  if "train" in FLAGS.mode:
    train_utils.serialize_config(params, model_dir)

  if params.runtime.mixed_precision_dtype:
    performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
  distribution_strategy = distribute_utils.get_distribution_strategy(
      distribution_strategy=params.runtime.distribution_strategy,
      all_reduce_alg=params.runtime.all_reduce_alg,
      num_gpus=params.runtime.num_gpus,
      tpu_address=params.runtime.tpu,
      **params.runtime.model_parallelism())

  with distribution_strategy.scope():
    if params.task.use_crf:
      task = ap_parsing_task.APParsingTaskCRF(params.task)
    else:
      task = ap_parsing_task.APParsingTaskBase(params.task)

    ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter(
        params, model_dir)
    trainer = train_utils.create_trainer(
        params,
        task,
        train="train" in FLAGS.mode,
        evaluate=("eval" in FLAGS.mode),
        checkpoint_exporter=ckpt_exporter)

  model, _ = train_lib.run_experiment(
      distribution_strategy=distribution_strategy,
      task=task,
      mode=FLAGS.mode,
      params=params,
      trainer=trainer,
      model_dir=model_dir)

  train_utils.save_gin_config(FLAGS.mode, model_dir)

  # Export saved model.
  if "train" in FLAGS.mode:
    saved_model_path = os.path.join(model_dir, "saved_models/latest")
    logging.info("Exporting SavedModel to %s", saved_model_path)
    tf.saved_model.save(model, saved_model_path)

    if ckpt_exporter:
      logging.info("Loading best checkpoint for export")
      trainer.checkpoint.restore(ckpt_exporter.best_ckpt_path)
      saved_model_path = os.path.join(model_dir, "saved_models/best")

      # Make sure restored and not re-initialized.
      if trainer.global_step > 0:
        logging.info(
            "Exporting best saved model by %s (from global step: %d) to %s",
            params.trainer.best_checkpoint_eval_metric,
            trainer.global_step.numpy(), saved_model_path)
        tf.saved_model.save(trainer.model, saved_model_path)
    def testTrainCtl(self):
        src_model_dir = self.get_temp_dir()
        flags_dict = dict(experiment='mock',
                          mode='continuous_train_and_eval',
                          model_dir=self._model_dir,
                          params_override={
                              'task': {
                                  'init_checkpoint': src_model_dir,
                              },
                              'trainer': {
                                  'continuous_eval_timeout': 1,
                                  'steps_per_loop': 1,
                                  'train_steps': 1,
                                  'validation_steps': 1,
                                  'best_checkpoint_export_subdir': 'best_ckpt',
                                  'best_checkpoint_eval_metric': 'acc',
                                  'optimizer_config': {
                                      'optimizer': {
                                          'type': 'sgd'
                                      },
                                      'learning_rate': {
                                          'type': 'constant'
                                      }
                                  }
                              }
                          })

        with flagsaver.flagsaver(**flags_dict):
            # Train and save some checkpoints.
            params = train_utils.parse_configuration(flags.FLAGS)
            distribution_strategy = tf.distribute.get_strategy()
            with distribution_strategy.scope():
                task = task_factory.get_task(params.task,
                                             logging_dir=src_model_dir)
            _ = train_lib.run_experiment(
                distribution_strategy=distribution_strategy,
                task=task,
                mode='train',
                params=params,
                model_dir=src_model_dir)

            params = train_utils.parse_configuration(FLAGS)
            eval_metrics = train_ctl_continuous_finetune.run_continuous_finetune(
                FLAGS.mode, params, FLAGS.model_dir, run_post_eval=True)
            self.assertIn('best_acc', eval_metrics)
예제 #4
0
def main(_):
    # TODO(b/177863554): consolidate to nlp/train.py
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
    params = train_utils.parse_configuration(FLAGS)
    model_dir = FLAGS.model_dir
    train_utils.serialize_config(params, model_dir)
    continuous_finetune_lib.run_continuous_finetune(
        FLAGS.mode, params, model_dir, pretrain_steps=FLAGS.pretrain_steps)
    train_utils.save_gin_config(FLAGS.mode, model_dir)
예제 #5
0
def main(_):
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
    params = train_utils.parse_configuration(FLAGS)
    model_dir = FLAGS.model_dir
    if 'train' in FLAGS.mode:
        # Pure eval modes do not output yaml files. Otherwise continuous eval job
        # may race against the train job for writing the same file.
        train_utils.serialize_config(params, model_dir)

    if 'train_and_eval' in FLAGS.mode:
        assert (
            params.task.train_data.feature_shape ==
            params.task.validation_data.feature_shape), (
                f'train {params.task.train_data.feature_shape} != validate '
                f'{params.task.validation_data.feature_shape}')

    if 'assemblenet' in FLAGS.experiment:
        if 'eval' in FLAGS.mode:
            # Use the feature shape in validation_data for all jobs. The number of
            # frames in train_data will be used to construct the Assemblenet model.
            params.task.model.backbone.assemblenet.num_frames = params.task.validation_data.feature_shape[
                0]
            shape = params.task.validation_data.feature_shape
        else:
            params.task.model.backbone.assemblenet.num_frames = params.task.train_data.feature_shape[
                0]
            shape = params.task.train_data.feature_shape
        logging.info('mode %r num_frames %r feature shape %r', FLAGS.mode,
                     params.task.model.backbone.assemblenet.num_frames, shape)

    # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
    # can have significant impact on model speeds by utilizing float16 in case of
    # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
    # dtype is float16
    if params.runtime.mixed_precision_dtype:
        performance.set_mixed_precision_policy(
            params.runtime.mixed_precision_dtype)
    distribution_strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu)
    with distribution_strategy.scope():
        task = task_factory.get_task(params.task, logging_dir=model_dir)

    train_lib.run_experiment(distribution_strategy=distribution_strategy,
                             task=task,
                             mode=FLAGS.mode,
                             params=params,
                             model_dir=model_dir)

    train_utils.save_gin_config(FLAGS.mode, model_dir)
예제 #6
0
def main(_):
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
    params = train_utils.parse_configuration(FLAGS)
    model_dir = FLAGS.model_dir
    if 'train' in FLAGS.mode:
        # Pure eval modes do not output yaml files. Otherwise continuous eval job
        # may race against the train job for writing the same file.
        train_utils.serialize_config(params, model_dir)

    # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
    # can have significant impact on model speeds by utilizing float16 in case of
    # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
    # dtype is float16
    if params.runtime.mixed_precision_dtype:
        performance.set_mixed_precision_policy(
            params.runtime.mixed_precision_dtype,
            params.runtime.loss_scale,
            use_experimental_api=True)

    input_partition_dims = None
    if FLAGS.mode == 'train_and_eval':
        if np.prod(params.task.train_input_partition_dims) != np.prod(
                params.task.eval_input_partition_dims):
            raise ValueError('Train and eval input partition dims can not be'
                             'partitioned on the same node')
        else:
            input_partition_dims = get_computation_shape_for_model_parallelism(
                params.task.train_input_partition_dims)
    elif FLAGS.mode == 'train':
        if params.task.train_input_partition_dims:
            input_partition_dims = get_computation_shape_for_model_parallelism(
                params.task.train_input_partition_dims)
    elif FLAGS.mode == 'eval' or FLAGS.mode == 'continuous_eval':
        if params.task.eval_input_partition_dims:
            input_partition_dims = get_computation_shape_for_model_parallelism(
                params.task.eval_input_partition_dims)

    distribution_strategy = create_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        num_gpus=params.runtime.num_gpus,
        input_partition_dims=input_partition_dims,
        tpu_address=params.runtime.tpu)
    with distribution_strategy.scope():
        task = task_factory.get_task(params.task, logging_dir=model_dir)

    train_lib.run_experiment(distribution_strategy=distribution_strategy,
                             task=task,
                             mode=FLAGS.mode,
                             params=params,
                             model_dir=model_dir)
예제 #7
0
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
  params = train_utils.parse_configuration(FLAGS)
  model_dir = FLAGS.model_dir
  if 'train' in FLAGS.mode:
    # Pure eval modes do not output yaml files. Otherwise continuous eval job
    # may race against the train job for writing the same file.
    train_utils.serialize_config(params, model_dir)

  # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
  # can have significant impact on model speeds by utilizing float16 in case of
  # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
  # dtype is float16
  if params.runtime.mixed_precision_dtype:
    performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
  distribution_strategy = distribute_utils.get_distribution_strategy(
      distribution_strategy=params.runtime.distribution_strategy,
      all_reduce_alg=params.runtime.all_reduce_alg,
      num_gpus=params.runtime.num_gpus,
      tpu_address=params.runtime.tpu)

  if isinstance(params, cfg.ExperimentConfig):
    with distribution_strategy.scope():
      task = task_factory.get_task(params.task, logging_dir=model_dir)

    train_lib.run_experiment(
        distribution_strategy=distribution_strategy,
        task=task,
        mode=FLAGS.mode,
        params=params,
        model_dir=model_dir)

  elif isinstance(params, multi_cfg.MultiTaskExperimentConfig):
    with distribution_strategy.scope():
      task = multitask.MultiTask.from_config(params.task, model_dir)
      model = multihead_model.build_model(params.task)

    train_lib_multitask.run_experiment(
        distribution_strategy=distribution_strategy,
        task=task,
        model=model,
        mode=FLAGS.mode,
        params=params,
        model_dir=model_dir)

  else:
    raise ValueError("Expected config to be either type cfg.ExperimentConfig" + \
      "or multi_cfg.MultiTaskExperimentConfig, got %s" %type(params))

  train_utils.save_gin_config(FLAGS.mode, model_dir)
예제 #8
0
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
  params = train_utils.parse_configuration(FLAGS)
  model_dir = FLAGS.model_dir
  if 'train' in FLAGS.mode:
    # Pure eval modes do not output yaml files. Otherwise continuous eval job
    # may race against the train job for writing the same file.
    train_utils.serialize_config(params, model_dir)

  # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
  # can have significant impact on model speeds by utilizing float16 in case of
  # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
  # dtype is float16
  if params.runtime.mixed_precision_dtype:
    performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
<<<<<<< HEAD
                                           params.runtime.loss_scale)
  def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
    model_dir = self.get_temp_dir()
    flags_dict = dict(
        experiment='mock',
        mode=flag_mode,
        model_dir=model_dir,
        params_override=json.dumps(self._test_config))
    with flagsaver.flagsaver(**flags_dict):
      params = train_utils.parse_configuration(flags.FLAGS)
      train_utils.serialize_config(params, model_dir)
      with distribution_strategy.scope():
        task = task_factory.get_task(params.task, logging_dir=model_dir)

      _, logs = train_lib.run_experiment(
          distribution_strategy=distribution_strategy,
          task=task,
          mode=flag_mode,
          params=params,
          model_dir=model_dir,
          run_post_eval=run_post_eval)

    if 'eval' in flag_mode:
      self.assertTrue(
          tf.io.gfile.exists(
              os.path.join(model_dir,
                           params.trainer.validation_summary_subdir)))
    if run_post_eval:
      self.assertNotEmpty(logs)
    else:
      self.assertEmpty(logs)
    self.assertNotEmpty(
        tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
    if flag_mode == 'eval':
      return
    self.assertNotEmpty(
        tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
    # Tests continuous evaluation.
    _, logs = train_lib.run_experiment(
        distribution_strategy=distribution_strategy,
        task=task,
        mode='continuous_eval',
        params=params,
        model_dir=model_dir,
        run_post_eval=run_post_eval)
    print(logs)
예제 #10
0
 def test_construct_experiment_from_flags(self):
     options = train_utils.ParseConfigOptions(
         experiment='foo',
         config_file=[],
         tpu='bar',
         tf_data_service='',
         params_override='task.model.model_id=new,'
         'trainer.train_steps=10,'
         'trainer.validation_steps=11')
     builder = train_utils.ExperimentParser(options)
     params_from_obj = builder.parse()
     params_from_func = train_utils.parse_configuration(options)
     pp = pprint.PrettyPrinter()
     self.assertEqual(pp.pformat(params_from_obj.as_dict()),
                      pp.pformat(params_from_func.as_dict()))
     self.assertEqual(params_from_obj.runtime.tpu, 'bar')
     self.assertEqual(params_from_obj.task.model.model_id, 'new')
     self.assertEqual(params_from_obj.trainer.train_steps, 10)
     self.assertEqual(params_from_obj.trainer.validation_steps, 11)
예제 #11
0
    def test_recovery(self, distribution_strategy, flag_mode):
        loss_threshold = 1.0
        model_dir = self.get_temp_dir()
        flags_dict = dict(experiment='mock',
                          mode=flag_mode,
                          model_dir=model_dir,
                          params_override=json.dumps(self._test_config))
        with flagsaver.flagsaver(**flags_dict):
            params = train_utils.parse_configuration(flags.FLAGS)
            params.trainer.loss_upper_bound = loss_threshold
            params.trainer.recovery_max_trials = 1
            train_utils.serialize_config(params, model_dir)
            with distribution_strategy.scope():
                task = task_factory.get_task(params.task,
                                             logging_dir=model_dir)

            # Saves a checkpoint for reference.
            model = task.build_model()
            checkpoint = tf.train.Checkpoint(model=model)
            checkpoint_manager = tf.train.CheckpointManager(
                checkpoint, self.get_temp_dir(), max_to_keep=2)
            checkpoint_manager.save()
            before_weights = model.get_weights()

            def build_losses(labels, model_outputs, aux_losses=None):
                del labels, model_outputs
                return tf.constant([loss_threshold], tf.float32) + aux_losses

            task.build_losses = build_losses

            model, _ = train_lib.run_experiment(
                distribution_strategy=distribution_strategy,
                task=task,
                mode=flag_mode,
                params=params,
                model_dir=model_dir)
            after_weights = model.get_weights()
            for left, right in zip(before_weights, after_weights):
                self.assertAllEqual(left, right)
예제 #12
0
def load_flags(CFG):
  params = train_utils.parse_configuration(CFG)
  model_dir = CFG.model_dir

  if params.runtime.mixed_precision_dtype:
    performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
                                           params.runtime.loss_scale)

  task = task_factory.get_task(params.task, logging_dir=model_dir)
  model = task.build_model()

  if model_dir is not None and model_dir != "":
    optimizer = task.create_optimizer(params.trainer.optimizer_config,
                                      params.runtime)
    # optimizer = tf.keras.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.SGD(), dynamic = True)
    ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
    status = ckpt.restore(tf.train.latest_checkpoint(model_dir))

    status.expect_partial().assert_existing_objects_matched()
    print(dir(status), status)
  else:
    task.initialize(model)

  return task, model, params
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
  params = train_utils.parse_configuration(FLAGS)
  model_dir = FLAGS.model_dir
  train_utils.serialize_config(params, model_dir)
  run_continuous_finetune(FLAGS.mode, params, model_dir)
예제 #14
0
def main(_) -> None:
  """Train and evaluate the Ranking model."""
  params = train_utils.parse_configuration(FLAGS)
  mode = FLAGS.mode
  model_dir = FLAGS.model_dir
  if 'train' in FLAGS.mode:
    # Pure eval modes do not output yaml files. Otherwise continuous eval job
    # may race against the train job for writing the same file.
    train_utils.serialize_config(params, model_dir)

  if FLAGS.seed is not None:
    logging.info('Setting tf seed.')
    tf.random.set_seed(FLAGS.seed)

  task = RankingTask(
      params=params.task,
      optimizer_config=params.trainer.optimizer_config,
      logging_dir=model_dir,
      steps_per_execution=params.trainer.steps_per_loop,
      name='RankingTask')

  enable_tensorboard = params.trainer.callbacks.enable_tensorboard

  strategy = distribute_utils.get_distribution_strategy(
      distribution_strategy=params.runtime.distribution_strategy,
      all_reduce_alg=params.runtime.all_reduce_alg,
      num_gpus=params.runtime.num_gpus,
      tpu_address=params.runtime.tpu)

  with strategy.scope():
    model = task.build_model()

  def get_dataset_fn(params):
    return lambda input_context: task.build_inputs(params, input_context)

  train_dataset = None
  if 'train' in mode:
    train_dataset = strategy.distribute_datasets_from_function(
        get_dataset_fn(params.task.train_data),
        options=tf.distribute.InputOptions(experimental_fetch_to_device=False))

  validation_dataset = None
  if 'eval' in mode:
    validation_dataset = strategy.distribute_datasets_from_function(
        get_dataset_fn(params.task.validation_data),
        options=tf.distribute.InputOptions(experimental_fetch_to_device=False))

  if params.trainer.use_orbit:
    with strategy.scope():
      checkpoint_exporter = train_utils.maybe_create_best_ckpt_exporter(
          params, model_dir)
      trainer = RankingTrainer(
          config=params,
          task=task,
          model=model,
          optimizer=model.optimizer,
          train='train' in mode,
          evaluate='eval' in mode,
          train_dataset=train_dataset,
          validation_dataset=validation_dataset,
          checkpoint_exporter=checkpoint_exporter)

    train_lib.run_experiment(
        distribution_strategy=strategy,
        task=task,
        mode=mode,
        params=params,
        model_dir=model_dir,
        trainer=trainer)

  else:  # Compile/fit
    checkpoint = tf.train.Checkpoint(model=model, optimizer=model.optimizer)

    latest_checkpoint = tf.train.latest_checkpoint(model_dir)
    if latest_checkpoint:
      checkpoint.restore(latest_checkpoint)
      logging.info('Loaded checkpoint %s', latest_checkpoint)

    checkpoint_manager = tf.train.CheckpointManager(
        checkpoint,
        directory=model_dir,
        max_to_keep=params.trainer.max_to_keep,
        step_counter=model.optimizer.iterations,
        checkpoint_interval=params.trainer.checkpoint_interval)
    checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)

    time_callback = keras_utils.TimeHistory(
        params.task.train_data.global_batch_size,
        params.trainer.time_history.log_steps,
        logdir=model_dir if enable_tensorboard else None)
    callbacks = [checkpoint_callback, time_callback]

    if enable_tensorboard:
      tensorboard_callback = tf.keras.callbacks.TensorBoard(
          log_dir=model_dir,
          update_freq=min(1000, params.trainer.validation_interval),
          profile_batch=FLAGS.profile_steps)
      callbacks.append(tensorboard_callback)

    num_epochs = (params.trainer.train_steps //
                  params.trainer.validation_interval)
    current_step = model.optimizer.iterations.numpy()
    initial_epoch = current_step // params.trainer.validation_interval

    eval_steps = params.trainer.validation_steps if 'eval' in mode else None

    if mode in ['train', 'train_and_eval']:
      logging.info('Training started')
      history = model.fit(
          train_dataset,
          initial_epoch=initial_epoch,
          epochs=num_epochs,
          steps_per_epoch=params.trainer.validation_interval,
          validation_data=validation_dataset,
          validation_steps=eval_steps,
          callbacks=callbacks,
      )
      model.summary()
      logging.info('Train history: %s', history.history)
    elif mode == 'eval':
      logging.info('Evaluation started')
      validation_output = model.evaluate(validation_dataset, steps=eval_steps)
      logging.info('Evaluation output: %s', validation_output)
    else:
      raise NotImplementedError('The mode is not implemented: %s' % mode)
예제 #15
0
def main(_):
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
    params = train_utils.parse_configuration(FLAGS)

    if params.runtime.num_hpus > 0:
        import os
        #TODO: remove when SW-49334 is fixed [SW-49404]
        os.environ["TF_DISABLE_EAGER_TO_FUNC_REWRITER"] = "1"
        from habana_frameworks.tensorflow import load_habana_module
        load_habana_module()

    if params.task.train_data.deterministic or params.task.validation_data.deterministic:
        import os
        os.environ['PYTHONHASHSEED'] = '0'
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
        import numpy
        numpy.random.seed(0)
        import tensorflow as tf
        tf.random.set_seed(0)
        tf.compat.v1.set_random_seed(0)
        import random
        random.seed(0)

    if FLAGS.dtype == "bf16":
        print("Using bf16 config list {}".format(FLAGS.bf16_config_path))
        os.environ['TF_BF16_CONVERSION'] = FLAGS.bf16_config_path

    hls_addresses = str(os.environ.get("MULTI_HLS_IPS",
                                       "127.0.0.1")).split(",")
    TF_BASE_PORT = 2410
    mpi_rank = comm_rank()
    mpi_size = comm_size()

    if params.runtime.num_hpus > 1:
        model_dir = os.path.join(FLAGS.model_dir, "worker_" + str(mpi_rank))
    else:
        model_dir = FLAGS.model_dir

    #prepare a comma-seperated list of device addreses
    worker_list = []
    for address in hls_addresses:
        for rank in range(mpi_size // len(hls_addresses)):
            worker_list.append(address + ':' + str(TF_BASE_PORT + rank))
    worker_hosts = ",".join(worker_list)
    task_index = mpi_rank

    # Configures cluster spec for distribution strategy.
    distribution_utils.configure_cluster(worker_hosts, task_index)
    if 'train' in FLAGS.mode:
        # Pure eval modes do not output yaml files. Otherwise continuous eval job
        # may race against the train job for writing the same file.
        train_utils.serialize_config(params, model_dir)

    # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
    # can have significant impact on model speeds by utilizing float16 in case of
    # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
    # dtype is float16
    if params.runtime.mixed_precision_dtype:
        performance.set_mixed_precision_policy(
            params.runtime.mixed_precision_dtype)

    distribution_strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        num_hpus=params.runtime.num_hpus,
        tpu_address=params.runtime.tpu)

    with distribution_strategy.scope():
        task = task_factory.get_task(params.task, logging_dir=model_dir)

    train_lib.run_experiment(distribution_strategy=distribution_strategy,
                             task=task,
                             mode=FLAGS.mode,
                             params=params,
                             model_dir=model_dir)

    train_utils.save_gin_config(FLAGS.mode, model_dir)