def test_invalid_strategy(self):
     with self.assertRaisesRegexp(
             ValueError,
             'distribution_strategy must be a string but got: False. If'):
         distribute_utils.get_distribution_strategy(False)
     with self.assertRaisesRegexp(
             ValueError,
             'distribution_strategy must be a string but got: 1'):
         distribute_utils.get_distribution_strategy(1)
    def test_tpu_strategy(self):
        if not TPU_TEST:
            self.skipTest('Only Cloud TPU VM instances can have local TPUs.')
        with self.assertRaises(ValueError):
            _ = distribute_utils.get_distribution_strategy('tpu')

        ds = distribute_utils.get_distribution_strategy('tpu',
                                                        tpu_address='local')
        self.assertIsInstance(ds, tf.distribute.TPUStrategy)
    def test_mwms(self):
        distribute_utils.configure_cluster(worker_hosts=None, task_index=-1)
        ds = distribute_utils.get_distribution_strategy(
            'multi_worker_mirrored', all_reduce_alg='nccl')
        self.assertIsInstance(
            ds, tf.distribute.experimental.MultiWorkerMirroredStrategy)

        with self.assertRaisesRegex(
                ValueError,
                'When used with `multi_worker_mirrored`, valid values.*'):
            _ = distribute_utils.get_distribution_strategy(
                'multi_worker_mirrored', all_reduce_alg='dummy')
 def test_get_strategy_scope(self):
     ds = distribute_utils.get_distribution_strategy('one_device',
                                                     num_gpus=0)
     with distribute_utils.get_strategy_scope(ds):
         self.assertIs(tf.distribute.get_strategy(), ds)
     with distribute_utils.get_strategy_scope(None):
         self.assertIsNot(tf.distribute.get_strategy(), ds)
Esempio n. 5
0
def main(_):
    logging.info('Parsing config files...')
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
    params = get_exp_config()

    # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
    # can have significant impact on model speeds by utilizing float16 in case of
    # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
    # dtype is float16
    if params.runtime.mixed_precision_dtype:
        performance.set_mixed_precision_policy(
            params.runtime.mixed_precision_dtype,
            params.runtime.loss_scale,
            use_experimental_api=True)
    distribution_strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu)

    with distribution_strategy.scope():
        task = distillation.BertDistillationTask(
            strategy=distribution_strategy,
            progressive=params.trainer.progressive,
            optimizer_config=params.trainer.optimizer_config,
            task_config=params.task)

    train_lib.run_experiment(distribution_strategy=distribution_strategy,
                             task=task,
                             mode=FLAGS.mode,
                             params=params,
                             model_dir=FLAGS.model_dir)
Esempio n. 6
0
def create_distribution_strategy(distribution_strategy,
                                 tpu_address,
                                 input_partition_dims=None,
                                 num_gpus=None):
    """Creates distribution strategy to use for computation."""

    if input_partition_dims is not None:
        if distribution_strategy != 'tpu':
            raise ValueError('Spatial partitioning is only supported '
                             'for TPUStrategy.')

        # When `input_partition_dims` is specified create custom TPUStrategy
        # instance with computation shape for model parallelism.
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=tpu_address)
        if tpu_address not in ('', 'local'):
            tf.config.experimental_connect_to_cluster(resolver)

        topology = tf.tpu.experimental.initialize_tpu_system(resolver)
        num_replicas = resolver.get_tpu_system_metadata().num_cores // np.prod(
            input_partition_dims)
        device_assignment = tf.tpu.experimental.DeviceAssignment.build(
            topology,
            num_replicas=num_replicas,
            computation_shape=input_partition_dims)
        return tf.distribute.TPUStrategy(
            resolver, experimental_device_assignment=device_assignment)

    return distribute_utils.get_distribution_strategy(
        distribution_strategy=distribution_strategy,
        tpu_address=tpu_address,
        num_gpus=num_gpus)
def main(_):
  with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    input_meta_data = json.loads(reader.read().decode('utf-8'))

  if FLAGS.mode == 'export_only':
    export_squad(FLAGS.model_export_path, input_meta_data)
    return

  # Configures cluster spec for multi-worker distribution strategy.
  if FLAGS.num_gpus > 0:
    _ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index)
  strategy = distribute_utils.get_distribution_strategy(
      distribution_strategy=FLAGS.distribution_strategy,
      num_gpus=FLAGS.num_gpus,
      all_reduce_alg=FLAGS.all_reduce_alg,
      tpu_address=FLAGS.tpu)

  if 'train' in FLAGS.mode:
    train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly)
  if 'predict' in FLAGS.mode:
    predict_squad(strategy, input_meta_data)
  if 'eval' in FLAGS.mode:
    eval_metrics = eval_squad(strategy, input_meta_data)
    f1_score = eval_metrics['final_f1']
    logging.info('SQuAD eval F1-score: %f', f1_score)
    summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
    summary_writer = tf.summary.create_file_writer(summary_dir)
    with summary_writer.as_default():
      # TODO(lehou): write to the correct step number.
      tf.summary.scalar('F1-score', f1_score, step=0)
      summary_writer.flush()
    # Also write eval_metrics to json file.
    squad_lib_sp.write_to_json_files(
        eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
    time.sleep(60)
    def __init__(self, flags_obj):
        """Init function of TransformerMain.

    Args:
      flags_obj: Object containing parsed flag values, i.e., FLAGS.

    Raises:
      ValueError: if not using static batch for input data on TPU.
    """
        self.flags_obj = flags_obj
        self.predict_model = None

        # Add flag-defined parameters to params object
        num_gpus = flags_core.get_num_gpus(flags_obj)
        self.params = params = misc.get_model_params(flags_obj.param_set,
                                                     num_gpus)

        params["num_gpus"] = num_gpus
        params["use_ctl"] = flags_obj.use_ctl
        params["data_dir"] = flags_obj.data_dir
        params["model_dir"] = flags_obj.model_dir
        params["static_batch"] = flags_obj.static_batch
        params["max_length"] = flags_obj.max_length
        params["decode_batch_size"] = flags_obj.decode_batch_size
        params["decode_max_length"] = flags_obj.decode_max_length
        params["padded_decode"] = flags_obj.padded_decode
        params["max_io_parallelism"] = (flags_obj.num_parallel_calls
                                        or tf.data.experimental.AUTOTUNE)

        params["use_synthetic_data"] = flags_obj.use_synthetic_data
        params["batch_size"] = flags_obj.batch_size or params[
            "default_batch_size"]
        params["repeat_dataset"] = None
        params["dtype"] = flags_core.get_tf_dtype(flags_obj)
        params["enable_tensorboard"] = flags_obj.enable_tensorboard
        params[
            "enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
        params["steps_between_evals"] = flags_obj.steps_between_evals
        params["enable_checkpointing"] = flags_obj.enable_checkpointing
        params["save_weights_only"] = flags_obj.save_weights_only

        self.distribution_strategy = distribute_utils.get_distribution_strategy(
            distribution_strategy=flags_obj.distribution_strategy,
            num_gpus=num_gpus,
            all_reduce_alg=flags_obj.all_reduce_alg,
            num_packs=flags_obj.num_packs,
            tpu_address=flags_obj.tpu or "")
        if self.use_tpu:
            params[
                "num_replicas"] = self.distribution_strategy.num_replicas_in_sync
        else:
            logging.info("Running transformer with num_gpus = %d", num_gpus)

        if self.distribution_strategy:
            logging.info("For training, using distribution strategy: %s",
                         self.distribution_strategy)
        else:
            logging.info("Not using any distribution strategy.")

        performance.set_mixed_precision_policy(params["dtype"])
Esempio n. 9
0
def main(_):
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
    params = train_utils.parse_configuration(FLAGS)
    model_dir = FLAGS.model_dir
    if 'train' in FLAGS.mode:
        # Pure eval modes do not output yaml files. Otherwise continuous eval job
        # may race against the train job for writing the same file.
        train_utils.serialize_config(params, model_dir)

    # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
    # can have significant impact on model speeds by utilizing float16 in case of
    # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
    # dtype is float16
    if params.runtime.mixed_precision_dtype:
        performance.set_mixed_precision_policy(
            params.runtime.mixed_precision_dtype)
    distribution_strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu,
        **params.runtime.model_parallelism())

    with distribution_strategy.scope():
        task = classification_example.ClassificationExampleTask(params.task)

    train_lib.run_experiment(distribution_strategy=distribution_strategy,
                             task=task,
                             mode=FLAGS.mode,
                             params=params,
                             model_dir=model_dir)

    train_utils.save_gin_config(FLAGS.mode, model_dir)
def main(_):
    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'

    strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)
    max_seq_length = input_meta_data['max_seq_length']
    train_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.train_data_path,
                                                        max_seq_length,
                                                        FLAGS.train_batch_size,
                                                        is_training=True)
    eval_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.eval_data_path,
                                                       max_seq_length,
                                                       FLAGS.eval_batch_size,
                                                       is_training=False)

    albert_config = albert_configs.AlbertConfig.from_json_file(
        FLAGS.bert_config_file)
    if FLAGS.mode == 'train_and_eval':
        run_classifier_bert.run_bert(strategy, input_meta_data, albert_config,
                                     train_input_fn, eval_input_fn)
    elif FLAGS.mode == 'predict':
        predict(strategy, albert_config, input_meta_data, eval_input_fn)
    else:
        raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
    return
    def test_invalid_args(self):
        with self.assertRaisesRegex(ValueError,
                                    '`num_gpus` can not be negative.'):
            _ = distribute_utils.get_distribution_strategy(num_gpus=-1)

        with self.assertRaisesRegex(ValueError,
                                    '.*If you meant to pass the string .*'):
            _ = distribute_utils.get_distribution_strategy(
                distribution_strategy=False, num_gpus=0)
        with self.assertRaisesRegex(ValueError, 'When 2 GPUs are specified.*'):
            _ = distribute_utils.get_distribution_strategy(
                distribution_strategy='off', num_gpus=2)
        with self.assertRaisesRegex(ValueError,
                                    '`OneDeviceStrategy` can not be used.*'):
            _ = distribute_utils.get_distribution_strategy(
                distribution_strategy='one_device', num_gpus=2)
Esempio n. 12
0
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
  params = train_utils.parse_configuration(FLAGS)
  model_dir = FLAGS.model_dir
  if "train" in FLAGS.mode:
    train_utils.serialize_config(params, model_dir)

  if params.runtime.mixed_precision_dtype:
    performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
  distribution_strategy = distribute_utils.get_distribution_strategy(
      distribution_strategy=params.runtime.distribution_strategy,
      all_reduce_alg=params.runtime.all_reduce_alg,
      num_gpus=params.runtime.num_gpus,
      tpu_address=params.runtime.tpu,
      **params.runtime.model_parallelism())

  with distribution_strategy.scope():
    if params.task.use_crf:
      task = ap_parsing_task.APParsingTaskCRF(params.task)
    else:
      task = ap_parsing_task.APParsingTaskBase(params.task)

    ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter(
        params, model_dir)
    trainer = train_utils.create_trainer(
        params,
        task,
        train="train" in FLAGS.mode,
        evaluate=("eval" in FLAGS.mode),
        checkpoint_exporter=ckpt_exporter)

  model, _ = train_lib.run_experiment(
      distribution_strategy=distribution_strategy,
      task=task,
      mode=FLAGS.mode,
      params=params,
      trainer=trainer,
      model_dir=model_dir)

  train_utils.save_gin_config(FLAGS.mode, model_dir)

  # Export saved model.
  if "train" in FLAGS.mode:
    saved_model_path = os.path.join(model_dir, "saved_models/latest")
    logging.info("Exporting SavedModel to %s", saved_model_path)
    tf.saved_model.save(model, saved_model_path)

    if ckpt_exporter:
      logging.info("Loading best checkpoint for export")
      trainer.checkpoint.restore(ckpt_exporter.best_ckpt_path)
      saved_model_path = os.path.join(model_dir, "saved_models/best")

      # Make sure restored and not re-initialized.
      if trainer.global_step > 0:
        logging.info(
            "Exporting best saved model by %s (from global step: %d) to %s",
            params.trainer.best_checkpoint_eval_metric,
            trainer.global_step.numpy(), saved_model_path)
        tf.saved_model.save(trainer.model, saved_model_path)
Esempio n. 13
0
def main(unused_argv):
  del unused_argv
  strategy = distribute_utils.get_distribution_strategy(
      distribution_strategy=FLAGS.strategy_type,
      tpu_address=FLAGS.tpu)
  if strategy:
    logging.info("***** Number of cores used : %d",
                 strategy.num_replicas_in_sync)
  train_input_fn = functools.partial(data_utils.get_classification_input_data,
                                     FLAGS.train_batch_size, FLAGS.seq_len,
                                     strategy, True, FLAGS.train_tfrecord_path)
  test_input_fn = functools.partial(data_utils.get_classification_input_data,
                                    FLAGS.test_batch_size, FLAGS.seq_len,
                                    strategy, False, FLAGS.test_tfrecord_path)

  total_training_steps = FLAGS.train_steps
  steps_per_loop = FLAGS.iterations
  eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
  eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
                              eval_steps)
  optimizer, learning_rate_fn = optimization.create_optimizer(
      FLAGS.learning_rate,
      total_training_steps,
      FLAGS.warmup_steps,
      adam_epsilon=FLAGS.adam_epsilon)
  model_config = xlnet_config.XLNetConfig(FLAGS)
  run_config = xlnet_config.create_run_config(True, False, FLAGS)
  model_fn = functools.partial(get_classificationxlnet_model, model_config,
                               run_config, FLAGS.n_class, FLAGS.summary_type)
  input_meta_data = {}
  input_meta_data["d_model"] = FLAGS.d_model
  input_meta_data["mem_len"] = FLAGS.mem_len
  input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                               strategy.num_replicas_in_sync)
  input_meta_data["n_layer"] = FLAGS.n_layer
  input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
  input_meta_data["n_class"] = FLAGS.n_class

  training_utils.train(
      strategy=strategy,
      model_fn=model_fn,
      input_meta_data=input_meta_data,
      eval_fn=eval_fn,
      metric_fn=get_metric_fn,
      train_input_fn=train_input_fn,
      init_checkpoint=FLAGS.init_checkpoint,
      init_from_transformerxl=FLAGS.init_from_transformerxl,
      total_training_steps=total_training_steps,
      steps_per_loop=steps_per_loop,
      optimizer=optimizer,
      learning_rate_fn=learning_rate_fn,
      model_dir=FLAGS.model_dir,
      save_steps=FLAGS.save_steps)
    def test_mirrored_strategy(self):
        ds = distribute_utils.get_distribution_strategy(num_gpus=5)
        self.assertEquals(ds.num_replicas_in_sync, 5)
        self.assertEquals(len(ds.extended.worker_devices), 5)
        for device in ds.extended.worker_devices:
            self.assertIn('GPU', device)

        _ = distribute_utils.get_distribution_strategy(
            distribution_strategy='mirrored',
            num_gpus=2,
            all_reduce_alg='nccl',
            num_packs=2)
        with self.assertRaisesRegex(
                ValueError,
                'When used with `mirrored`, valid values for all_reduce_alg are.*'
        ):
            _ = distribute_utils.get_distribution_strategy(
                distribution_strategy='mirrored',
                num_gpus=2,
                all_reduce_alg='dummy',
                num_packs=2)
Esempio n. 15
0
def main(_):
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
    params = train_utils.parse_configuration(FLAGS)
    model_dir = FLAGS.model_dir
    if 'train' in FLAGS.mode:
        # Pure eval modes do not output yaml files. Otherwise continuous eval job
        # may race against the train job for writing the same file.
        train_utils.serialize_config(params, model_dir)

    if 'train_and_eval' in FLAGS.mode:
        assert (
            params.task.train_data.feature_shape ==
            params.task.validation_data.feature_shape), (
                f'train {params.task.train_data.feature_shape} != validate '
                f'{params.task.validation_data.feature_shape}')

    if 'assemblenet' in FLAGS.experiment:
        if 'eval' in FLAGS.mode:
            # Use the feature shape in validation_data for all jobs. The number of
            # frames in train_data will be used to construct the Assemblenet model.
            params.task.model.backbone.assemblenet.num_frames = params.task.validation_data.feature_shape[
                0]
            shape = params.task.validation_data.feature_shape
        else:
            params.task.model.backbone.assemblenet.num_frames = params.task.train_data.feature_shape[
                0]
            shape = params.task.train_data.feature_shape
        logging.info('mode %r num_frames %r feature shape %r', FLAGS.mode,
                     params.task.model.backbone.assemblenet.num_frames, shape)

    # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
    # can have significant impact on model speeds by utilizing float16 in case of
    # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
    # dtype is float16
    if params.runtime.mixed_precision_dtype:
        performance.set_mixed_precision_policy(
            params.runtime.mixed_precision_dtype)
    distribution_strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu)
    with distribution_strategy.scope():
        task = task_factory.get_task(params.task, logging_dir=model_dir)

    train_lib.run_experiment(distribution_strategy=distribution_strategy,
                             task=task,
                             mode=FLAGS.mode,
                             params=params,
                             model_dir=model_dir)

    train_utils.save_gin_config(FLAGS.mode, model_dir)
Esempio n. 16
0
def run():
    """Runs NHNet using Keras APIs."""
    if FLAGS.enable_mlir_bridge:
        tf.config.experimental.enable_mlir_bridge()

    strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        tpu_address=FLAGS.tpu)
    if strategy:
        logging.info("***** Number of cores used : %d",
                     strategy.num_replicas_in_sync)

    params = models.get_model_params(FLAGS.model_type)
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)
    params.override(
        {
            "len_title":
            FLAGS.len_title,
            "len_passage":
            FLAGS.len_passage,
            "num_hidden_layers":
            FLAGS.num_encoder_layers,
            "num_decoder_layers":
            FLAGS.num_decoder_layers,
            "passage_list":
            [chr(ord("b") + i) for i in range(FLAGS.num_nhnet_articles)],
        },
        is_strict=False)
    stats = {}
    if "train" in FLAGS.mode:
        stats = train(params, strategy)
    if "eval" in FLAGS.mode:
        timeout = 0 if FLAGS.mode == "train_and_eval" else FLAGS.eval_timeout
        # Uses padded decoding for TPU. Always uses cache.
        padded_decode = isinstance(strategy,
                                   tf.distribute.experimental.TPUStrategy)
        params.override({
            "padded_decode": padded_decode,
        }, is_strict=False)
        stats = evaluation.continuous_eval(
            strategy,
            params,
            model_type=FLAGS.model_type,
            eval_file_pattern=FLAGS.eval_file_pattern,
            batch_size=FLAGS.eval_batch_size,
            eval_steps=FLAGS.eval_steps,
            model_dir=FLAGS.model_dir,
            timeout=timeout)
    return stats
Esempio n. 17
0
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
  params = train_utils.parse_configuration(FLAGS)
  model_dir = FLAGS.model_dir
  if 'train' in FLAGS.mode:
    # Pure eval modes do not output yaml files. Otherwise continuous eval job
    # may race against the train job for writing the same file.
    train_utils.serialize_config(params, model_dir)

  # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
  # can have significant impact on model speeds by utilizing float16 in case of
  # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
  # dtype is float16
  if params.runtime.mixed_precision_dtype:
    performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
  distribution_strategy = distribute_utils.get_distribution_strategy(
      distribution_strategy=params.runtime.distribution_strategy,
      all_reduce_alg=params.runtime.all_reduce_alg,
      num_gpus=params.runtime.num_gpus,
      tpu_address=params.runtime.tpu)

  if isinstance(params, cfg.ExperimentConfig):
    with distribution_strategy.scope():
      task = task_factory.get_task(params.task, logging_dir=model_dir)

    train_lib.run_experiment(
        distribution_strategy=distribution_strategy,
        task=task,
        mode=FLAGS.mode,
        params=params,
        model_dir=model_dir)

  elif isinstance(params, multi_cfg.MultiTaskExperimentConfig):
    with distribution_strategy.scope():
      task = multitask.MultiTask.from_config(params.task, model_dir)
      model = multihead_model.build_model(params.task)

    train_lib_multitask.run_experiment(
        distribution_strategy=distribution_strategy,
        task=task,
        model=model,
        mode=FLAGS.mode,
        params=params,
        model_dir=model_dir)

  else:
    raise ValueError("Expected config to be either type cfg.ExperimentConfig" + \
      "or multi_cfg.MultiTaskExperimentConfig, got %s" %type(params))

  train_utils.save_gin_config(FLAGS.mode, model_dir)
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
  if not FLAGS.model_dir:
    FLAGS.model_dir = '/tmp/bert20/'
  # Configures cluster spec for multi-worker distribution strategy.
  if FLAGS.num_gpus > 0:
    _ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index)
  strategy = distribute_utils.get_distribution_strategy(
      distribution_strategy=FLAGS.distribution_strategy,
      num_gpus=FLAGS.num_gpus,
      all_reduce_alg=FLAGS.all_reduce_alg,
      tpu_address=FLAGS.tpu)
  if strategy:
    print('***** Number of cores used : ', strategy.num_replicas_in_sync)

  run_bert_pretrain(strategy)
Esempio n. 19
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    superglue_flags.validate_flags(FLAGS, file_exists_fn=tf.io.gfile.exists)

    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
    distribution_strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)

    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))

    with distribution_strategy.scope():
        task = None
        if 'train_eval' in FLAGS.mode:
            logging.info('Starting training and eval...')
            logging.info('Model dir: %s', FLAGS.model_dir)

            exp_config = _get_exp_config(input_meta_data=input_meta_data,
                                         exp_config_files=FLAGS.config_file)
            train_utils.serialize_config(exp_config, FLAGS.model_dir)
            task = task_factory.get_task(exp_config.task,
                                         logging_dir=FLAGS.model_dir)
            train_lib.run_experiment(
                distribution_strategy=distribution_strategy,
                task=task,
                mode='train_and_eval',
                params=exp_config,
                model_dir=FLAGS.model_dir)

        if 'predict' in FLAGS.mode:
            logging.info('Starting predict...')
            # When mode is `predict`, `task` will be None.
            if task is None:
                exp_config = _get_exp_config(input_meta_data=input_meta_data,
                                             exp_config_files=[
                                                 os.path.join(
                                                     FLAGS.model_dir,
                                                     'params.yaml')
                                             ])
                task = task_factory.get_task(exp_config.task,
                                             logging_dir=FLAGS.model_dir)
            _write_submission_file(task, input_meta_data['max_seq_length'])
Esempio n. 20
0
  def __init__(self, strategy_type=None, strategy_config=None):
    _ = distribute_utils.configure_cluster(strategy_config.worker_hosts,
                                           strategy_config.task_index)
    """Constructor.

    Args:
      strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'.
        If None, the user is responsible to set the strategy before calling
        build_executor(...).
      strategy_config: necessary config for constructing the proper Strategy.
        Check strategy_flags_dict() for examples of the structure.
    """
    self._strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=strategy_type,
        num_gpus=strategy_config.num_gpus,
        all_reduce_alg=strategy_config.all_reduce_alg,
        num_packs=strategy_config.num_packs,
        tpu_address=strategy_config.tpu)
Esempio n. 21
0
def get_v1_distribution_strategy(params):
    """Returns the distribution strategy to use."""
    if params["use_tpu"]:
        # Some of the networking libraries are quite chatty.
        for name in [
                "googleapiclient.discovery", "googleapiclient.discovery_cache",
                "oauth2client.transport"
        ]:
            logging.getLogger(name).setLevel(logging.ERROR)

        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=params["tpu"],
            zone=params["tpu_zone"],
            project=params["tpu_gcp_project"],
            coordinator_name="coordinator")

        logging.info("Issuing reset command to TPU to ensure a clean state.")
        tf.Session.reset(tpu_cluster_resolver.get_master())

        # Estimator looks at the master it connects to for MonitoredTrainingSession
        # by reading the `TF_CONFIG` environment variable, and the coordinator
        # is used by StreamingFilesDataset.
        tf_config_env = {
            "session_master":
            tpu_cluster_resolver.get_master(),
            "eval_session_master":
            tpu_cluster_resolver.get_master(),
            "coordinator":
            tpu_cluster_resolver.cluster_spec().as_dict()["coordinator"]
        }
        os.environ["TF_CONFIG"] = json.dumps(tf_config_env)

        distribution = tf.distribute.experimental.TPUStrategy(
            tpu_cluster_resolver, steps_per_run=100)

    else:
        distribution = distribute_utils.get_distribution_strategy(
            num_gpus=params["num_gpus"])

    return distribution
Esempio n. 22
0
def main(_):
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
    print(FLAGS.experiment)
    params = train_utils.parse_configuration(FLAGS)

    model_dir = FLAGS.model_dir
    if 'train' in FLAGS.mode:
        # Pure eval modes do not output yaml files. Otherwise continuous eval job
        # may race against the train job for writing the same file.
        train_utils.serialize_config(params, model_dir)

    # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
    # can have significant impact on model speeds by utilizing float16 in case of
    # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
    # dtype is float16
    if params.runtime.mixed_precision_dtype:
        performance.set_mixed_precision_policy(
            params.runtime.mixed_precision_dtype, params.runtime.loss_scale)
    if params.runtime.worker_hosts != '' and params.runtime.worker_hosts is not None:
        num_workers = distribute_utils.configure_cluster(
            worker_hosts=params.runtime.worker_hosts,
            task_index=params.runtime.task_index)
        print(num_workers)
    distribution_strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu)

    with distribution_strategy.scope():
        task = task_factory.get_task(params.task, logging_dir=model_dir)

    train_lib.run_experiment(distribution_strategy=distribution_strategy,
                             task=task,
                             mode=FLAGS.mode,
                             params=params,
                             model_dir=model_dir)
Esempio n. 23
0
def train_and_eval(
    params: base_configs.ExperimentConfig,
    strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]:
  """Runs the train and eval path using compile/fit."""
  logging.info('Running train and eval.')

  distribute_utils.configure_cluster(params.runtime.worker_hosts,
                                     params.runtime.task_index)

  # Note: for TPUs, strategy and scope should be created before the dataset
  strategy = strategy_override or distribute_utils.get_distribution_strategy(
      distribution_strategy=params.runtime.distribution_strategy,
      all_reduce_alg=params.runtime.all_reduce_alg,
      num_gpus=params.runtime.num_gpus,
      tpu_address=params.runtime.tpu)

  strategy_scope = distribute_utils.get_strategy_scope(strategy)

  logging.info('Detected %d devices.',
               strategy.num_replicas_in_sync if strategy else 1)

  label_smoothing = params.model.loss.label_smoothing
  one_hot = label_smoothing and label_smoothing > 0

  builders = _get_dataset_builders(params, strategy, one_hot)
  datasets = [
      builder.build(strategy) if builder else None for builder in builders
  ]

  # Unpack datasets and builders based on train/val/test splits
  train_builder, validation_builder = builders  # pylint: disable=unbalanced-tuple-unpacking
  train_dataset, validation_dataset = datasets

  train_epochs = params.train.epochs
  train_steps = params.train.steps or train_builder.num_steps
  validation_steps = params.evaluation.steps or validation_builder.num_steps

  initialize(params, train_builder)

  logging.info('Global batch size: %d', train_builder.global_batch_size)

  with strategy_scope:
    model_params = params.model.model_params.as_dict()
    model = get_models()[params.model.name](**model_params)
    learning_rate = optimizer_factory.build_learning_rate(
        params=params.model.learning_rate,
        batch_size=train_builder.global_batch_size,
        train_epochs=train_epochs,
        train_steps=train_steps)
    optimizer = optimizer_factory.build_optimizer(
        optimizer_name=params.model.optimizer.name,
        base_learning_rate=learning_rate,
        params=params.model.optimizer.as_dict(),
        model=model)
    optimizer = performance.configure_optimizer(
        optimizer,
        use_float16=train_builder.dtype == 'float16',
        loss_scale=get_loss_scale(params))

    metrics_map = _get_metrics(one_hot)
    metrics = [metrics_map[metric] for metric in params.train.metrics]
    steps_per_loop = train_steps if params.train.set_epoch_loop else 1

    if one_hot:
      loss_obj = tf.keras.losses.CategoricalCrossentropy(
          label_smoothing=params.model.loss.label_smoothing)
    else:
      loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
    model.compile(
        optimizer=optimizer,
        loss=loss_obj,
        metrics=metrics,
        steps_per_execution=steps_per_loop)

    initial_epoch = 0
    if params.train.resume_checkpoint:
      initial_epoch = resume_from_checkpoint(
          model=model, model_dir=params.model_dir, train_steps=train_steps)

    callbacks = custom_callbacks.get_callbacks(
        model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
        include_tensorboard=params.train.callbacks.enable_tensorboard,
        time_history=params.train.callbacks.enable_time_history,
        track_lr=params.train.tensorboard.track_lr,
        write_model_weights=params.train.tensorboard.write_model_weights,
        initial_step=initial_epoch * train_steps,
        batch_size=train_builder.global_batch_size,
        log_steps=params.train.time_history.log_steps,
        model_dir=params.model_dir,
        backup_and_restore=params.train.callbacks.enable_backup_and_restore)

  serialize_config(params=params, model_dir=params.model_dir)

  if params.evaluation.skip_eval:
    validation_kwargs = {}
  else:
    validation_kwargs = {
        'validation_data': validation_dataset,
        'validation_steps': validation_steps,
        'validation_freq': params.evaluation.epochs_between_evals,
    }

  history = model.fit(
      train_dataset,
      epochs=train_epochs,
      steps_per_epoch=train_steps,
      initial_epoch=initial_epoch,
      callbacks=callbacks,
      verbose=2,
      **validation_kwargs)

  validation_output = None
  if not params.evaluation.skip_eval:
    validation_output = model.evaluate(
        validation_dataset, steps=validation_steps, verbose=2)

  # TODO(dankondratyuk): eval and save final test accuracy
  stats = common.build_stats(history, validation_output, callbacks)
  return stats
Esempio n. 24
0
 def test_one_device_strategy_gpu(self):
     ds = distribute_utils.get_distribution_strategy(num_gpus=1)
     self.assertEquals(ds.num_replicas_in_sync, 1)
     self.assertEquals(len(ds.extended.worker_devices), 1)
     self.assertIn('GPU', ds.extended.worker_devices[0])
Esempio n. 25
0
 def test_mirrored_strategy(self):
     ds = distribute_utils.get_distribution_strategy(num_gpus=5)
     self.assertEquals(ds.num_replicas_in_sync, 5)
     self.assertEquals(len(ds.extended.worker_devices), 5)
     for device in ds.extended.worker_devices:
         self.assertIn('GPU', device)
Esempio n. 26
0
def run_executor(params,
                 mode,
                 checkpoint_path=None,
                 train_input_fn=None,
                 eval_input_fn=None,
                 callbacks=None,
                 prebuilt_strategy=None):
    """Runs the object detection model on distribution strategy defined by the user."""

    if params.architecture.use_bfloat16:
        policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
            'mixed_bfloat16')
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

    model_builder = model_factory.model_generator(params)

    if prebuilt_strategy is not None:
        strategy = prebuilt_strategy
    else:
        strategy_config = params.strategy_config
        distribute_utils.configure_cluster(strategy_config.worker_hosts,
                                           strategy_config.task_index)
        strategy = distribute_utils.get_distribution_strategy(
            distribution_strategy=params.strategy_type,
            num_gpus=strategy_config.num_gpus,
            all_reduce_alg=strategy_config.all_reduce_alg,
            num_packs=strategy_config.num_packs,
            tpu_address=strategy_config.tpu)

    num_workers = int(strategy.num_replicas_in_sync + 7) // 8
    is_multi_host = (int(num_workers) >= 2)

    if mode == 'train':

        def _model_fn(params):
            return model_builder.build_model(params, mode=ModeKeys.TRAIN)

        logging.info(
            'Train num_replicas_in_sync %d num_workers %d is_multi_host %s',
            strategy.num_replicas_in_sync, num_workers, is_multi_host)

        dist_executor = DetectionDistributedExecutor(
            strategy=strategy,
            params=params,
            model_fn=_model_fn,
            loss_fn=model_builder.build_loss_fn,
            is_multi_host=is_multi_host,
            predict_post_process_fn=model_builder.post_processing,
            trainable_variables_filter=model_builder.
            make_filter_trainable_variables_fn())

        if is_multi_host:
            train_input_fn = functools.partial(
                train_input_fn,
                batch_size=params.train.batch_size //
                strategy.num_replicas_in_sync)

        return dist_executor.train(
            train_input_fn=train_input_fn,
            model_dir=params.model_dir,
            iterations_per_loop=params.train.iterations_per_loop,
            total_steps=params.train.total_steps,
            init_checkpoint=model_builder.make_restore_checkpoint_fn(),
            custom_callbacks=callbacks,
            save_config=True)
    elif mode == 'eval' or mode == 'eval_once':

        def _model_fn(params):
            return model_builder.build_model(params,
                                             mode=ModeKeys.PREDICT_WITH_GT)

        logging.info(
            'Eval num_replicas_in_sync %d num_workers %d is_multi_host %s',
            strategy.num_replicas_in_sync, num_workers, is_multi_host)

        if is_multi_host:
            eval_input_fn = functools.partial(
                eval_input_fn,
                batch_size=params.eval.batch_size //
                strategy.num_replicas_in_sync)

        dist_executor = DetectionDistributedExecutor(
            strategy=strategy,
            params=params,
            model_fn=_model_fn,
            loss_fn=model_builder.build_loss_fn,
            is_multi_host=is_multi_host,
            predict_post_process_fn=model_builder.post_processing,
            trainable_variables_filter=model_builder.
            make_filter_trainable_variables_fn())

        if mode == 'eval':
            results = dist_executor.evaluate_from_model_dir(
                model_dir=params.model_dir,
                eval_input_fn=eval_input_fn,
                eval_metric_fn=model_builder.eval_metrics,
                eval_timeout=params.eval.eval_timeout,
                min_eval_interval=params.eval.min_eval_interval,
                total_steps=params.train.total_steps)
        else:
            # Run evaluation once for a single checkpoint.
            if not checkpoint_path:
                raise ValueError('checkpoint_path cannot be empty.')
            if tf.io.gfile.isdir(checkpoint_path):
                checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
            summary_writer = executor.SummaryWriter(params.model_dir, 'eval')
            results, _ = dist_executor.evaluate_checkpoint(
                checkpoint_path=checkpoint_path,
                eval_input_fn=eval_input_fn,
                eval_metric_fn=model_builder.eval_metrics,
                summary_writer=summary_writer)
        for k, v in results.items():
            logging.info('Final eval metric %s: %f', k, v)
        return results
    else:
        raise ValueError('Mode not found: %s.' % mode)
def run(flags_obj):
    """Run ResNet ImageNet training and eval loop using custom training loops.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.

  Returns:
    Dictionary of training and eval stats.
  """
    keras_utils.set_session_config()
    performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))

    if tf.config.list_physical_devices('GPU'):
        if flags_obj.tf_gpu_thread_mode:
            keras_utils.set_gpu_thread_mode_and_count(
                per_gpu_thread_count=flags_obj.per_gpu_thread_count,
                gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
                num_gpus=flags_obj.num_gpus,
                datasets_num_private_threads=flags_obj.
                datasets_num_private_threads)
        common.set_cudnn_batchnorm_mode()

    data_format = flags_obj.data_format
    if data_format is None:
        data_format = ('channels_first'
                       if tf.config.list_physical_devices('GPU') else
                       'channels_last')
    tf.keras.backend.set_image_data_format(data_format)

    strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu)

    per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
        flags_obj)
    if flags_obj.steps_per_loop is None:
        steps_per_loop = per_epoch_steps
    elif flags_obj.steps_per_loop > per_epoch_steps:
        steps_per_loop = per_epoch_steps
        logging.warn('Setting steps_per_loop to %d to respect epoch boundary.',
                     steps_per_loop)
    else:
        steps_per_loop = flags_obj.steps_per_loop

    logging.info(
        'Training %d epochs, each epoch has %d steps, '
        'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
        train_epochs * per_epoch_steps, eval_steps)

    time_callback = keras_utils.TimeHistory(
        flags_obj.batch_size,
        flags_obj.log_steps,
        logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
    with distribute_utils.get_strategy_scope(strategy):
        runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
                                                  per_epoch_steps)

    eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
    checkpoint_interval = (steps_per_loop * 5
                           if flags_obj.enable_checkpoint_and_export else None)
    summary_interval = steps_per_loop if flags_obj.enable_tensorboard else None

    checkpoint_manager = tf.train.CheckpointManager(
        runnable.checkpoint,
        directory=flags_obj.model_dir,
        max_to_keep=10,
        step_counter=runnable.global_step,
        checkpoint_interval=checkpoint_interval)

    resnet_controller = orbit.Controller(
        strategy=strategy,
        trainer=runnable,
        evaluator=runnable if not flags_obj.skip_eval else None,
        global_step=runnable.global_step,
        steps_per_loop=steps_per_loop,
        checkpoint_manager=checkpoint_manager,
        summary_interval=summary_interval,
        summary_dir=flags_obj.model_dir,
        eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval'))

    time_callback.on_train_begin()
    if not flags_obj.skip_eval:
        resnet_controller.train_and_evaluate(train_steps=per_epoch_steps *
                                             train_epochs,
                                             eval_steps=eval_steps,
                                             eval_interval=eval_interval)
    else:
        resnet_controller.train(steps=per_epoch_steps * train_epochs)
    time_callback.on_train_end()

    stats = build_stats(runnable, time_callback)
    return stats
Esempio n. 28
0
def main(unused_argv):
    del unused_argv
    num_hosts = 1
    strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.strategy_type, tpu_address=FLAGS.tpu)
    if FLAGS.strategy_type == "tpu":
        num_hosts = strategy.extended.num_hosts
    if strategy:
        logging.info("***** Number of cores used : %d",
                     strategy.num_replicas_in_sync)
        logging.info("***** Number of hosts used : %d", num_hosts)
    online_masking_config = data_utils.OnlineMaskingConfig(
        sample_strategy=FLAGS.sample_strategy,
        max_num_tokens=FLAGS.max_num_tokens,
        min_num_tokens=FLAGS.min_num_tokens,
        max_num_words=FLAGS.max_num_words,
        min_num_words=FLAGS.min_num_words)

    train_input_fn = functools.partial(
        data_utils.get_pretrain_input_data, FLAGS.train_batch_size,
        FLAGS.seq_len, strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len,
        FLAGS.perm_size, FLAGS.leak_ratio, FLAGS.num_predict, FLAGS.uncased,
        online_masking_config, num_hosts)

    total_training_steps = FLAGS.train_steps

    steps_per_loop = FLAGS.iterations

    optimizer, learning_rate_fn = optimization.create_optimizer(
        init_lr=FLAGS.learning_rate,
        num_train_steps=total_training_steps,
        num_warmup_steps=FLAGS.warmup_steps,
        min_lr_ratio=FLAGS.min_lr_ratio,
        adam_epsilon=FLAGS.adam_epsilon,
        weight_decay_rate=FLAGS.weight_decay_rate)

    model_config = xlnet_config.XLNetConfig(FLAGS)
    run_config = xlnet_config.create_run_config(True, False, FLAGS)
    input_meta_data = {}
    input_meta_data["d_model"] = FLAGS.d_model
    input_meta_data["mem_len"] = FLAGS.mem_len
    input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                                 strategy.num_replicas_in_sync)
    input_meta_data["n_layer"] = FLAGS.n_layer
    input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
    model_fn = functools.partial(get_pretrainxlnet_model, model_config,
                                 run_config)

    model = training_utils.train(
        strategy=strategy,
        model_fn=model_fn,
        input_meta_data=input_meta_data,
        eval_fn=None,
        metric_fn=None,
        train_input_fn=train_input_fn,
        init_checkpoint=FLAGS.init_checkpoint,
        init_from_transformerxl=FLAGS.init_from_transformerxl,
        total_training_steps=total_training_steps,
        steps_per_loop=steps_per_loop,
        optimizer=optimizer,
        learning_rate_fn=learning_rate_fn,
        model_dir=FLAGS.model_dir,
        save_steps=FLAGS.save_steps)

    # Export transformer-xl model checkpoint to be used in finetuning.
    checkpoint = tf.train.Checkpoint(transformer_xl=model.transformerxl_model)
    saved_path = checkpoint.save(
        os.path.join(FLAGS.model_dir, "pretrained/transformer_xl.ckpt"))
    logging.info(
        "Exporting the transformer-xl model as a new TF checkpoint: %s",
        saved_path)
Esempio n. 29
0
def run_continuous_finetune(
    mode: str,
    params: config_definitions.ExperimentConfig,
    model_dir: str,
    run_post_eval: bool = False,
    pretrain_steps: Optional[int] = None,
) -> Mapping[str, Any]:
    """Run modes with continuous training.

  Currently only supports continuous_train_and_eval.

  Args:
    mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a
      checkpoint directory. Once a new checkpoint is discovered, loads the
      checkpoint, finetune the model by training it (probably on another dataset
      or with another task), then evaluate the finetuned model.
    params: ExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.
    run_post_eval: Whether to run post eval once after training, metrics logs
      are returned.
    pretrain_steps: Optional, the number of total training steps for the
      pretraining job.

  Returns:
    eval logs: returns eval metrics logs when run_post_eval is set to True,
      othewise, returns {}.
  """

    assert mode == 'continuous_train_and_eval', (
        'Only continuous_train_and_eval is supported by continuous_finetune. '
        'Got mode: {}'.format(mode))

    # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
    # can have significant impact on model speeds by utilizing float16 in case of
    # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
    # dtype is float16
    if params.runtime.mixed_precision_dtype:
        performance.set_mixed_precision_policy(
            params.runtime.mixed_precision_dtype, params.runtime.loss_scale)
    distribution_strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu)

    retry_times = 0
    while not tf.io.gfile.isdir(params.task.init_checkpoint):
        # Wait for the init_checkpoint directory to be created.
        if retry_times >= 60:
            raise ValueError(
                'ExperimentConfig.task.init_checkpoint must be a directory for '
                'continuous_train_and_eval mode.')
        retry_times += 1
        time.sleep(60)

    summary_writer = tf.summary.create_file_writer(
        os.path.join(model_dir, 'eval'))

    global_step = 0

    def timeout_fn():
        if pretrain_steps and global_step < pretrain_steps:
            # Keeps waiting for another timeout period.
            logging.info(
                'Continue waiting for new checkpoint as current pretrain '
                'global_step=%d and target is %d.', global_step,
                pretrain_steps)
            return False
        # Quits the loop.
        return True

    for pretrain_ckpt in tf.train.checkpoints_iterator(
            checkpoint_dir=params.task.init_checkpoint,
            min_interval_secs=10,
            timeout=params.trainer.continuous_eval_timeout,
            timeout_fn=timeout_fn):
        with distribution_strategy.scope():
            global_step = train_utils.read_global_step_from_checkpoint(
                pretrain_ckpt)
        # Replaces params.task.init_checkpoint to make sure that we load
        # exactly this pretrain checkpoint.
        if params.trainer.best_checkpoint_export_subdir:
            best_ckpt_subdir = '{}_{}'.format(
                params.trainer.best_checkpoint_export_subdir, global_step)
            params_replaced = params.replace(
                task={'init_checkpoint': pretrain_ckpt},
                trainer={'best_checkpoint_export_subdir': best_ckpt_subdir})
        else:
            params_replaced = params.replace(
                task={'init_checkpoint': pretrain_ckpt})
        params_replaced.lock()
        logging.info('Running finetuning with params: %s', params_replaced)

        with distribution_strategy.scope():
            if isinstance(params, configs.MultiEvalExperimentConfig):
                task = task_factory.get_task(params_replaced.task)
                eval_tasks = multitask.MultiTask.from_config(
                    params_replaced.eval_tasks)
                (_, eval_metrics
                 ) = multitask_train_lib.run_experiment_wtih_multitask_eval(
                     distribution_strategy=distribution_strategy,
                     train_task=task,
                     eval_tasks=eval_tasks,
                     mode='train_and_eval',
                     params=params_replaced,
                     model_dir=model_dir,
                     run_post_eval=True,
                     save_summary=False)
            else:
                task = task_factory.get_task(params_replaced.task,
                                             logging_dir=model_dir)
                _, eval_metrics = train_lib.run_experiment(
                    distribution_strategy=distribution_strategy,
                    task=task,
                    mode='train_and_eval',
                    params=params_replaced,
                    model_dir=model_dir,
                    run_post_eval=True,
                    save_summary=False)
        logging.info('Evaluation finished. Pretrain global_step: %d',
                     global_step)
        train_utils.write_json_summary(model_dir, global_step, eval_metrics)

        if not os.path.basename(model_dir):  # if model_dir.endswith('/')
            summary_grp = os.path.dirname(model_dir) + '_' + task.name
        else:
            summary_grp = os.path.basename(model_dir) + '_' + task.name
        summaries = {}
        for name, value in _flatten_dict(eval_metrics).items():
            summaries[summary_grp + '/' + '-'.join(name)] = value
        train_utils.write_summary(summary_writer, global_step, summaries)

        train_utils.remove_ckpts(model_dir)
        # In TF2, the resource life cycle is bound with the python object life
        # cycle. Force trigger python garbage collection here so those resources
        # can be deallocated in time, so it doesn't cause OOM when allocating new
        # objects.
        # TODO(b/169178664): Fix cycle reference in Keras model and revisit to see
        # if we need gc here.
        gc.collect()

    if run_post_eval:
        return eval_metrics
    return {}
Esempio n. 30
0
def run_ncf(_):
    """Run NCF training and eval with Keras."""

    keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)

    if FLAGS.seed is not None:
        print("Setting tf seed")
        tf.random.set_seed(FLAGS.seed)

    model_helpers.apply_clean(FLAGS)

    if FLAGS.dtype == "fp16" and FLAGS.fp16_implementation == "keras":
        tf.keras.mixed_precision.set_global_policy("mixed_float16")

    strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)

    params = ncf_common.parse_flags(FLAGS)
    params["distribute_strategy"] = strategy
    params["use_tpu"] = (FLAGS.distribution_strategy == "tpu")

    if params["use_tpu"] and not params["keras_use_ctl"]:
        logging.error(
            "Custom training loop must be used when using TPUStrategy.")
        return

    batch_size = params["batch_size"]
    time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
    callbacks = [time_callback]

    producer, input_meta_data = None, None
    generate_input_online = params["train_dataset_path"] is None

    if generate_input_online:
        # Start data producing thread.
        num_users, num_items, _, _, producer = ncf_common.get_inputs(params)
        producer.start()
        per_epoch_callback = IncrementEpochCallback(producer)
        callbacks.append(per_epoch_callback)
    else:
        assert params["eval_dataset_path"] and params["input_meta_data_path"]
        with tf.io.gfile.GFile(params["input_meta_data_path"], "rb") as reader:
            input_meta_data = json.loads(reader.read().decode("utf-8"))
            num_users = input_meta_data["num_users"]
            num_items = input_meta_data["num_items"]

    params["num_users"], params["num_items"] = num_users, num_items

    if FLAGS.early_stopping:
        early_stopping_callback = CustomEarlyStopping(
            "val_HR_METRIC", desired_value=FLAGS.hr_threshold)
        callbacks.append(early_stopping_callback)

    (train_input_dataset, eval_input_dataset,
     num_train_steps, num_eval_steps) = \
      (ncf_input_pipeline.create_ncf_input_data(
          params, producer, input_meta_data, strategy))
    steps_per_epoch = None if generate_input_online else num_train_steps

    with distribute_utils.get_strategy_scope(strategy):
        keras_model = _get_keras_model(params)
        optimizer = tf.keras.optimizers.Adam(
            learning_rate=params["learning_rate"],
            beta_1=params["beta1"],
            beta_2=params["beta2"],
            epsilon=params["epsilon"])
        if FLAGS.fp16_implementation == "graph_rewrite":
            optimizer = \
              tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
                  optimizer,
                  loss_scale=flags_core.get_loss_scale(FLAGS,
                                                       default_for_fp16="dynamic"))
        elif FLAGS.dtype == "fp16":
            loss_scale = flags_core.get_loss_scale(FLAGS,
                                                   default_for_fp16="dynamic")
            # Note Model.compile automatically wraps the optimizer with a
            # LossScaleOptimizer using dynamic loss scaling. We explicitly wrap it
            # here for the case where a custom training loop or fixed loss scale is
            # used.
            if loss_scale == "dynamic":
                optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
                    optimizer)
            else:
                optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
                    optimizer, dynamic=False, initial_scale=loss_scale)

        if params["keras_use_ctl"]:
            train_loss, eval_results = run_ncf_custom_training(
                params,
                strategy,
                keras_model,
                optimizer,
                callbacks,
                train_input_dataset,
                eval_input_dataset,
                num_train_steps,
                num_eval_steps,
                generate_input_online=generate_input_online)
        else:
            keras_model.compile(optimizer=optimizer,
                                run_eagerly=FLAGS.run_eagerly)

            if not FLAGS.ml_perf:
                # Create Tensorboard summary and checkpoint callbacks.
                summary_dir = os.path.join(FLAGS.model_dir, "summaries")
                summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
                checkpoint_path = os.path.join(FLAGS.model_dir, "checkpoint")
                checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
                    checkpoint_path, save_weights_only=True)

                callbacks += [summary_callback, checkpoint_callback]

            history = keras_model.fit(train_input_dataset,
                                      epochs=FLAGS.train_epochs,
                                      steps_per_epoch=steps_per_epoch,
                                      callbacks=callbacks,
                                      validation_data=eval_input_dataset,
                                      validation_steps=num_eval_steps,
                                      verbose=2)

            logging.info("Training done. Start evaluating")

            eval_loss_and_metrics = keras_model.evaluate(eval_input_dataset,
                                                         steps=num_eval_steps,
                                                         verbose=2)

            logging.info("Keras evaluation is done.")

            # Keras evaluate() API returns scalar loss and metric values from
            # evaluation as a list. Here, the returned list would contain
            # [evaluation loss, hr sum, hr count].
            eval_hit_rate = eval_loss_and_metrics[1] / eval_loss_and_metrics[2]

            # Format evaluation result into [eval loss, eval hit accuracy].
            eval_results = [eval_loss_and_metrics[0], eval_hit_rate]

            if history and history.history:
                train_history = history.history
                train_loss = train_history["loss"][-1]

    stats = build_stats(train_loss, eval_results, time_callback)
    return stats