コード例 #1
0
    def train(self,
              train_input_fn,
              run_eval_after_train=False,
              eval_input_fn=None):
        """Run distributed training on Mask RCNN model."""

        self._save_config()
        train_run_config = self.build_strategy_configuration('train')
        train_params = self.build_model_parameters('train')
        train_estimator = self.build_mask_rcnn_estimator(
            train_params, train_run_config, 'train')

        with dump_callback():
            train_estimator.train(
                input_fn=train_input_fn,
                max_steps=self._runtime_config.total_steps,
                hooks=get_training_hooks(
                    mode="train",
                    model_dir=self._runtime_config.model_dir,
                    checkpoint_path=self._runtime_config.checkpoint,
                    skip_checkpoint_variables=self._runtime_config.
                    skip_checkpoint_variables,
                    batch_size=train_params['batch_size'],
                    save_summary_steps=self._runtime_config.save_summary_steps,
                ))

        if not run_eval_after_train:
            return None

        if eval_input_fn is None:
            raise ValueError(
                'Eval input_fn must be passed to conduct evaluation after training.'
            )

        eval_run_config = self.build_strategy_configuration('eval')
        eval_params = self.build_model_parameters('eval')
        eval_estimator = self.build_mask_rcnn_estimator(
            eval_params, eval_run_config, 'eval')

        last_ckpt = self.get_last_checkpoint_path()
        logging.info("Restoring parameters from %s\n" % last_ckpt)

        eval_results, predictions = evaluation.evaluate(
            eval_estimator,
            eval_input_fn,
            self._runtime_config.eval_samples,
            self._runtime_config.eval_batch_size,
            self._runtime_config.include_mask,
            self._runtime_config.val_json_file,
            report_frequency=self._runtime_config.report_frequency,
            checkpoint_path=last_ckpt)

        output_dir = os.path.join(self._runtime_config.model_dir, 'eval')
        tf.io.gfile.makedirs(output_dir)

        # Summary writer writes out eval metrics.
        self._write_summary(output_dir, eval_results, predictions,
                            self._runtime_config.total_steps)

        return eval_results
コード例 #2
0
def main(_):
    common.initialize_preloading()
    if flags.FLAGS.use_horovod and flags.FLAGS.distribution_strategy != "off":
        raise RuntimeError(
            "Horovod and distribution strategy cannot be used together. Please select one of the scaleout methods."
        )
    if flags.FLAGS.distribution_strategy not in ["off", "hpu"]:
        raise RuntimeError(
            "Currently HPU supports only HPUStrategy, please set --distribution_strategy=hpu or use horovod"
        )
    if flags.FLAGS.use_horovod:
        if flags.FLAGS.horovod_hierarchical_allreduce:
            os.environ['HOROVOD_HIERARCHICAL_ALLREDUCE'] = "1"
        hvd_init()
    else:
        synapse_logger_init()
    load_habana_module()

    if flags.FLAGS.global_seed:
        tf.random.set_seed(flags.FLAGS.global_seed)

    with dump_callback():
        model_helpers.apply_clean(flags.FLAGS)
        with logger.benchmark_context(flags.FLAGS):
            stats = run(flags.FLAGS)
        logging.info('Run stats:\n%s', stats)
コード例 #3
0
ファイル: main.py プロジェクト: HabanaAI/Model-References
def main():
    params = parse_args()
    tf.random.set_seed(params.seed)
    tf.get_logger().setLevel(logging.ERROR)

    params = setup_horovod(params)
    set_flags(params)
    model_dir = prepare_model_dir(params)
    logger = get_logger(params)

    dataset = Dataset(data_dir=params.data_dir,
                      batch_size=params.batch_size,
                      fold_idx=params.fold,
                      n_folds=params.num_folds,
                      params=params,
                      seed=params.seed)

    estimator = build_estimator(params, model_dir)

    if params.tensorboard_logging and (params.worker_id == 0
                                       or params.log_all_workers):
        from TensorFlow.common.tb_utils import write_hparams_v1
        write_hparams_v1(params.log_dir, vars(params))

    if not params.benchmark:
        params.max_steps = params.max_steps // params.num_workers
    if 'train' in params.exec_mode:
        with dump_callback(params.dump_config):
            training_hooks = get_hooks(params, logger)
            dataset_fn = dataset.synth_train_fn if params.synth_data else dataset.train_fn

            estimator.train(input_fn=dataset_fn,
                            steps=params.max_steps,
                            hooks=training_hooks)

    if 'evaluate' in params.exec_mode:
        result = estimator.evaluate(input_fn=dataset.eval_fn,
                                    steps=dataset.eval_size)
        data = parse_evaluation_results(result)
        if params.worker_id == 0:
            logger.log(step=(), data=data)

    if 'predict' == params.exec_mode:
        inference_hooks = get_hooks(params, logger)
        if params.worker_id == 0:
            count = 1 if not params.benchmark else 2 * params.warmup_steps * params.batch_size // dataset.test_size
            predictions = estimator.predict(input_fn=lambda: dataset.test_fn(
                count=count, drop_remainder=params.benchmark),
                                            hooks=inference_hooks)

            for idx, p in enumerate(predictions):
                volume = p['predictions']
                if not params.benchmark:
                    np.save(
                        os.path.join(params.model_dir,
                                     "vol_{}.npy".format(idx)), volume)
コード例 #4
0
def main(_):
    common.initialize_preloading()
    if flags.FLAGS.use_horovod:
        hvd_init()
    else:
        synapse_logger_init()
    load_habana_module()

    with dump_callback():
        model_helpers.apply_clean(flags.FLAGS)
        with logger.benchmark_context(flags.FLAGS):
            stats = run(flags.FLAGS)
        logging.info('Run stats:\n%s', stats)
コード例 #5
0
ファイル: unet2d.py プロジェクト: HabanaAI/Model-References
def main():
    """
    Starting point of the application
    """
    params = parse_args(description="UNet-medical")
    if params.use_horovod:
        hvd_init()
    set_flags(params)

    model_dir = prepare_model_dir(params)
    params.model_dir = model_dir
    logger = get_logger(params)

    tb_logger = None
    if params.tensorboard_logging:
        log_dir = params.log_dir
        if horovod_enabled() and params.log_all_workers:
            log_dir = os.path.join(log_dir, f'worker_{hvd_rank()}')
        tb_logger = namedtuple('TBSummaryWriters', 'train_writer eval_writer')(
            tf.summary.create_file_writer(log_dir),
            tf.summary.create_file_writer(os.path.join(log_dir, 'eval')))

    model = Unet()

    dataset = Dataset(data_dir=params.data_dir,
                      batch_size=params.batch_size,
                      fold=params.fold,
                      augment=params.augment,
                      hpu_id=hvd_rank() if horovod_enabled() else 0,
                      num_hpus=hvd_size() if horovod_enabled() else 1,
                      seed=params.seed)

    if 'train' in params.exec_mode:
        with dump_callback(params.dump_config):
            train(params, model, dataset, logger, tb_logger)

    if 'evaluate' in params.exec_mode:
        evaluate(params, model, dataset, logger, tb_logger)

    if 'predict' in params.exec_mode:
        predict(params, model, dataset, logger)
コード例 #6
0
ファイル: train.py プロジェクト: HabanaAI/Model-References
                                 profile_batch=0,
                                 update_freq=params.save_summary_steps))
    callbacks.append(
        ExamplesPerSecondKerasHookV2(every_n_steps=params.save_summary_steps,
                                     output_dir=log_dir,
                                     batch_size=params.batch_size))

# Prepare model
if params.dtype == 'bf16':
    tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
optimizer = tf.keras.optimizers.Adam(params.lr)
model_kwargs = {'dropout_rate': 0.0} if params.no_dropout else {}
model = T5.from_pretrained(os.path.join(params.data_dir, 't5_base'),
                           **model_kwargs)
model.compile(optimizer=optimizer)

# Run training
steps = params.steps or (len(train_ds) // params.batch_size)
valid_steps = 0 if params.no_eval else (len(valid_ds) // params.val_batch_size)

with dump_callback(params.dump_config):
    model.fit(tf_train_ds.repeat(),
              epochs=params.epochs,
              steps_per_epoch=steps,
              callbacks=callbacks,
              validation_data=tf_valid_ds.repeat(),
              validation_steps=valid_steps)

if not params.no_checkpoints:
    model.save_pretrained(os.path.join(params.model_dir, 'checkpoints'))
コード例 #7
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file)

  if FLAGS.deterministic_run and (albert_config.attention_probs_dropout_prob or albert_config.hidden_dropout_prob):
        albert_config.attention_probs_dropout_prob = 0.0
        albert_config.hidden_dropout_prob = 0.0

  validate_flags_or_throw(albert_config)

  tf.gfile.MakeDirs(FLAGS.output_dir)
  model_dir = FLAGS.output_dir
  if horovod_enabled():
    model_dir = os.path.join(FLAGS.output_dir, "worker_" + str(hvd.rank()))

  tokenizer = fine_tuning_utils.create_vocab(
      vocab_file=FLAGS.vocab_file,
      do_lower_case=FLAGS.do_lower_case,
      spm_model_file=FLAGS.spm_model_file,
      hub_module=FLAGS.albert_hub_module_handle)

  tpu_cluster_resolver = None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = tf.distribute_cluster.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2
  if FLAGS.do_train:
    iterations_per_loop = int(min(FLAGS.iterations_per_loop,
                                  FLAGS.save_checkpoints_steps))
  else:
    iterations_per_loop = FLAGS.iterations_per_loop

  # The Scoped Allocator Optimization is enabled by default unless disabled by a flag.
  if FLAGS.enable_scoped_allocator:
    from tensorflow.core.protobuf import rewriter_config_pb2  # pylint: disable=import-error

    session_config = tf.compat.v1.ConfigProto()
    session_config.graph_options.rewrite_options.scoped_allocator_optimization = rewriter_config_pb2.RewriterConfig.ON

    enable_op = session_config.graph_options.rewrite_options.scoped_allocator_opts.enable_op
    del enable_op[:]
    enable_op.append("HorovodAllreduce")
  else:
    session_config = None

  run_config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=model_dir,
      keep_checkpoint_max=0,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      save_summary_steps=FLAGS.save_summary_steps,
      tpu_config=tf.estimator.tpu.TPUConfig(
          iterations_per_loop=iterations_per_loop,
          num_shards=FLAGS.num_tpu_cores,
          per_host_input_for_training=is_per_host),
      session_config=session_config)

  train_examples = None
  num_train_steps = None
  num_warmup_steps = None

  train_batch_size = FLAGS.train_batch_size
  if horovod_enabled():
    train_batch_size = train_batch_size * hvd.size()

  if FLAGS.do_train:
    train_examples = squad_utils.read_squad_examples(
        input_file=FLAGS.train_file, is_training=True)
    num_train_steps = int(
        len(train_examples) / train_batch_size * FLAGS.num_train_epochs)
    if FLAGS.train_steps > 0:
      num_train_steps = FLAGS.train_steps
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    # Pre-shuffle the input to avoid having to make a very large shuffle
    # buffer in in the `input_fn`.
    rng = random.Random(12345)
    rng.shuffle(train_examples)

  start_index = 0
  end_index = len(train_examples)
  per_worker_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]
  worker_id = 0

  if horovod_enabled():
    per_worker_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record_{}".format(i)) for i in range(hvd.local_size())]
    num_examples_per_rank = len(train_examples) // hvd.size()
    remainder = len(train_examples) % hvd.size()
    worker_id = hvd.rank()
    if worker_id < remainder:
      start_index = worker_id * (num_examples_per_rank + 1)
      end_index = start_index + num_examples_per_rank + 1
    else:
      start_index = worker_id * num_examples_per_rank + remainder
      end_index = start_index + (num_examples_per_rank)

  learning_rate = FLAGS.learning_rate

  model_fn = squad_utils.v1_model_fn_builder(
      albert_config=albert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu,
      use_einsum=FLAGS.use_einsum,
      hub_module=FLAGS.albert_hub_module_handle)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator = tf.estimator.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      predict_batch_size=FLAGS.predict_batch_size)

  write_hparams_v1(FLAGS.output_dir, {
    'batch_size': FLAGS.train_batch_size,
    **{x: getattr(FLAGS, x) for x in FLAGS}
  })

  if FLAGS.do_train:
    # We write to a temporary file to avoid storing very large constant tensors
    # in memory.

    tf.logging.info("***** Running training *****")
    tf.logging.info("  Num orig examples = %d", len(train_examples))
    tf.logging.info("  Num steps = %d", num_train_steps)
    tf.logging.info("  Per-worker batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Total batch size = %d", train_batch_size)

    ## use pre-generated tf_record as input
    if FLAGS.input_file:
      if horovod_enabled():
        per_worker_filenames_temp = [os.path.join(FLAGS.input_file, "train.tf_record") for i in range(hvd.local_size())]
      else:
        per_worker_filenames_temp = [os.path.join(FLAGS.input_file, "train.tf_record")]

      if tf.gfile.Exists(per_worker_filenames_temp[hvd.local_rank() if horovod_enabled() else worker_id]):
        per_worker_filenames = per_worker_filenames_temp

    if not tf.gfile.Exists(per_worker_filenames[hvd.local_rank() if horovod_enabled() else worker_id]):
      train_writer = squad_utils.FeatureWriter(
          filename=per_worker_filenames[hvd.local_rank() if horovod_enabled() else worker_id], is_training=True)
      squad_utils.convert_examples_to_features(
          examples=train_examples[start_index:end_index],
          tokenizer=tokenizer,
          max_seq_length=FLAGS.max_seq_length,
          doc_stride=FLAGS.doc_stride,
          max_query_length=FLAGS.max_query_length,
          is_training=True,
          output_fn=train_writer.process_feature,
          do_lower_case=FLAGS.do_lower_case)
      tf.logging.info("  Num split examples = %d", train_writer.num_features)
      train_writer.close()

      del train_examples

    train_input_fn = squad_utils.input_fn_builder(
        input_file=per_worker_filenames,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True,
        use_tpu=FLAGS.use_tpu,
        bsz=FLAGS.train_batch_size,
        is_v2=False)

    train_hooks = [habana_hooks.PerfLoggingHook(batch_size=train_batch_size, mode="train")]
    if horovod_enabled():
      train_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

    if "range" == os.environ.get("HABANA_SYNAPSE_LOGGER", "False").lower():
      from habana_frameworks.tensorflow.synapse_logger_helpers import SynapseLoggerHook
      begin = 670
      end = begin + 10
      print("Begin: {}".format(begin))
      print("End: {}".format(end))
      train_hooks.append(SynapseLoggerHook(list(range(begin, end)), False))

    with dump_callback():
      estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=train_hooks)

  if FLAGS.do_predict:
    with tf.gfile.Open(FLAGS.predict_file) as predict_file:
      prediction_json = json.load(predict_file)["data"]

    eval_examples = squad_utils.read_squad_examples(
        input_file=FLAGS.predict_file, is_training=False)

    eval_writer = squad_utils.FeatureWriter(
        filename=os.path.join(model_dir, "eval.tf_record"), is_training=False)
    eval_features = []

    def append_feature(feature):
      eval_features.append(feature)
      eval_writer.process_feature(feature)

    squad_utils.convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=FLAGS.max_seq_length,
        doc_stride=FLAGS.doc_stride,
        max_query_length=FLAGS.max_query_length,
        is_training=False,
        output_fn=append_feature,
        do_lower_case=FLAGS.do_lower_case)
    eval_writer.close()

    with tf.gfile.Open(os.path.join(model_dir, "eval_left.tf_record"), "wb") as fout:
      pickle.dump(eval_features, fout)

    tf.logging.info("***** Running predictions *****")
    tf.logging.info("  Num orig examples = %d", len(eval_examples))
    tf.logging.info("  Num split examples = %d", len(eval_features))
    tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    predict_input_fn = squad_utils.input_fn_builder(
        input_file=os.path.join(model_dir, "eval.tf_record"),
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=False,
        use_tpu=FLAGS.use_tpu,
        bsz=FLAGS.predict_batch_size,
        is_v2=False)

    eval_hooks = [habana_hooks.PerfLoggingHook(batch_size=FLAGS.predict_batch_size, mode="eval")]

    def get_result(checkpoint):
      """Evaluate the checkpoint on SQuAD 1.0."""
      # If running eval on the TPU, you will need to specify the number of
      # steps.
      reader = tf.train.NewCheckpointReader(checkpoint)
      global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
      all_results = []
      for result in estimator.predict(
          predict_input_fn, yield_single_examples=True,
          checkpoint_path=checkpoint, hooks=eval_hooks):
        if len(all_results) % 1000 == 0:
          tf.logging.info("Processing example: %d" % (len(all_results)))
        unique_id = int(result["unique_ids"])
        start_log_prob = [float(x) for x in result["start_log_prob"].flat]
        end_log_prob = [float(x) for x in result["end_log_prob"].flat]
        all_results.append(
            squad_utils.RawResult(
                unique_id=unique_id,
                start_log_prob=start_log_prob,
                end_log_prob=end_log_prob))

      output_prediction_file = os.path.join(
          model_dir, "predictions.json")
      output_nbest_file = os.path.join(
          model_dir, "nbest_predictions.json")

      result_dict = {}
      squad_utils.accumulate_predictions_v1(
          result_dict, eval_examples, eval_features,
          all_results, FLAGS.n_best_size, FLAGS.max_answer_length)
      predictions = squad_utils.write_predictions_v1(
          result_dict, eval_examples, eval_features, all_results,
          FLAGS.n_best_size, FLAGS.max_answer_length,
          output_prediction_file, output_nbest_file)

      return squad_utils.evaluate_v1(
          prediction_json, predictions), int(global_step)

    def _find_valid_cands(curr_step):
      filenames = tf.gfile.ListDirectory(model_dir)
      candidates = []
      for filename in filenames:
        if filename.endswith(".index"):
          ckpt_name = filename[:-6]
          idx = ckpt_name.split("-")[-1]
          if idx != "best" and int(idx) > curr_step:
            candidates.append(filename)
      return candidates

    output_eval_file = os.path.join(model_dir, "eval_results.txt")
    checkpoint_path = os.path.join(model_dir, "model.ckpt-best")
    key_name = "f1"
    writer = tf.gfile.GFile(output_eval_file, "w")
    if tf.gfile.Exists(checkpoint_path + ".index"):
      result = get_result(checkpoint_path)
      exact_match = result[0]["exact_match"]
      f1 = result[0]["f1"]
      with TBSummary(os.path.join(model_dir, 'eval')) as summary_writer:
          summary_writer.add_scalar('f1', f1, 0)
          summary_writer.add_scalar('exact_match', exact_match, 0)
      best_perf = result[0][key_name]
      global_step = result[1]
    else:
      global_step = -1
      best_perf = -1
      checkpoint_path = None
    while global_step < num_train_steps:
      steps_and_files = {}
      filenames = tf.gfile.ListDirectory(model_dir)
      for filename in filenames:
        if filename.endswith(".index"):
          ckpt_name = filename[:-6]
          cur_filename = os.path.join(model_dir, ckpt_name)
          if cur_filename.split("-")[-1] == "best":
            continue
          gstep = int(cur_filename.split("-")[-1])
          if gstep not in steps_and_files:
            tf.logging.info("Add {} to eval list.".format(cur_filename))
            steps_and_files[gstep] = cur_filename
      tf.logging.info("found {} files.".format(len(steps_and_files)))
      if not steps_and_files:
        tf.logging.info("found 0 file, global step: {}. Sleeping."
                        .format(global_step))
        time.sleep(60)
      else:
        for ele in sorted(steps_and_files.items()):
          step, checkpoint_path = ele
          if global_step >= step:
            if len(_find_valid_cands(step)) > 1:
              for ext in ["meta", "data-00000-of-00001", "index"]:
                src_ckpt = checkpoint_path + ".{}".format(ext)
                tf.logging.info("removing {}".format(src_ckpt))
                tf.gfile.Remove(src_ckpt)
            continue
          result, global_step = get_result(checkpoint_path)
          exact_match = result["exact_match"]
          f1 = result["f1"]
          with TBSummary(os.path.join(model_dir, 'eval')) as summary_writer:
            summary_writer.add_scalar('f1', f1, 0)
            summary_writer.add_scalar('exact_match', exact_match, 0)
          tf.logging.info("***** Eval results *****")
          for key in sorted(result.keys()):
            tf.logging.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))
          if result[key_name] > best_perf:
            best_perf = result[key_name]
            for ext in ["meta", "data-00000-of-00001", "index"]:
              src_ckpt = checkpoint_path + ".{}".format(ext)
              tgt_ckpt = checkpoint_path.rsplit(
                  "-", 1)[0] + "-best.{}".format(ext)
              tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt))
              tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True)
              writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt))
          writer.write("best {} = {}\n".format(key_name, best_perf))
          tf.logging.info("  best {} = {}\n".format(key_name, best_perf))

          if len(_find_valid_cands(global_step)) > 2:
            for ext in ["meta", "data-00000-of-00001", "index"]:
              src_ckpt = checkpoint_path + ".{}".format(ext)
              tf.logging.info("removing {}".format(src_ckpt))
              tf.gfile.Remove(src_ckpt)
          writer.write("=" * 50 + "\n")

    checkpoint_path = os.path.join(model_dir, "model.ckpt-best")
    result, global_step = get_result(checkpoint_path)
    tf.logging.info("***** Final Eval results *****")
    for key in sorted(result.keys()):
      tf.logging.info("  %s = %s", key, str(result[key]))
      writer.write("%s = %s\n" % (key, str(result[key])))
    writer.write("best perf happened at step: {}".format(global_step))

  if FLAGS.export_dir:
    tf.gfile.MakeDirs(FLAGS.export_dir)
    squad_serving_input_fn = (
        build_squad_serving_input_fn(FLAGS.max_seq_length))
    tf.logging.info("Starting to export model.")
    subfolder = estimator.export_saved_model(
        export_dir_base=os.path.join(FLAGS.export_dir, "saved_model"),
        serving_input_receiver_fn=squad_serving_input_fn)

    tf.logging.info("Starting to export TFLite.")
    converter = tf.lite.TFLiteConverter.from_saved_model(
        subfolder,
        input_arrays=["input_ids", "input_mask", "segment_ids"],
        output_arrays=["start_logits", "end_logits"])
    float_model = converter.convert()
    tflite_file = os.path.join(FLAGS.export_dir, "albert_model.tflite")
    with tf.gfile.GFile(tflite_file, "wb") as f:
      f.write(float_model)
コード例 #8
0
def main(_):
    #tf.disable_v2_behavior() ###
    tf.compat.v1.disable_eager_execution()
    tf.compat.v1.enable_resource_variables()

    # Enable habana bf16 conversion pass
    if FLAGS.dtype == 'bf16':
        os.environ['TF_BF16_CONVERSION'] = flags.FLAGS.bf16_config_path
        FLAGS.precision = 'bf16'
    else:
        os.environ['TF_BF16_CONVERSION'] = "0"

    if FLAGS.use_horovod:
        hvd_init()

    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name,
            is_training=True,
            use_grayscale=FLAGS.use_grayscale)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs

        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #if FLAGS.quantize_delay >= 0:
        #  quantize.create_training_graph(quant_delay=FLAGS.quantize_delay) #for debugging!!

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        if horovod_enabled():
            hvd.broadcast_global_variables(0)
        ###########################
        # Kicks off the training. #
        ###########################
        with dump_callback():
            with logger.benchmark_context(FLAGS):
                eps1 = ExamplesPerSecondKerasHook(FLAGS.log_every_n_steps,
                                                  output_dir=FLAGS.train_dir,
                                                  batch_size=FLAGS.batch_size)

                write_hparams_v1(
                    eps1.writer, {
                        'batch_size': FLAGS.batch_size,
                        **{x: getattr(FLAGS, x)
                           for x in FLAGS}
                    })

                train_step_kwargs = {}
                if FLAGS.max_number_of_steps:
                    should_stop_op = math_ops.greater_equal(
                        global_step, FLAGS.max_number_of_steps)
                else:
                    should_stop_op = constant_op.constant(False)
                train_step_kwargs['should_stop'] = should_stop_op
                if FLAGS.log_every_n_steps > 0:
                    train_step_kwargs['should_log'] = math_ops.equal(
                        math_ops.mod(global_step, FLAGS.log_every_n_steps), 0)

                eps1.on_train_begin()
                train_step_kwargs['EPS'] = eps1

                slim.learning.train(
                    train_tensor,
                    logdir=FLAGS.train_dir,
                    train_step_fn=train_step1,
                    train_step_kwargs=train_step_kwargs,
                    master=FLAGS.master,
                    is_chief=(FLAGS.task == 0),
                    init_fn=_get_init_fn(),
                    summary_op=summary_op,
                    summary_writer=None,
                    number_of_steps=FLAGS.max_number_of_steps,
                    log_every_n_steps=FLAGS.log_every_n_steps,
                    save_summaries_secs=FLAGS.save_summaries_secs,
                    save_interval_secs=FLAGS.save_interval_secs,
                    sync_optimizer=optimizer if FLAGS.sync_replicas else None)
コード例 #9
0
def train(model,
          train_images,
          train_annotations,
          input_height=None,
          input_width=None,
          n_classes=None,
          verify_dataset=True,
          checkpoints_path=None,
          epochs=5,
          batch_size=2,
          validate=False,
          val_images=None,
          val_annotations=None,
          auto_resume_checkpoint=False,
          load_weights=None,
          steps_per_epoch=None,
          val_steps_per_epoch=None,
          gen_use_multiprocessing=False,
          ignore_zero_class=False,
          optimizer_name='adam',
          do_augment=False,
          augmentation_name="aug_all",
          data_type='fp32',
          tb_location=None,
          deterministic=False,
          model_dir=None,
          dump_config=None,
          distributed=False,
          use_upsampling=False,
          loss_type=0,
          train_engine='hpu',
          not_cached=False):

    if train_engine == 'hpu':
        from habana_frameworks.tensorflow import load_habana_module
        load_habana_module()
        print("Loaded HPU modules")
        from TensorFlow.common.debug import dump_callback
        # For Habana Model runner hooks
        from TensorFlow.common.tb_utils import (TensorBoardWithHParamsV2,
                                                ExamplesPerSecondKerasHookV2)
    else:

        class dump_callback(object):
            def __init__(self, file_name):
                pass

            def __enter__(self):
                pass

            def __exit__(self, type, value, traceback):
                pass

    if data_type == 'bf16' and train_engine == 'hpu':
        bf16_json = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                 '../bf16_segnet.json')
        os.environ['TF_BF16_CONVERSION'] = os.environ.get(
            'TF_BF16_CONVERSION', bf16_json)
        print("Setting BF16:", os.getenv('TF_BF16_CONVERSION'))

    shard_id = 0
    num_shards = 1

    if distributed:
        import horovod.tensorflow.keras as hvd
        print("hvd init")
        hvd.init()
        if train_engine == 'gpu':
            gpus = tf.config.experimental.list_physical_devices('GPU')
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            if gpus:
                tf.config.experimental.set_visible_devices(
                    gpus[hvd.local_rank()], 'GPU')
            print("Set memory growth for GPUS")

        shard_id = hvd.rank()
        num_shards = hvd.size()
        if num_shards == 1:
            print(
                "Distributed training requested but horovod init not success")
            exit()

    print("num_shards: " + str(num_shards) + " shard_id: " + str(shard_id))

    from keras_segmentation.models.all_models import model_from_name
    # check if user gives model name instead of the model object
    if isinstance(model, six.string_types):
        # create the model from the name
        assert (n_classes is not None), "Please provide the n_classes"
        if (input_height is not None) and (input_width is not None):
            model = model_from_name[model](n_classes,
                                           input_height=input_height,
                                           input_width=input_width,
                                           batch_size=batch_size,
                                           use_upsampling=use_upsampling,
                                           loss_type=loss_type)
        else:
            model = model_from_name[model](n_classes,
                                           batch_size=batch_size,
                                           use_upsampling=use_upsampling,
                                           loss_type=loss_type)

    #model.save('my_segnet_model.h5')
    n_classes = model.n_classes
    input_height = model.input_height
    input_width = model.input_width
    output_height = model.output_height
    output_width = model.output_width

    if steps_per_epoch is None:
        steps_per_epoch = len(
            os.listdir(train_images)) // (batch_size * num_shards)
    if val_steps_per_epoch is None:
        val_steps_per_epoch = len(os.listdir(val_images)) // batch_size

    print("Steps per epoch: " + str(steps_per_epoch))

    def optimized_xent_loss_custom_grad(ytrue, ypred):
        @tf.custom_gradient
        def loss_without_mean(ytrue, ypred):
            with tf.name_scope("softmax_cross_entropy"):
                logits_t = tf.transpose(ypred,
                                        perm=(0, 1, 3, 2),
                                        name="logits_t")  # BS H N W
                reduce_max = tf.reduce_max(logits_t, 2,
                                           name="reduce_max")  # BS H W
                max_logits = tf.expand_dims(reduce_max, 3)  # BS H W 1
                shifted_logits = tf.subtract(ypred,
                                             max_logits,
                                             name="shifted_logits")  # BS H W N
                exp_shifted_logits = tf.math.exp(
                    shifted_logits, name="exp_shifted_logits")  # BS H W N
                reduce_sum_filter = tf.fill([1, 1, n_classes, 1], 1.0)
                sum_exp = tf.nn.conv2d(exp_shifted_logits,
                                       reduce_sum_filter,
                                       strides=1,
                                       padding="VALID",
                                       name="sum_exp")  # BS H W 1
                log_sum_exp = tf.math.log(sum_exp,
                                          name="log_sum_exp")  # BS H W 1
                shifted_logits2 = tf.nn.conv2d(
                    shifted_logits * ytrue,
                    reduce_sum_filter,
                    strides=1,
                    padding="VALID",
                    name="shifted_logits2")  # BS H W 1
                loss = tf.subtract(log_sum_exp,
                                   shifted_logits2,
                                   name="loss/sub")  # BS H W 1

                def custom_grad(dy):  # dy is BS H W 1
                    with tf.name_scope("gradients/softmax_cross_entropy"):
                        div = tf.math.truediv(exp_shifted_logits,
                                              sum_exp,
                                              name="div")  # BS H W N
                        sub = tf.math.subtract(div, ytrue,
                                               name="sub")  # BS H W N
                        ret = tf.math.multiply(sub, dy, name="mul")
                    return -dy * shifted_logits, ret

                return loss, custom_grad

        return tf.math.reduce_mean(loss_without_mean(ytrue, ypred))

    if validate:
        assert val_images is not None
        assert val_annotations is not None

    if optimizer_name is not None:

        if ignore_zero_class:
            loss_k = masked_categorical_crossentropy
        elif loss_type == 1:
            loss_k = tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True)
        elif loss_type == 2:
            loss_k = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
        else:
            loss_k = optimized_xent_loss_custom_grad

        print(optimizer_name)
        if num_shards > 1:
            optimizer = Adam(lr=LearningRate)
            optimizer_name = hvd.DistributedOptimizer(optimizer)

        model.compile(loss=loss_k,
                      optimizer=optimizer_name,
                      metrics=['accuracy'])

    if checkpoints_path is not None:
        with open(checkpoints_path + "_config.json", "w") as f:
            json.dump(
                {
                    "model_class": model.model_name,
                    "n_classes": n_classes,
                    "input_height": input_height,
                    "input_width": input_width,
                    "output_height": output_height,
                    "output_width": output_width
                }, f)

    if load_weights is not None and len(load_weights) > 0:
        print("Loading weights from ", load_weights)
        status = model.load_weights(load_weights)
        print(status)

    if auto_resume_checkpoint and (checkpoints_path is not None):
        latest_checkpoint = find_latest_checkpoint(checkpoints_path)
        if latest_checkpoint is not None:
            print("Loading the weights from latest checkpoint ",
                  latest_checkpoint)
            model.load_weights(latest_checkpoint)

    if verify_dataset:
        print("Verifying training dataset")
        verified = verify_segmentation_dataset(train_images, train_annotations,
                                               n_classes, deterministic)
        assert verified
        if validate:
            print("Verifying validation dataset")
            verified = verify_segmentation_dataset(val_images, val_annotations,
                                                   n_classes, deterministic)
            assert verified

    if not_cached:
        train_gen = image_segmentation_generator(
            train_images,
            train_annotations,
            batch_size,
            n_classes,
            input_height,
            input_width,
            output_height,
            output_width,
            deterministic,
            do_augment=do_augment,
            augmentation_name=augmentation_name,
            num_shards=num_shards,
            shard_id=shard_id,
            loss_type=loss_type)
    else:
        train_gen = image_segmentation_generator(
            train_images,
            train_annotations,
            1,
            n_classes,
            input_height,
            input_width,
            output_height,
            output_width,
            deterministic,
            do_augment=do_augment,
            augmentation_name=augmentation_name,
            num_shards=num_shards,
            shard_id=shard_id,
            loss_type=loss_type)

        train_gen = cached_image_generator(train_gen, num_shards, shard_id,
                                           batch_size,
                                           len(os.listdir(train_images)),
                                           deterministic)

    callbacks = []

    if num_shards > 1:
        callbacks.append(hvd.callbacks.BroadcastGlobalVariablesCallback(0))
        callbacks.append(hvd.callbacks.MetricAverageCallback())

    callbacks.append(CheckpointsCallback(checkpoints_path))
    #if shard_id == 0:
    #    callbacks.append(ModelCheckpoint( self.checkpoints_path, monitor='loss', verbose=2, mode='min', save_best_only=True, save_weights_only=True))

    if model_dir is not None:
        hparams = {
            "model_name": model,
            "optimizer": optimizer_name,
            "batch_size": batch_size
        }

        if train_engine == 'hpu':
            callbacks += [
                TensorBoardWithHParamsV2(hparams,
                                         log_dir=model_dir,
                                         update_freq=5),
                ExamplesPerSecondKerasHookV2(5,
                                             batch_size=batch_size,
                                             output_dir=model_dir)
            ]

    if tb_location != '':
        tensorboard_callback = TensorBoard(log_dir=tb_location,
                                           histogram_freq=1)
        callbacks.append(tensorboard_callback)
        print("TB:", tb_location)

    if not validate:
        with dump_callback(dump_config):
            start_compilation = time.time()
            model.fit(train_gen, steps_per_epoch=1, epochs=1)
            stop_compilation = time.time()
            history = model.fit(train_gen,
                                steps_per_epoch=steps_per_epoch,
                                epochs=epochs,
                                callbacks=callbacks,
                                verbose=1 if shard_id == 0 else 0)
            stop_training = time.time()
        with open('./trainHistoryDict_' + str(shard_id), 'wb') as file_pi:
            pickle.dump(history.history, file_pi)
        avg_time_per_batch = (stop_training -
                              stop_compilation) / (steps_per_epoch * epochs)
        print('Compile time in seconds:',
              (stop_compilation - start_compilation))
        print('Average time per batch in seconds (leaving out compilation):',
              avg_time_per_batch)
        print('Average time per image in seconds (leaving out compilation)',
              avg_time_per_batch / batch_size)
        print('Average images per sec (leaving out compilation):',
              batch_size / avg_time_per_batch)

        if loss_type == 1:
            print('Eval for LOSS_FUNC_TYPE=1 is WIP')
            exit()

        if shard_id == 0:
            if not_cached:
                val_gen = image_segmentation_generator(val_images,
                                                       val_annotations,
                                                       batch_size,
                                                       n_classes,
                                                       input_height,
                                                       input_width,
                                                       output_height,
                                                       output_width,
                                                       deterministic,
                                                       num_shards=1,
                                                       shard_id=shard_id,
                                                       loss_type=loss_type)
            else:
                val_gen = image_segmentation_generator(val_images,
                                                       val_annotations,
                                                       1,
                                                       n_classes,
                                                       input_height,
                                                       input_width,
                                                       output_height,
                                                       output_width,
                                                       deterministic,
                                                       num_shards=1,
                                                       shard_id=shard_id,
                                                       loss_type=loss_type)
                val_gen = cached_image_generator(val_gen, 1, 0, batch_size,
                                                 len(os.listdir(val_images)))
            f1_metric = FBetaScore(num_classes=n_classes)
            model.compile(loss=model.loss,
                          metrics=[
                              tf.keras.metrics.CategoricalAccuracy(
                                  name="categorical_accuracy", dtype=None),
                              f1_metric
                          ])
            test_loss, test_acc, test_f1 = model.evaluate(
                val_gen, steps=(len(os.listdir(val_images)) // batch_size))
            train_loss, train_acc, train_f1 = model.evaluate(
                train_gen, steps=(len(os.listdir(train_images)) // batch_size))
            print(
                f'test loss : {test_loss}, test accuracy : {test_acc}, test f1 : {test_f1}'
            )
            print(
                f'train loss : {train_loss}, train accuracy : {train_acc}, train f1 : {train_f1}'
            )

    else:
        assert (
            num_shards is
            1), "Only support training with validation with single HPU setup"
        if not_cached:
            val_gen = image_segmentation_generator(val_images,
                                                   val_annotations,
                                                   batch_size,
                                                   n_classes,
                                                   input_height,
                                                   input_width,
                                                   output_height,
                                                   output_width,
                                                   deterministic,
                                                   num_shards=num_shards,
                                                   shard_id=shard_id,
                                                   loss_type=loss_type)
        else:
            val_gen = image_segmentation_generator(val_images,
                                                   val_annotations,
                                                   1,
                                                   n_classes,
                                                   input_height,
                                                   input_width,
                                                   output_height,
                                                   output_width,
                                                   deterministic,
                                                   num_shards=num_shards,
                                                   shard_id=shard_id,
                                                   loss_type=loss_type)
            val_gen = cached_image_generator(val_gen, num_shards, shard_id,
                                             batch_size,
                                             len(os.listdir(val_images)),
                                             deterministic)

        start_compilation = time.time()
        model.fit(train_gen, steps_per_epoch=1, epochs=1)
        stop_compilation = time.time()
        model.fit(train_gen,
                  steps_per_epoch=steps_per_epoch,
                  validation_data=val_gen,
                  validation_steps=val_steps_per_epoch,
                  epochs=epochs,
                  callbacks=callbacks,
                  use_multiprocessing=gen_use_multiprocessing,
                  verbose=1 if shard_id == 0 else 0)
        stop_training = time.time()
        avg_time_per_batch = (stop_training -
                              stop_compilation) / (steps_per_epoch * epochs)
        print('Compile time in seconds:',
              (stop_compilation - start_compilation))
        print('Average time per batch in seconds (leaving out compilation):',
              avg_time_per_batch)
        print('Average time per image in seconds (leaving out compilation)',
              avg_time_per_batch / batch_size)
コード例 #10
0
def main(_):
    with dump_callback(), logger.benchmark_context(
            flags.FLAGS), disable_session_recovery():
        run_imagenet(flags.FLAGS)
コード例 #11
0
ファイル: train.py プロジェクト: HabanaAI/Model-References
def main():
    parser = argparse.ArgumentParser(description=DESCRIPTION)
    parser.add_argument('--dataset', '--dataset_dir', metavar='PATH',
                        default=config.DEFAULT_DATASET_DIR, help='Dataset directory.')
    parser.add_argument('--optimizer', default='sgd',
                        choices=['sgd', 'adam', 'rmsprop'], help='Optimizer.')
    parser.add_argument('-d', '--dtype', default='fp32',
                        choices=['fp32', 'bf16'], help='Data type.')
    parser.add_argument('--batch_size', type=int,
                        default=32, help='Global batch size.')
    parser.add_argument('--lr_sched', default='WarmupCosine', choices=[
                        'linear', 'exp', 'steps', 'constant', 'WarmupCosine'], help='Learning rate scheduler.')
    parser.add_argument('--initial_lr', type=float,
                        default=6e-2, help='Initial learning rate.')
    parser.add_argument('--final_lr', type=float,
                        default=1e-5, help='Final learning rate.')
    parser.add_argument('--warmup_steps', type=int,
                        default=4000, help='Warmup steps.')
    parser.add_argument('--epochs', type=int, default=10,
                        help='Total number of epochs for training.')
    parser.add_argument('--steps_per_epoch', type=int,
                        help='Number of steps for training per epoch, overrides default value.')
    parser.add_argument('--validation_steps', type=int,
                        help='Number of steps for validation, overrides default value.')
    parser.add_argument('--model', default='ViT-B_16',
                        choices=['ViT-B_16', 'ViT-L_16', 'ViT-B_32', 'ViT-L_32'], help='Model.')
    parser.add_argument('--train_subset', default='train',
                        help='Pattern to detect train subset in dataset directory.')
    parser.add_argument('--val_subset', default='validation',
                        help='Pattern to detect validation subset in dataset directory.')
    parser.add_argument('--grad_accum_steps', type=int,
                        default=8, help='Gradient accumulation steps.')
    parser.add_argument('--resume_from_checkpoint_path',
                        metavar='PATH', help='Path to checkpoint to start from.')
    parser.add_argument('--resume_from_epoch', metavar='EPOCH_INDEX',
                        type=int, default=0, help='Initial epoch index.')
    parser.add_argument('--evaluate_checkpoint_path', metavar='PATH',
                        help='Checkpoint path for evaluating the model on --val_subset')
    parser.add_argument('--weights_path', metavar='PATH',
                        help='Path to weights cache directory. ~/.keras is used if not set.')
    parser.add_argument('--deterministic', action='store_true', default=False,
                        help='Enable deterministic behavior, this will also disable data augmentation. --seed must be set.')
    parser.add_argument('--seed', type=int,
                        help='Seed to be used by random functions.')
    parser.add_argument('--device', default='HPU',
                        choices=['CPU', 'HPU'], help='Device type.')
    parser.add_argument('--distributed', action='store_true',
                        default=False, help='Enable distributed training.')
    parser.add_argument('--base_tf_server_port', type=int,
                        default=7850, help='Rank 0 port used by tf.distribute.')
    parser.add_argument('--save_summary_steps', type=int, default=0,
                        help='Steps between saving summaries to TensorBoard.')
    parser.add_argument('--recipe_cache', default='/tmp/vit_recipe_cache',
                        help='Path to recipe cache directory. Set to empty to disable recipe cache. Externally set \'TF_RECIPE_CACHE_PATH\' will override this setting.')
    parser.add_argument(
        '--dump_config', help='Side-by-side config file. Internal, do not use.')
    args = parser.parse_args()

    if args.weights_path is not None:
        config.WEIGHTS_DIR = args.weights_path

    if args.dtype == 'bf16':
        tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')

    if args.device == 'HPU':
        if args.distributed:
            os.environ['TF_HCCL_MEMORY_ALLOWANCE_MB'] = '500'
        from habana_frameworks.tensorflow import load_habana_module
        from habana_frameworks.tensorflow.ops.layer_norm import HabanaLayerNormalization
        load_habana_module()
        tf.keras.layers.LayerNormalization = HabanaLayerNormalization

        # Handle recipe caching.
        recipe_cache = args.recipe_cache
        if 'TF_RECIPE_CACHE_PATH' not in os.environ.keys() and recipe_cache:
            os.environ['TF_RECIPE_CACHE_PATH'] = recipe_cache

        # Clear previous recipe cache.
        if not args.distributed or comm_rank() == 0:
            if os.path.exists(recipe_cache) and os.path.isdir(recipe_cache):
                import shutil
                shutil.rmtree(recipe_cache)
        # Wait for rank 0 to remove cache.
        if args.distributed:
            from mpi4py import MPI
            MPI.COMM_WORLD.Barrier()

    # Handle determinism.
    config.DETERMINISTIC = args.deterministic
    config.SEED = args.seed
    if args.deterministic:
        assert args.seed is not None, "Deterministic behavior require seed to be set."
        tf.config.threading.set_inter_op_parallelism_threads(1)
        tf.config.threading.set_intra_op_parallelism_threads(1)
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
        config.DATA_AUGMENTATION = False
    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        tf.random.set_seed(args.seed)

    # Handle distribution strategy.
    if args.distributed:
        tf_distribute_config(args.base_tf_server_port)
        if args.device == 'HPU':
            os.environ['HBN_TF_REGISTER_DATASETOPS'] = '1'
            from habana_frameworks.tensorflow.distribute import HPUStrategy
            strategy = HPUStrategy()
        else:
            strategy = tf.distribute.MultiWorkerMirroredStrategy()
    else:
        strategy = tf.distribute.OneDeviceStrategy(f'device:{args.device}:0')

    if not args.distributed or comm_rank() == 0:
        print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

    num_classes = 1000
    batch_size = args.batch_size
    nb_epoch = args.epochs
    dataset = args.dataset
    resume_from_checkpoint_path = args.resume_from_checkpoint_path
    resume_from_epoch = args.resume_from_epoch
    optim_name = args.optimizer
    initial_lr = args.initial_lr
    final_lr = args.final_lr
    lr_sched = args.lr_sched
    warmup_steps = args.warmup_steps
    model_name = args.model
    grad_accum_steps = args.grad_accum_steps

    ds_train = get_dataset(dataset, args.train_subset, batch_size,
                           is_training=True, distributed=args.distributed)
    ds_valid = get_dataset(dataset, args.val_subset,
                           batch_size, False, distributed=args.distributed)

    if args.dump_config is not None:
        vit.CONFIG_B['dropout'] = 0.0
        vit.CONFIG_L['dropout'] = 0.0

    # Load our model
    with strategy.scope():
        image_size = 384
        if model_name == 'ViT-B_16':
            model = vit.vit_b16(
                image_size=image_size,
                activation='softmax',
                pretrained=True,
                include_top=True,
                pretrained_top=False,
                classes=num_classes,
                weights="imagenet21k")
        elif model_name == 'ViT-L_16':
            model = vit.vit_l16(
                image_size=image_size,
                activation='softmax',
                pretrained=True,
                include_top=True,
                pretrained_top=False,
                classes=num_classes,
                weights="imagenet21k")
        elif model_name == 'ViT-B_32':
            model = vit.vit_b32(
                image_size=image_size,
                activation='softmax',
                pretrained=True,
                include_top=True,
                pretrained_top=False,
                classes=num_classes,
                weights="imagenet21k")
        elif model_name == 'ViT-L_32':
            model = vit.vit_l32(
                image_size=image_size,
                activation='softmax',
                pretrained=True,
                include_top=True,
                pretrained_top=False,
                classes=num_classes,
                weights="imagenet21k")
        else:
            print(
                "Model is not supported, please use either ViT-B_16 or ViT-L_16 or ViT-B_32 or ViT-L_32")
            exit(0)

        optimizer = get_optimizer(
            optim_name, initial_lr, accumulation_steps=grad_accum_steps, epsilon=1e-2)
        model.compile(optimizer=optimizer, loss='categorical_crossentropy',
                      metrics=['accuracy'], run_eagerly=False)

        # Start training

        steps_per_epoch = 1281167 // batch_size
        if args.steps_per_epoch is not None:
            steps_per_epoch = args.steps_per_epoch
        validation_steps = 50000 // batch_size
        if args.validation_steps is not None:
            validation_steps = args.validation_steps

        total_steps = nb_epoch * steps_per_epoch
        resume_step = resume_from_epoch * steps_per_epoch

        lrate = get_lr_func(nb_epoch, lr_sched, initial_lr,
                            final_lr, warmup_steps, resume_step, total_steps)

        save_name = model_name if not model_name.endswith('.h5') else \
            os.path.split(model_name)[-1].split('.')[0].split('-')[0]
        model_ckpt = tf.keras.callbacks.ModelCheckpoint(
            os.path.join(config.SAVE_DIR, save_name) + '-ckpt-{epoch:03d}.h5',
            monitor='train_loss')

        callbacks = [lrate, model_ckpt]
        if args.save_summary_steps > 0:
            callbacks += [TensorBoardWithHParamsV2(
                vars(args), log_dir=config.LOG_DIR, update_freq=args.save_summary_steps)]
            callbacks += [ExamplesPerSecondKerasHookV2(
                output_dir=config.LOG_DIR, every_n_steps=args.save_summary_steps, batch_size=args.batch_size)]

        if (args.evaluate_checkpoint_path is not None):
            model.load_weights(args.evaluate_checkpoint_path)
            results = model.evaluate(x=ds_valid, steps=validation_steps)
            print("Test loss, Test acc:", results)
            exit()

        if ((resume_from_epoch is not None) and (resume_from_checkpoint_path is not None)):
            model.load_weights(resume_from_checkpoint_path)

        with dump_callback(args.dump_config):
            model.fit(x=ds_train, y=None,
                      steps_per_epoch=steps_per_epoch,
                      callbacks=callbacks,
                      initial_epoch=resume_from_epoch,
                      epochs=nb_epoch,
                      shuffle=not args.deterministic,
                      verbose=1 if not args.distributed else comm_rank() == 0,
                      validation_data=(ds_valid, None),
                      validation_steps=validation_steps,
                      )

        if not args.distributed or comm_rank() == 0:
            model.save(f'{config.SAVE_DIR}/{save_name}-model-final.h5')
コード例 #12
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if horovod_enabled():
        FLAGS.output_dir = FLAGS.output_dir if hvd_rank(
        ) == 0 else os.path.join(FLAGS.output_dir, str(hvd_rank()))

    albert_config = modeling.AlbertConfig.from_json_file(
        FLAGS.albert_config_file)
    if FLAGS.deterministic_run and (albert_config.attention_probs_dropout_prob
                                    or albert_config.hidden_dropout_prob):
        albert_config.attention_probs_dropout_prob = 0.0
        albert_config.hidden_dropout_prob = 0.0

    tf.gfile.MakeDirs(FLAGS.output_dir)

    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.gfile.Glob(input_pattern))

    if FLAGS.use_horovod and len(input_files) < hvd.size():
        input_files = [input_files[0] for i in range(hvd.size())]

    tf.logging.info("*** Input Files ***")
    for input_file in input_files:
        tf.logging.info("  %s" % input_file)

    eval_files = []
    for eval_pattern in FLAGS.eval_file.split(","):
        eval_files.extend(tf.gfile.Glob(eval_pattern))

    if FLAGS.use_horovod and len(eval_files) < hvd.size():
        eval_files = [eval_files[0] for i in range(hvd.size())]

    tf.logging.info("*** Eval Files ***")
    for eval_file in eval_files:
        tf.logging.info("  %s" % eval_file)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        save_summary_steps=FLAGS.save_summary_steps,
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    num_train_steps = FLAGS.num_train_steps
    num_warmup_steps = FLAGS.num_warmup_steps
    if FLAGS.do_train and horovod_enabled():
        num_train_steps //= hvd_size()
        num_warmup_steps //= hvd_size()

    model_fn = model_fn_builder(
        albert_config=albert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate
        if not FLAGS.use_horovod else FLAGS.learning_rate * hvd_size(),
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu,
        optimizer=FLAGS.optimizer,
        poly_power=FLAGS.poly_power,
        start_warmup_step=FLAGS.start_warmup_step,
        use_einsum=FLAGS.use_einsum)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    write_hparams_v1(
        FLAGS.output_dir, {
            'batch_size': FLAGS.train_batch_size,
            'batch_size_per_pu': FLAGS.train_batch_size,
            **{x: getattr(FLAGS, x)
               for x in FLAGS}
        })

    if FLAGS.do_train:
        training_hooks = []
        if horovod_enabled():
            training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

        tf.logging.info("***** Running training *****")
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        train_input_fn = input_fn_builder(
            input_files=input_files,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=True)
        with dump_callback():
            estimator.train(input_fn=train_input_fn,
                            hooks=training_hooks,
                            max_steps=FLAGS.num_train_steps)

    if FLAGS.do_eval and (not FLAGS.use_horovod or hvd_rank() == 0):
        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)
        global_step = -1
        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        writer = tf.gfile.GFile(output_eval_file, "w")
        eval_input_fn = input_fn_builder(
            input_files=eval_files,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=False)
        best_perf = 0
        key_name = "masked_lm_accuracy"
        while global_step < FLAGS.num_train_steps:
            if estimator.latest_checkpoint() is None:
                tf.logging.info("No checkpoint found yet. Sleeping.")
                time.sleep(1)
            else:
                result = estimator.evaluate(input_fn=eval_input_fn,
                                            steps=FLAGS.max_eval_steps)
                global_step = result["global_step"]
                tf.logging.info("***** Eval results *****")
                checkpoint_path = estimator.latest_checkpoint()
                for key in sorted(result.keys()):
                    tf.logging.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))
                    if result[key_name] > best_perf:
                        best_perf = result[key_name]
                        for ext in ["meta", "data-00000-of-00001", "index"]:
                            src_ckpt = checkpoint_path + ".{}".format(ext)
                            tgt_ckpt = checkpoint_path.rsplit(
                                "-", 1)[0] + "-best.{}".format(ext)
                            tf.logging.info("saving {} to {}".format(
                                src_ckpt, tgt_ckpt))
                            tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True)
                            writer.write("saved {} to {}\n".format(
                                src_ckpt, tgt_ckpt))
コード例 #13
0
def main(_):
  if FLAGS.enable_packed_data_mode:
    FLAGS.num_accumulation_steps = round(FLAGS.num_accumulation_steps / FLAGS.avg_seq_per_pack)

  os.environ["TF_XLA_FLAGS"] = "--tf_xla_enable_lazy_compilation=false" #causes memory fragmentation for bert leading to OOM

  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  dllogging = dllogger_class.dllogger_class(FLAGS.dllog_path)

  if not FLAGS.do_train and not FLAGS.do_eval:
    raise ValueError("At least one of `do_train` or `do_eval` must be True.")

  # In multi-node scenario, on each of HLSes there must be a checkpoint directly in the output_dir (read by Phase 2).
  # There may be only one worker with comm_local_rank() == 0 on each machine and this worker will put its checkpoints there.
  # All other workers use sub-directories to keep checkpoints.
  if horovod_enabled() and comm_local_rank() != 0:
    FLAGS.output_dir = os.path.join(FLAGS.output_dir, f'worker_{hvd_rank()}')

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  tf.io.gfile.makedirs(FLAGS.output_dir)

  input_files = []
  for input_file_dir in FLAGS.input_files_dir.split(","):
    input_files.extend(tf.io.gfile.glob(os.path.join(input_file_dir, "*")))

  if FLAGS.horovod and len(input_files) < hvd.size():
      tf.compat.v1.logging.warning("Input files count lower then expected. Using single file for OVERFIT test.")
      input_files = [input_files[0] for i in range(hvd.size())]
  if FLAGS.amp and FLAGS.manual_fp16:
      raise ValueError("AMP and Manual Mixed Precision Training are both activated! Error")

  is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2

  # The Scoped Allocator Optimization is enabled by default unless disabled by a flag.
  if FLAGS.enable_scoped_allocator:
    from tensorflow.core.protobuf import rewriter_config_pb2  # pylint: disable=import-error

    session_config = tf.compat.v1.ConfigProto()
    session_config.graph_options.rewrite_options.scoped_allocator_optimization = rewriter_config_pb2.RewriterConfig.ON

    enable_op = session_config.graph_options.rewrite_options.scoped_allocator_opts.enable_op
    del enable_op[:]
    enable_op.append("HorovodAllreduce")
  else:
    session_config = tf.compat.v1.ConfigProto()

  if FLAGS.horovod:
    session_config.gpu_options.visible_device_list = str(hvd.local_rank())
    if hvd.rank() == 0:
      tf.compat.v1.logging.info("***** Configuaration *****")
      for key in FLAGS.__flags.keys():
          tf.compat.v1.logging.info('  {}: {}'.format(key, getattr(FLAGS, key)))
      tf.compat.v1.logging.info("**************************")

#    config.gpu_options.per_process_gpu_memory_fraction = 0.7
  if FLAGS.use_xla:
      session_config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
      session_config.graph_options.rewrite_options.memory_optimization = rewriter_config_pb2.RewriterConfig.NO_MEM_OPT
      if FLAGS.amp:
        tf.compat.v1.enable_resource_variables()

  run_config = tf.estimator.RunConfig(
      model_dir=FLAGS.output_dir,
      session_config=session_config,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      keep_checkpoint_max = 2,
      save_summary_steps=FLAGS.save_summary_steps,
      log_step_count_steps=1)

  model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate if not FLAGS.horovod else FLAGS.learning_rate*hvd_size(),
      num_train_steps=FLAGS.num_train_steps,
      num_warmup_steps=FLAGS.num_warmup_steps,
      use_one_hot_embeddings=False)

  estimator = tf.estimator.Estimator(
      model_fn=model_fn,
      config=run_config)

  batch_size_per_node = FLAGS.train_batch_size * FLAGS.num_accumulation_steps
  global_batch_size = (hvd.size() if FLAGS.horovod else 1) * batch_size_per_node
  write_hparams_v1(FLAGS.output_dir, {
    'batch_size': FLAGS.train_batch_size,
    'batch_size_per_pu': FLAGS.train_batch_size,
    'batch_size_per_node': batch_size_per_node,
    'global_batch_size': global_batch_size,
    **{x: getattr(FLAGS, x) for x in FLAGS}
  })

  if FLAGS.do_train:

    training_hooks = []
    if horovod_enabled():
      training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

    train_log_hook = _LogSessionRunHook(
      global_batch_size, FLAGS.num_accumulation_steps, dllogging,
      FLAGS.display_loss_steps, FLAGS.save_checkpoints_steps, FLAGS.report_loss)
    training_hooks.append(train_log_hook)

    training_hooks.append(ExamplesPerSecondEstimatorHook(
      batch_size=batch_size_per_node, output_dir=FLAGS.output_dir,
      extra_metrics={'global_examples/sec': global_batch_size}))

    tf.compat.v1.logging.info("***** Running training *****")
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    train_input_fn = input_fn_builder(
        input_files=input_files,
        batch_size=FLAGS.train_batch_size,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        is_training=True)

    train_start_time = time.time()
    with dump_callback():
      estimator.train(input_fn=train_input_fn, hooks=training_hooks, max_steps=FLAGS.num_train_steps)
    train_time_elapsed = time.time() - train_start_time

    if (not FLAGS.horovod or hvd_rank() == 0):
        train_time_wo_overhead = train_log_hook.total_time
        avg_sentences_per_second = FLAGS.num_train_steps * global_batch_size * 1.0 / train_time_elapsed
        if FLAGS.enable_packed_data_mode: avg_sentences_per_second *=FLAGS.avg_seq_per_pack
        try:
            ss_sentences_per_second = (FLAGS.num_train_steps - train_log_hook.skipped) * global_batch_size * 1.0 / train_time_wo_overhead
            if FLAGS.enable_packed_data_mode: ss_sentences_per_second *=FLAGS.avg_seq_per_pack
            throughput_avg_wo_overhead_msg = ["Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second]
        except:
            ss_sentences_per_second = float('nan')
            throughput_avg_wo_overhead_msg = [f"Throughput Average W/O Overhead is not logged when num_train_steps < {train_log_hook.skip_iters}"]

        tf.compat.v1.logging.info("-----------------------------")
        tf.compat.v1.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed,
                        FLAGS.num_train_steps * global_batch_size)
        tf.compat.v1.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead,
                        (FLAGS.num_train_steps - train_log_hook.skipped) * global_batch_size)
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
        tf.compat.v1.logging.info(*throughput_avg_wo_overhead_msg)
        dllogging.logger.log(step=(), data={"throughput_train": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
        tf.compat.v1.logging.info("-----------------------------")

  if FLAGS.do_eval and (not FLAGS.horovod or hvd_rank() == 0):
    tf.compat.v1.logging.info("***** Running evaluation *****")
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

    eval_files = []
    for eval_file_dir in FLAGS.eval_files_dir.split(","):
        eval_files.extend(tf.io.gfile.glob(os.path.join(eval_file_dir, "*")))

    eval_input_fn = input_fn_builder(
        input_files=eval_files,
        batch_size=FLAGS.eval_batch_size,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        is_training=False)

    eval_hooks = [LogEvalRunHook(FLAGS.eval_batch_size)]
    eval_start_time = time.time()
    result = estimator.evaluate(
        input_fn=eval_input_fn, steps=FLAGS.max_eval_steps, hooks=eval_hooks)

    eval_time_elapsed = time.time() - eval_start_time
    time_list = eval_hooks[-1].time_list
    time_list.sort()
    # Removing outliers (init/warmup) in throughput computation.
    eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
    num_sentences = (int(len(time_list) * 0.99)) * FLAGS.eval_batch_size

    ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead
    if FLAGS.enable_packed_data_mode: ss_sentences_per_second *=FLAGS.avg_seq_per_pack
    
    tf.compat.v1.logging.info("-----------------------------")
    tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", eval_time_elapsed,
                    eval_hooks[-1].count * FLAGS.eval_batch_size)
    tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", eval_time_wo_overhead,
                    num_sentences)
    tf.compat.v1.logging.info("Summary Inference Statistics on EVAL set")
    tf.compat.v1.logging.info("Batch size = %d", FLAGS.eval_batch_size)
    tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
    tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.amp else "fp32")
    tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
    dllogging.logger.log(step=(), data={"throughput_val": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
    tf.compat.v1.logging.info("-----------------------------")

    output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
    with tf.io.gfile.GFile(output_eval_file, "w") as writer:
      tf.compat.v1.logging.info("***** Eval results *****")
      for key in sorted(result.keys()):
        tf.compat.v1.logging.info("  %s = %s", key, str(result[key]))
        writer.write("%s = %s\n" % (key, str(result[key])))
コード例 #14
0
ファイル: main.py プロジェクト: HabanaAI/Model-References
def main(argv):
    del argv  # Unused.

    # if given an efficentdet ckpt don't use default backbone ckpt
    if FLAGS.backbone_ckpt == BACKBONE_CKPT_DEFAULT_DIR and FLAGS.ckpt is not None:
        print("Using ckpt flag: {}, ignoring default backbone_ckpt: {}".format(
            FLAGS.ckpt, FLAGS.backbone_ckpt))
        FLAGS.backbone_ckpt = None

    if FLAGS.use_horovod is not None:
        if FLAGS.dump_all_ranks:
            FLAGS.model_dir += "/worker_" + str(hvd.rank())
        if not 'HOROVOD_CYCLE_TIME' in os.environ:
            os.environ['HOROVOD_CYCLE_TIME'] = '0.5'
        if not 'HABANA_HCCL_COMM_API' in os.environ:
            os.environ['HABANA_HCCL_COMM_API'] = '0'
        hvd_init()

    if not FLAGS.no_hpu:
        from habana_frameworks.tensorflow import load_habana_module
        load_habana_module()

        if FLAGS.use_horovod:
            assert (horovod_enabled())

    set_env(use_amp=FLAGS.use_amp)

    # deterministic setting
    if FLAGS.sbs_test or FLAGS.deterministic:
        set_deterministic()

    # Check data path
    if FLAGS.mode in (
            'train', 'train_and_eval') and FLAGS.training_file_pattern is None:
        raise RuntimeError(
            'You must specify --training_file_pattern for training.')
    if FLAGS.mode in ('eval', 'train_and_eval'):
        if FLAGS.validation_file_pattern is None:
            raise RuntimeError('You must specify --validation_file_pattern '
                               'for evaluation.')
        if not FLAGS.val_json_file and not FLAGS.testdev_dir:
            raise RuntimeError(
                'You must specify --val_json_file or --testdev for evaluation.'
            )

    # Parse and override hparams
    config = hparams_config.get_detection_config(FLAGS.model_name)
    config.override(FLAGS.hparams)

    # The following is for spatial partitioning. `features` has one tensor while
    # `labels` had 4 + (`max_level` - `min_level` + 1) * 2 tensors. The input
    # partition is performed on `features` and all partitionable tensors of
    # `labels`, see the partition logic below.
    # In the TPUEstimator context, the meaning of `shard` and `replica` is the
    # same; follwing the API, here has mixed use of both.
    if FLAGS.use_spatial_partition:
        # Checks input_partition_dims agrees with num_cores_per_replica.
        if FLAGS.num_cores_per_replica != np.prod(FLAGS.input_partition_dims):
            raise RuntimeError(
                '--num_cores_per_replica must be a product of array'
                'elements in --input_partition_dims.')

        labels_partition_dims = {
            'mean_num_positives': None,
            'source_ids': None,
            'groundtruth_data': None,
            'image_scales': None,
        }
        # The Input Partition Logic: We partition only the partition-able tensors.
        # Spatial partition requires that the to-be-partitioned tensors must have a
        # dimension that is a multiple of `partition_dims`. Depending on the
        # `partition_dims` and the `image_size` and the `max_level` in config, some
        # high-level anchor labels (i.e., `cls_targets` and `box_targets`) cannot
        # be partitioned. For example, when `partition_dims` is [1, 4, 2, 1], image
        # size is 1536, `max_level` is 9, `cls_targets_8` has a shape of
        # [batch_size, 6, 6, 9], which cannot be partitioned (6 % 4 != 0). In this
        # case, the level-8 and level-9 target tensors are not partition-able, and
        # the highest partition-able level is 7.
        image_size = config.get('image_size')
        for level in range(config.get('min_level'),
                           config.get('max_level') + 1):

            def _can_partition(spatial_dim):
                partitionable_index = np.where(
                    spatial_dim % np.array(FLAGS.input_partition_dims) == 0)
                return len(partitionable_index[0]) == len(
                    FLAGS.input_partition_dims)

            spatial_dim = image_size // (2**level)
            if _can_partition(spatial_dim):
                labels_partition_dims['box_targets_%d' %
                                      level] = FLAGS.input_partition_dims
                labels_partition_dims['cls_targets_%d' %
                                      level] = FLAGS.input_partition_dims
            else:
                labels_partition_dims['box_targets_%d' % level] = None
                labels_partition_dims['cls_targets_%d' % level] = None
        num_cores_per_replica = FLAGS.num_cores_per_replica
        input_partition_dims = [
            FLAGS.input_partition_dims, labels_partition_dims
        ]
        num_shards = FLAGS.num_cores // num_cores_per_replica
    else:
        num_cores_per_replica = None
        input_partition_dims = None
        num_shards = FLAGS.num_cores
        if horovod_enabled():
            num_shards = hvd.size()
        else:
            num_shards = 1

    params = build_estimator_params('train', config, num_shards)
    # disabling input data scaling/flip manipulations.
    if FLAGS.sbs_test:
        sbs_params = dict(input_rand_hflip=False,
                          train_scale_min=1,
                          train_scale_max=1,
                          dropout_rate=0.0)
        params.update(sbs_params)

    tf_random_seed = 0 if FLAGS.deterministic else None
    run_config = build_estimator_config('train', config, num_shards,
                                        num_cores_per_replica,
                                        input_partition_dims)
    write_hparams_v1(FLAGS.model_dir, {
        'batch_size': FLAGS.train_batch_size,
        **FLAGS.flag_values_dict()
    })

    model_fn_instance = det_model_fn.get_model_fn(FLAGS.model_name)

    # TPU Estimator
    logging.info(params)

    if FLAGS.mode == 'train':
        train_estimator = HorovodEstimator(model_fn=model_fn_instance,
                                           model_dir=FLAGS.model_dir,
                                           config=run_config,
                                           params=params)

        # for deterministic input, we pass to dataloader False for not manipulating input data
        is_training = not FLAGS.deterministic
        use_fake_data = FLAGS.use_fake_data or FLAGS.deterministic

        input_fn = dataloader.InputReader(FLAGS.training_file_pattern,
                                          is_training=is_training,
                                          params=params,
                                          use_fake_data=use_fake_data,
                                          is_deterministic=FLAGS.deterministic)
        max_steps = int((FLAGS.num_epochs * FLAGS.num_examples_per_epoch) /
                        (FLAGS.train_batch_size * num_shards)) + 1

        # for sbs test, train under sbs callbacks
        if FLAGS.sbs_test:
            from TensorFlow.common.debug import dump_callback
            SBS_TEST_CONFIG = os.path.join(
                os.environ['TF_TESTS_ROOT'],
                "tests/tf_training_tests/side_by_side/topologies/efficientdet/dump_config.json"
            )
            with dump_callback(SBS_TEST_CONFIG):
                train_estimator.train(input_fn=input_fn, max_steps=max_steps)
        else:
            if FLAGS.ckpt is not None:
                train_estimator.train(input_fn=input_fn, steps=max_steps)
            else:
                train_estimator.train(input_fn=input_fn, max_steps=max_steps)

    elif FLAGS.mode == 'eval':
        eval_params = build_estimator_params('eval', config, num_shards)
        eval_config = build_estimator_config('eval', config, num_shards,
                                             num_cores_per_replica,
                                             input_partition_dims)

        # Eval only runs on CPU or GPU host with batch_size = 1.
        # Override the default options: disable randomization in the input pipeline
        # and don't run on the TPU.
        # Also, disable use_bfloat16 for eval on CPU/GPU.

        eval_estimator = tf.estimator.tpu.TPUEstimator(
            model_fn=model_fn_instance,
            use_tpu=False,
            train_batch_size=FLAGS.train_batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            config=eval_config,
            params=eval_params)

        def terminate_eval():
            logging.info('Terminating eval after %d seconds of no checkpoints',
                         FLAGS.eval_timeout)
            return True

        # Run evaluation when there's a new checkpoint
        for ckpt in tf.train.checkpoints_iterator(
                FLAGS.model_dir,
                min_interval_secs=FLAGS.min_eval_interval,
                timeout=FLAGS.eval_timeout,
                timeout_fn=terminate_eval):

            logging.info('Starting to evaluate.')
            try:
                eval_results = eval_estimator.evaluate(
                    input_fn=dataloader.InputReader(
                        FLAGS.validation_file_pattern, is_training=False),
                    steps=FLAGS.eval_samples // FLAGS.eval_batch_size)
                logging.info('Eval results: %s', eval_results)

                # Terminate eval job when final checkpoint is reached.
                try:
                    current_step = int(os.path.basename(ckpt).split('-')[1])
                except IndexError:
                    logging.info('%s has no global step info: stop!', ckpt)
                    break

                write_summary(eval_results, ckpt, current_step)

                utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
                total_step = int(
                    (FLAGS.num_epochs * FLAGS.num_examples_per_epoch) /
                    FLAGS.train_batch_size)
                if current_step >= total_step:
                    logging.info('Evaluation finished after training step %d',
                                 current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)

    elif FLAGS.mode == 'train_and_eval':
        train_params = build_estimator_params('train', config, num_shards)
        train_config = build_estimator_config('train', config, num_shards,
                                              num_cores_per_replica,
                                              input_partition_dims)
        train_estimator = HorovodEstimator(model_fn=model_fn_instance,
                                           model_dir=FLAGS.model_dir,
                                           config=train_config,
                                           params=train_params)

        eval_estimator = None

        for cycle in range(FLAGS.num_epochs):
            logging.info('Starting training cycle, epoch: %d.', cycle)

            train_estimator.train(
                input_fn=dataloader.InputReader(
                    FLAGS.training_file_pattern,
                    is_training=True,
                    use_fake_data=FLAGS.use_fake_data),
                max_steps=(cycle + 1) *
                int(FLAGS.num_examples_per_epoch / FLAGS.train_batch_size))

            # synchronization point for all ranks
            if horovod_enabled():
                hvd.allreduce(tf.constant(0))

            logging.info('Starting evaluation cycle, epoch: %d.', cycle)
            # Run evaluation after every epoch.

            if eval_estimator is None:
                eval_params = build_estimator_params('eval', config,
                                                     num_shards)
                eval_config = build_estimator_config('eval', config,
                                                     num_shards,
                                                     num_cores_per_replica,
                                                     input_partition_dims)
                eval_estimator = tf.estimator.tpu.TPUEstimator(
                    model_fn=model_fn_instance,
                    use_tpu=False,
                    train_batch_size=FLAGS.train_batch_size,
                    eval_batch_size=FLAGS.eval_batch_size,
                    config=eval_config,
                    params=eval_params)

            if is_rank0():
                eval_results = eval_estimator.evaluate(
                    input_fn=dataloader.InputReader(
                        FLAGS.validation_file_pattern, is_training=False),
                    steps=FLAGS.eval_samples // FLAGS.eval_batch_size)

                checkpoint_path = Path(FLAGS.model_dir)
                last_ckpt = tf.train.latest_checkpoint(str(checkpoint_path),
                                                       latest_filename=None)
                current_step = int(os.path.basename(last_ckpt).split('-')[1])
                write_summary(eval_results, FLAGS.model_dir, current_step)
                logging.info('Evaluation results: %s', eval_results)

                ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
                utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
        pass

    else:
        logging.info('Mode not found.')
コード例 #15
0
ファイル: trainer.py プロジェクト: HabanaAI/Model-References
def main(argv):
  tf.disable_v2_behavior()
  tf.enable_resource_variables()

  if FLAGS.use_hpu and FLAGS.recipe_cache:
    prepare_recipe_cache()

  if FLAGS.use_horovod:
    if FLAGS.use_hpu:
      from TensorFlow.common.horovod_helpers import hvd_init, horovod_enabled, hvd
      hvd_init()
      assert horovod_enabled()
      if FLAGS.recipe_cache:
        # Other ranks should wait for recipe cache to be removed.
        # This operation can't be done before hvd_init.
        from mpi4py import MPI
        MPI.COMM_WORLD.Barrier()
    else:
      import horovod.tensorflow as hvd
      hvd.init()
      assert hvd.size() > 1
      os.environ['CUDA_VISIBLE_DEVICES'] = str(hvd.local_rank())

  if FLAGS.use_hpu:
    if FLAGS.use_bf16:
      os.environ['TF_BF16_CONVERSION'] = FLAGS.bf16_config_path

    dyn_shapes_flag = 'TF_ENABLE_DYNAMIC_SHAPES'
    if dyn_shapes_flag not in os.environ:
        os.environ[dyn_shapes_flag] = 'false'

    from habana_frameworks.tensorflow import load_habana_module  # noqa
    load_habana_module()

  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  # If we just have to print the registry, do that and exit early.
  maybe_log_registry_and_exit()

  # Create HParams.
  if argv:
    set_hparams_from_args(argv[1:])
  if FLAGS.schedule != "run_std_server":
    hparams = create_hparams()
  if FLAGS.gpu_automatic_mixed_precision:
    setattr(hparams, "gpu_automatic_mixed_precision", True)
  if FLAGS.deterministic_dataset:
    hparams.add_hparam("deterministic_dataset", True)

  hparams.add_hparam("use_horovod", FLAGS.use_horovod)
  hparams.add_hparam("use_hpu", FLAGS.use_hpu)
  if FLAGS.use_horovod:
    hparams.add_hparam("hvd_worker_id", hvd.rank())
    hparams.add_hparam("hvd_size", hvd.size())

  if FLAGS.schedule == "run_std_server":
    run_std_server()
  trainer_lib.set_random_seed(FLAGS.random_seed)

  if FLAGS.generate_data:
    generate_data()

  exp_fn = create_experiment_fn()
  exp = exp_fn(create_run_config(hparams), hparams)
  if is_chief():
    save_metadata(hparams)

  with dump_callback():
    execute_schedule(exp)
コード例 #16
0
def main():
    parser = DenseNetArgumentParser(
        description=(
            "train.py is the main training/evaluation script for DenseNet. "
            "In order to run training on multiple Gaudi cards, use demo_densenet.py or run "
            "train.py with mpirun."))
    args, _ = parser.parse_known_args()

    strategy = None
    verbose = 1

    os.environ['ENABLE_EXPERIMENTAL_FLAGS'] = 'true'
    os.environ['RUN_TPC_FUSER'] = '******'

    if args.deterministic:
        if args.inputs is None:
            raise ValueError("Must provide inputs for deterministic mode")
        if args.resume_from_checkpoint_path is None:
            raise ValueError("Must provide checkpoint for deterministic mode")

    if args.dtype == 'bf16':
        os.environ['TF_BF16_CONVERSION'] = '1'

    if args.run_on_hpu:
        load_habana_module()
        if args.use_hpu_strategy:
            hls_addresses = str(os.environ.get(
                "MULTI_HLS_IPS", "127.0.0.1")).split(",")
            TF_BASE_PORT = 2410
            mpi_rank = comm_rank()
            mpi_size = comm_size()
            if mpi_rank > 0:
                verbose = 0
            worker_hosts = ""
            for address in hls_addresses:
                # worker_hosts: comma-separated list of worker ip:port pairs.
                worker_hosts = worker_hosts + ",".join(
                    [address + ':' + str(TF_BASE_PORT + rank)
                     for rank in range(mpi_size//len(hls_addresses))])
            task_index = mpi_rank

            # Configures cluster spec for distribution strategy.
            _ = distribution_utils.configure_cluster(worker_hosts, task_index)
            strategy = HPUStrategy()
            print('Number of devices: {}'.format(
                strategy.num_replicas_in_sync))
    else:
        strategy = tf.distribute.MultiWorkerMirroredStrategy()
        print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

    if args.seed is not None:
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
        random.seed(args.seed)
        np.random.seed(args.seed)
        tf.random.set_seed(args.seed)

    img_rows, img_cols = 224, 224  # Resolution of inputs
    channel = 3
    num_classes = 1000
    batch_size = args.batch_size
    nb_epoch = args.epochs
    dataset_dir = args.dataset_dir
    resume_from_checkpoint_path = args.resume_from_checkpoint_path
    resume_from_epoch = args.resume_from_epoch
    dropout_rate = args.dropout_rate
    weight_decay = args.weight_decay
    optim_name = args.optimizer
    initial_lr = args.initial_lr
    model_name = args.model
    save_summary_steps = args.save_summary_steps

    if model_name == "densenet121":
        growth_rate = 32
        nb_filter = 64
        nb_layers = [6, 12, 24, 16]

    elif model_name == "densenet161":
        growth_rate = 48
        nb_filter = 96
        nb_layers = [6, 12, 36, 24]

    elif model_name == "densenet169":
        growth_rate = 32
        nb_filter = 64
        nb_layers = [6, 12, 32, 32]

    else:
        print("model is not supported")
        exit(1)

    # Load our model
    if strategy:
        with strategy.scope():
            model = densenet_model(img_rows=img_rows, img_cols=img_cols, color_type=channel,
                                   dropout_rate=dropout_rate, weight_decay=weight_decay, num_classes=num_classes,
                                   growth_rate=growth_rate, nb_filter=nb_filter, nb_layers=nb_layers)
            optimizer = get_optimizer(
                model_name, optim_name, initial_lr, epsilon=1e-2)
            model.compile(optimizer=optimizer,
                          loss='categorical_crossentropy', metrics=['accuracy'])
    else:
        model = densenet_model(img_rows=img_rows, img_cols=img_cols, color_type=channel,
                               dropout_rate=dropout_rate, weight_decay=weight_decay, num_classes=num_classes,
                               growth_rate=growth_rate, nb_filter=nb_filter, nb_layers=nb_layers)
        optimizer = get_optimizer(
            model_name, optim_name, initial_lr, epsilon=1e-2)
        model.compile(optimizer=optimizer,
                      loss='categorical_crossentropy', metrics=['accuracy'])

    # Start training
    steps_per_epoch = 1281167 // batch_size
    if args.steps_per_epoch is not None:
        steps_per_epoch = args.steps_per_epoch
    validation_steps = 50000 // batch_size
    if args.validation_steps is not None:
        validation_steps = args.validation_steps
    warmup_steps = args.warmup_epochs * steps_per_epoch
    lr_sched = {0: 1, 30: 0.1, 60: 0.01, 80: 0.001}
    lr_sched_steps = {
        epoch * steps_per_epoch: multiplier for (epoch, multiplier) in lr_sched.items()}

    lrate = StepLearningRateScheduleWithWarmup(initial_lr=initial_lr,
                                               initial_global_step=0,
                                               warmup_steps=warmup_steps,
                                               decay_schedule=lr_sched_steps,
                                               verbose=0)

    save_name = model_name if not model_name.endswith('.h5') else \
        os.path.split(model_name)[-1].split('.')[0].split('-')[0]

    model_ckpt = tf.keras.callbacks.ModelCheckpoint(
        os.path.join(args.model_dir, config.SAVE_DIR,
                     save_name) + '-ckpt-{epoch:03d}.h5',
        monitor='train_loss')

    callbacks = [lrate, model_ckpt]

    if save_summary_steps is not None and save_summary_steps > 0:
        log_dir = os.path.join(args.model_dir, config.LOG_DIR)
        local_batch_size = batch_size
        
        if args.use_hpu_strategy:
            log_dir = os.path.join(log_dir, 'worker_' + str(comm_rank()))
            local_batch_size = batch_size // strategy.num_replicas_in_sync

        callbacks += [
            TensorBoardWithHParamsV2(
                args.__dict__, log_dir=log_dir,
                update_freq=save_summary_steps, profile_batch=0),
            ExamplesPerSecondKerasHookV2(
                save_summary_steps, output_dir=log_dir,
                batch_size=local_batch_size),
        ]

    if (args.evaluate_checkpoint_path is not None):
        model.load_weights(args.evaluate_checkpoint_path)
        results = model.evaluate(x=ds_valid, steps=validation_steps)
        print("Test loss, Test acc:", results)
        exit()

    if ((resume_from_epoch is not None) and (resume_from_checkpoint_path is not None)):
        model.load_weights(resume_from_checkpoint_path)

    if args.deterministic:
        set_deterministic()
        if not os.path.isfile(args.dump_config):
            raise FileNotFoundError("wrong dump config path")

        import pickle
        x_path = os.path.join(args.inputs, "input")
        y_path = os.path.join(args.inputs, "target")
        x = pickle.load(open(x_path, 'rb'))
        y = pickle.load(open(y_path, 'rb'))

        with dump_callback(args.dump_config):
          model.fit(x=x, y=y,
                  steps_per_epoch=steps_per_epoch,
                  callbacks=callbacks,
                  initial_epoch=resume_from_epoch,
                  epochs=nb_epoch,
                  shuffle=False,
                  verbose=verbose,
                  validation_data=None,
                  validation_steps=0,
                  )
    else:
      ds_train = get_dataset(dataset_dir, args.train_subset, batch_size)
      ds_valid = get_dataset(dataset_dir, args.val_subset, batch_size)

      model.fit(x=ds_train, y=None,
                steps_per_epoch=steps_per_epoch,
                callbacks=callbacks,
                initial_epoch=resume_from_epoch,
                epochs=nb_epoch,
                shuffle=True,
                verbose=verbose,
                validation_data=(ds_valid, None),
                validation_steps=validation_steps,
                validation_freq=1,
                )
コード例 #17
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    processors = {
        "cola": classifier_utils.ColaProcessor,
        "mnli": classifier_utils.MnliProcessor,
        "mismnli": classifier_utils.MisMnliProcessor,
        "mrpc": classifier_utils.MrpcProcessor,
        "rte": classifier_utils.RteProcessor,
        "sst-2": classifier_utils.Sst2Processor,
        "sts-b": classifier_utils.StsbProcessor,
        "qqp": classifier_utils.QqpProcessor,
        "qnli": classifier_utils.QnliProcessor,
        "wnli": classifier_utils.WnliProcessor,
    }

    if not (FLAGS.do_train or FLAGS.do_eval or FLAGS.do_predict
            or FLAGS.export_dir):
        raise ValueError(
            "At least one of `do_train`, `do_eval`, `do_predict' or `export_dir` "
            "must be True.")

    if not FLAGS.albert_config_file and not FLAGS.albert_hub_module_handle:
        raise ValueError("At least one of `--albert_config_file` and "
                         "`--albert_hub_module_handle` must be set")

    if FLAGS.albert_config_file:
        albert_config = modeling.AlbertConfig.from_json_file(
            FLAGS.albert_config_file)
        if FLAGS.max_seq_length > albert_config.max_position_embeddings:
            raise ValueError(
                "Cannot use sequence length %d because the ALBERT model "
                "was only trained up to sequence length %d" %
                (FLAGS.max_seq_length, albert_config.max_position_embeddings))
    else:
        albert_config = None  # Get the config from TF-Hub.

    if FLAGS.deterministic_run and (albert_config.attention_probs_dropout_prob
                                    or albert_config.hidden_dropout_prob):
        albert_config.attention_probs_dropout_prob = 0.0
        albert_config.hidden_dropout_prob = 0.0

    tf.gfile.MakeDirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name](
        use_spm=True if FLAGS.spm_model_file else False,
        do_lower_case=FLAGS.do_lower_case)

    label_list = processor.get_labels()

    tokenizer = fine_tuning_utils.create_vocab(
        vocab_file=FLAGS.vocab_file,
        do_lower_case=FLAGS.do_lower_case,
        spm_model_file=FLAGS.spm_model_file,
        hub_module=FLAGS.albert_hub_module_handle)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    model_dir = FLAGS.output_dir
    if horovod_enabled():
        model_dir = os.path.join(FLAGS.output_dir, "worker_" + str(hvd.rank()))

    # The Scoped Allocator Optimization is enabled by default unless disabled by a flag.
    if FLAGS.enable_scoped_allocator:
        from tensorflow.core.protobuf import rewriter_config_pb2  # pylint: disable=import-error

        session_config = tf.compat.v1.ConfigProto()
        session_config.graph_options.rewrite_options.scoped_allocator_optimization = rewriter_config_pb2.RewriterConfig.ON

        enable_op = session_config.graph_options.rewrite_options.scoped_allocator_opts.enable_op
        del enable_op[:]
        enable_op.append("HorovodAllreduce")
    else:
        session_config = None

    is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2
    if FLAGS.do_train:
        iterations_per_loop = int(
            min(FLAGS.iterations_per_loop, FLAGS.save_checkpoints_steps))
    else:
        iterations_per_loop = FLAGS.iterations_per_loop

    run_config = tf.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=model_dir,
        save_checkpoints_steps=int(FLAGS.save_checkpoints_steps),
        keep_checkpoint_max=0,
        save_summary_steps=FLAGS.save_summary_steps,
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host),
        session_config=session_config)

    train_examples = None

    train_batch_size = FLAGS.train_batch_size
    if horovod_enabled():
        train_batch_size = train_batch_size * hvd.size()

    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)

        start_index = 0
        end_index = len(train_examples)
        worker_id = 0
        per_worker_filenames = [
            os.path.join(FLAGS.output_dir, "train.tf_record")
        ]

        if horovod_enabled():
            per_worker_filenames = [
                os.path.join(FLAGS.output_dir, "train.tf_record_{}".format(i))
                for i in range(hvd.size())
            ]
            num_examples_per_rank = len(train_examples) // hvd.size()
            remainder = len(train_examples) % hvd.size()
            worker_id = hvd.rank()
            if worker_id < remainder:
                start_index = worker_id * (num_examples_per_rank + 1)
                end_index = start_index + num_examples_per_rank + 1
            else:
                start_index = worker_id * num_examples_per_rank + remainder
                end_index = start_index + (num_examples_per_rank)

    learning_rate = FLAGS.learning_rate
    if horovod_enabled():
        learning_rate = learning_rate * hvd.size()

    model_fn = classifier_utils.model_fn_builder(
        albert_config=albert_config,
        num_labels=len(label_list),
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=learning_rate,
        num_train_steps=FLAGS.train_step,
        num_warmup_steps=FLAGS.warmup_step,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu,
        task_name=task_name,
        hub_module=FLAGS.albert_hub_module_handle,
        optimizer=FLAGS.optimizer)

    if not math.isnan(FLAGS.threshold_to_export):
        model_fn = _add_threshold_to_model_fn(model_fn,
                                              FLAGS.threshold_to_export)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.predict_batch_size,
        export_to_tpu=False)  # http://yaqs/4707241341091840

    write_hparams_v1(
        FLAGS.output_dir, {
            'batch_size': FLAGS.train_batch_size,
            **{x: getattr(FLAGS, x)
               for x in FLAGS}
        })

    if FLAGS.do_train:
        if FLAGS.deterministic_run and not horovod_enabled(
        ) and FLAGS.input_file:
            per_worker_filenames = [
                os.path.join(FLAGS.input_file, "train.tf_record")
            ]
        if not tf.gfile.Exists(per_worker_filenames[worker_id]):
            classifier_utils.file_based_convert_examples_to_features(
                train_examples[start_index:end_index], label_list,
                FLAGS.max_seq_length, tokenizer,
                per_worker_filenames[worker_id], task_name)
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num examples = %d", len(train_examples))
        tf.logging.info("  Per-worker batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Total batch size = %d", train_batch_size)
        tf.logging.info("  Num steps = %d", FLAGS.train_step)
        train_input_fn = classifier_utils.file_based_input_fn_builder(
            input_file=per_worker_filenames,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True,
            task_name=task_name,
            use_tpu=FLAGS.use_tpu,
            bsz=FLAGS.train_batch_size)

        train_hooks = [
            habana_hooks.PerfLoggingHook(batch_size=train_batch_size,
                                         mode="train")
        ]
        if horovod_enabled():
            train_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

        if "range" == os.environ.get("HABANA_SYNAPSE_LOGGER", "False").lower():
            from habana_frameworks.tensorflow.synapse_logger_helpers import SynapseLoggerHook
            begin = 30
            end = begin + 10
            print("Begin: {}".format(begin))
            print("End: {}".format(end))
            train_hooks.append(
                SynapseLoggerHook(list(range(begin, end)), False))

        with dump_callback():
            estimator.train(input_fn=train_input_fn,
                            max_steps=FLAGS.train_step,
                            hooks=train_hooks)

    if FLAGS.do_eval:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        num_actual_eval_examples = len(eval_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on. These do NOT count towards the metric (all tf.metrics
            # support a per-instance weight, and these get a weight of 0.0).
            while len(eval_examples) % FLAGS.eval_batch_size != 0:
                eval_examples.append(classifier_utils.PaddingInputExample())

        cached_dir = FLAGS.cached_dir
        if not cached_dir:
            cached_dir = FLAGS.output_dir
        eval_file = os.path.join(cached_dir, task_name + "_eval.tf_record")
        if not tf.gfile.Exists(eval_file):
            classifier_utils.file_based_convert_examples_to_features(
                eval_examples, label_list, FLAGS.max_seq_length, tokenizer,
                eval_file, task_name)

        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(eval_examples), num_actual_eval_examples,
                        len(eval_examples) - num_actual_eval_examples)
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        # This tells the estimator to run through the entire set.
        eval_steps = None
        # However, if running eval on the TPU, you will need to specify the
        # number of steps.
        if FLAGS.use_tpu:
            assert len(eval_examples) % FLAGS.eval_batch_size == 0
            eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)

        eval_drop_remainder = True if FLAGS.use_tpu else False
        eval_input_fn = classifier_utils.file_based_input_fn_builder(
            input_file=eval_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=eval_drop_remainder,
            task_name=task_name,
            use_tpu=FLAGS.use_tpu,
            bsz=FLAGS.eval_batch_size)

        eval_hooks = [
            habana_hooks.PerfLoggingHook(batch_size=FLAGS.eval_batch_size,
                                         mode="eval")
        ]
        best_trial_info_file = os.path.join(FLAGS.output_dir, "best_trial.txt")

        def _best_trial_info():
            """Returns information about which checkpoints have been evaled so far."""
            if tf.gfile.Exists(best_trial_info_file):
                with tf.gfile.GFile(best_trial_info_file, "r") as best_info:
                    global_step, best_metric_global_step, metric_value = (
                        best_info.read().split(":"))
                    global_step = int(global_step)
                    best_metric_global_step = int(best_metric_global_step)
                    metric_value = float(metric_value)
            else:
                metric_value = -1
                best_metric_global_step = -1
                global_step = -1
            tf.logging.info(
                "Best trial info: Step: %s, Best Value Step: %s, "
                "Best Value: %s", global_step, best_metric_global_step,
                metric_value)
            return global_step, best_metric_global_step, metric_value

        def _remove_checkpoint(checkpoint_path):
            for ext in ["meta", "data-00000-of-00001", "index"]:
                src_ckpt = checkpoint_path + ".{}".format(ext)
                tf.logging.info("removing {}".format(src_ckpt))
                tf.gfile.Remove(src_ckpt)

        def _find_valid_cands(curr_step):
            filenames = tf.gfile.ListDirectory(model_dir)
            candidates = []
            for filename in filenames:
                if filename.endswith(".index"):
                    ckpt_name = filename[:-6]
                    idx = ckpt_name.split("-")[-1]
                    if int(idx) > curr_step:
                        candidates.append(filename)
            return candidates

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")

        if task_name == "sts-b":
            key_name = "pearson"
        elif task_name == "cola":
            key_name = "matthew_corr"
        else:
            key_name = "eval_accuracy"

        global_step, best_perf_global_step, best_perf = _best_trial_info()
        writer = tf.gfile.GFile(output_eval_file, "w")
        while global_step < FLAGS.train_step:
            steps_and_files = {}
            filenames = tf.gfile.ListDirectory(model_dir)
            for filename in filenames:
                if filename.endswith(".index"):
                    ckpt_name = filename[:-6]
                    cur_filename = os.path.join(model_dir, ckpt_name)
                    if cur_filename.split("-")[-1] == "best":
                        continue
                    gstep = int(cur_filename.split("-")[-1])
                    if gstep not in steps_and_files:
                        tf.logging.info(
                            "Add {} to eval list.".format(cur_filename))
                        steps_and_files[gstep] = cur_filename
            tf.logging.info("found {} files.".format(len(steps_and_files)))
            if not steps_and_files:
                tf.logging.info(
                    "found 0 file, global step: {}. Sleeping.".format(
                        global_step))
                time.sleep(60)
            else:
                for checkpoint in sorted(steps_and_files.items()):
                    step, checkpoint_path = checkpoint
                    if global_step >= step:
                        if (best_perf_global_step != step
                                and len(_find_valid_cands(step)) > 1):
                            _remove_checkpoint(checkpoint_path)
                        continue
                    result = estimator.evaluate(
                        input_fn=eval_input_fn,
                        steps=eval_steps,
                        checkpoint_path=checkpoint_path,
                        hooks=eval_hooks)
                    global_step = result["global_step"]
                    tf.logging.info("***** Eval results *****")
                    for key in sorted(result.keys()):
                        tf.logging.info("  %s = %s", key, str(result[key]))
                        writer.write("%s = %s\n" % (key, str(result[key])))
                    writer.write("best = {}\n".format(best_perf))
                    if result[key_name] > best_perf:
                        best_perf = result[key_name]
                        best_perf_global_step = global_step
                    elif len(_find_valid_cands(global_step)) > 1:
                        _remove_checkpoint(checkpoint_path)
                    writer.write("=" * 50 + "\n")
                    writer.flush()
                    with tf.gfile.GFile(best_trial_info_file,
                                        "w") as best_info:
                        best_info.write("{}:{}:{}".format(
                            global_step, best_perf_global_step, best_perf))
        writer.close()

        for ext in ["meta", "data-00000-of-00001", "index"]:
            src_ckpt = "model.ckpt-{}.{}".format(best_perf_global_step, ext)
            tgt_ckpt = "model.ckpt-best.{}".format(ext)
            tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt))
            tf.io.gfile.rename(os.path.join(model_dir, src_ckpt),
                               os.path.join(model_dir, tgt_ckpt),
                               overwrite=True)

    if FLAGS.do_predict:
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        num_actual_predict_examples = len(predict_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on.
            while len(predict_examples) % FLAGS.predict_batch_size != 0:
                predict_examples.append(classifier_utils.PaddingInputExample())

        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        classifier_utils.file_based_convert_examples_to_features(
            predict_examples, label_list, FLAGS.max_seq_length, tokenizer,
            predict_file, task_name)

        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(predict_examples), num_actual_predict_examples,
                        len(predict_examples) - num_actual_predict_examples)
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_drop_remainder = True if FLAGS.use_tpu else False
        predict_input_fn = classifier_utils.file_based_input_fn_builder(
            input_file=predict_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=predict_drop_remainder,
            task_name=task_name,
            use_tpu=FLAGS.use_tpu,
            bsz=FLAGS.predict_batch_size)

        checkpoint_path = os.path.join(model_dir, "model.ckpt-best")
        result = estimator.predict(input_fn=predict_input_fn,
                                   checkpoint_path=checkpoint_path)

        output_predict_file = os.path.join(FLAGS.output_dir,
                                           "test_results.tsv")
        output_submit_file = os.path.join(FLAGS.output_dir,
                                          "submit_results.tsv")
        with tf.gfile.GFile(output_predict_file, "w") as pred_writer,\
            tf.gfile.GFile(output_submit_file, "w") as sub_writer:
            sub_writer.write("index" + "\t" + "prediction\n")
            num_written_lines = 0
            tf.logging.info("***** Predict results *****")
            for (i, (example, prediction)) in\
                enumerate(zip(predict_examples, result)):
                probabilities = prediction["probabilities"]
                if i >= num_actual_predict_examples:
                    break
                output_line = "\t".join(
                    str(class_probability)
                    for class_probability in probabilities) + "\n"
                pred_writer.write(output_line)

                if task_name != "sts-b":
                    actual_label = label_list[int(prediction["predictions"])]
                else:
                    actual_label = str(prediction["predictions"])
                sub_writer.write(example.guid + "\t" + actual_label + "\n")
                num_written_lines += 1
        assert num_written_lines == num_actual_predict_examples

    if FLAGS.export_dir:
        tf.gfile.MakeDirs(FLAGS.export_dir)
        checkpoint_path = os.path.join(model_dir, "model.ckpt-best")
        tf.logging.info("Starting to export model.")
        subfolder = estimator.export_saved_model(
            export_dir_base=FLAGS.export_dir,
            serving_input_receiver_fn=_serving_input_receiver_fn,
            checkpoint_path=checkpoint_path)
        tf.logging.info("Model exported to %s.", subfolder)
コード例 #18
0
def run_experiment(
    distribution_strategy: tf.distribute.Strategy,
    task: base_task.Task,
    mode: str,
    params: config_definitions.ExperimentConfig,
    model_dir: str,
    run_post_eval: bool = False,
    save_summary: bool = True,
    trainer: Optional[base_trainer.Trainer] = None
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
    """Runs train/eval configured by the experiment params.

  Args:
    distribution_strategy: A distribution distribution_strategy.
    task: A Task instance.
    mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
      or 'continuous_eval'.
    params: ExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.
    run_post_eval: Whether to run post eval once after training, metrics logs
      are returned.
    save_summary: Whether to save train and validation summary.
    trainer: the base_trainer.Trainer instance. It should be created within the
      strategy.scope().

  Returns:
    A 2-tuple of (model, eval_logs).
      model: `tf.keras.Model` instance.
      eval_logs: returns eval metrics logs when run_post_eval is set to True,
        otherwise, returns {}.
  """

    with distribution_strategy.scope():
        if not trainer:
            trainer = train_utils.create_trainer(
                params,
                task,
                train='train' in mode,
                evaluate=('eval' in mode) or run_post_eval,
                checkpoint_exporter=maybe_create_best_ckpt_exporter(
                    params, model_dir))

    if trainer.checkpoint:
        checkpoint_manager = tf.train.CheckpointManager(
            trainer.checkpoint,
            directory=model_dir,
            max_to_keep=params.trainer.max_to_keep,
            step_counter=trainer.global_step,
            checkpoint_interval=params.trainer.checkpoint_interval,
            init_fn=trainer.initialize)
        # Adds recovery handling.
        trainer.add_recovery(params.trainer,
                             checkpoint_manager=checkpoint_manager)
    else:
        checkpoint_manager = None

    #Create logs matching tensorboard log parser format
    #see tensorboard_for_parser.md
    hparams = {
        "batch_size": params.task.train_data.global_batch_size,
        "precision": params.runtime.mixed_precision_dtype
    }

    controller = orbit.Controller(
        strategy=distribution_strategy,
        trainer=trainer if 'train' in mode else None,
        evaluator=trainer,
        global_step=trainer.global_step,
        steps_per_loop=params.trainer.steps_per_loop,
        checkpoint_manager=checkpoint_manager,
        summary_dir=model_dir if (save_summary) else None,
        eval_summary_dir=os.path.join(
            model_dir, params.trainer.validation_summary_subdir) if
        (save_summary) else None,
        summary_interval=params.trainer.summary_interval if
        (save_summary) else None,
        hparams=hparams if (save_summary) else None,
        train_actions=None,
        eval_actions=actions.get_eval_actions(params, trainer, model_dir))

    logging.info('Starts to execute mode: %s', mode)
    with distribution_strategy.scope():
        if (params.runtime.dump_config):
            from TensorFlow.common.debug import dump_callback
        with dump_callback(
                params.runtime.dump_config
        ) if params.runtime.dump_config else contextlib.ExitStack():
            if mode == 'train':
                controller.train(steps=params.trainer.train_steps)
            elif mode == 'train_and_eval':
                controller.train_and_evaluate(
                    train_steps=params.trainer.train_steps,
                    eval_steps=params.trainer.validation_steps,
                    eval_interval=params.trainer.validation_interval)
            elif mode == 'eval':
                controller.evaluate(steps=params.trainer.validation_steps)
            elif mode == 'continuous_eval':

                def timeout_fn():
                    if trainer.global_step.numpy(
                    ) >= params.trainer.train_steps:
                        return True
                    return False

                controller.evaluate_continuously(
                    steps=params.trainer.validation_steps,
                    timeout=params.trainer.continuous_eval_timeout,
                    timeout_fn=timeout_fn)
            else:
                raise NotImplementedError('The mode is not implemented: %s' %
                                          mode)

    num_params = train_utils.try_count_params(trainer.model)
    if num_params is not None:
        logging.info('Number of trainable params in model: %f Millions.',
                     num_params / 10.**6)

    if run_post_eval:
        with distribution_strategy.scope():
            return trainer.model, trainer.evaluate(
                tf.convert_to_tensor(params.trainer.validation_steps))
    else:
        return trainer.model, {}
コード例 #19
0
def train(args, cycle_gan_model, train_ds, test_ds, checkpoint=None, horovod=None):
    gen_X = cycle_gan_model.gen_X
    gen_Y = cycle_gan_model.gen_Y
    cycle_loss_fn = keras.losses.MeanAbsoluteError()
    id_loss_fn = keras.losses.MeanAbsoluteError()
    discriminator_loss_fn, generator_loss_fn = get_adversarial_losses_fn(
        'lsgan')

    lr_opts = dict(
        gen_optimizer=args.generator_lr,
        disc_optimizer=args.discriminator_lr,
    )

    if args.use_horovod:
        for k in lr_opts.keys():
            lr_opts[k] *= math.sqrt(horovod.size())

    # Callbacks
    hooks = []
    if args.use_hooks and (args.log_all_workers or is_local_master(args.use_horovod, horovod)):

        hparams = {
            'batch_size': args.batch_size,
            'precision': args.data_type,
            'epochs': args.epochs,
            'logdir': args.logdir,
            'hvd_workers': args.hvd_workers
        }
        tb = TensorBoardWithHParamsV2(
            hparams, log_dir=os.path.join(args.logdir, "train"))
        examples_per_sec = ExamplesPerSecondKerasHookV2(
            output_dir=os.path.join(args.logdir, "train"), batch_size=args.batch_size)

        # Apply the preprocessing operations to the test data
        file_writer_imgs = tf.summary.create_file_writer(
            os.path.join(args.logdir, 'imgs'))
        plotter = GANMonitor(
            file_writer_imgs, test_ds[0], test_ds[1], freq=args.monitor_freq)
        steps_per_epoch = int(train_ds.reduce(0, lambda x, _: x+1).numpy()
                              ) if args.steps_per_epoch is None else args.steps_per_epoch
        save_every_n_steps = args.save_freq*steps_per_epoch
        checkpoint_filename = "cyclegan_checkpoints.{epoch:03d}"
        hooks += [plotter, tb, examples_per_sec,
                  keras.callbacks.ModelCheckpoint(filepath=os.path.join(args.logdir, checkpoint_filename), save_weights_only=True, save_freq=save_every_n_steps)]

    if not args.no_lr_optimizer:
        from scheduling import MultiOptimizerLR, CosineDecay, Warmup
        lr_fn = CosineDecay(epochs=args.epochs - args.cosine_decay_delay, clif=args.cosine_decay_delay)
        if args.warmup_epochs > 0:
            lr_fn = Warmup(lr_fn, warmup_epochs=args.warmup_epochs)
        scheduler_hook = MultiOptimizerLR(initial_lr=lr_opts, multiplier=lr_fn)
        hooks += [scheduler_hook]

    start_epoch = 0
    if checkpoint:
        print(f'Resuming from {checkpoint}')
        start_epoch = int(re.search(r'[0-9]{3}', checkpoint)[0])
        cycle_gan_model.load_weights(checkpoint)
    else:
        print(f'Couldn\'t find checkpoint at {args.logdir}')

    pool_F = None
    pool_G = None
    if args.pool_size > 0:
        print('Populating pool')
        pool_F = []
        pool_G = []
        for i, (A, B) in enumerate(train_ds):
            if i >= args.pool_size // args.batch_size:
                break
            pool_F.append(gen_X(A))
            pool_G.append(gen_Y(B))
        pool_F = TFPool(tf.concat(pool_F, 0), batch_size=args.batch_size)
        pool_G = TFPool(tf.concat(pool_G, 0), batch_size=args.batch_size)
        print(
            f'Done, sample count- F: {pool_F.pool.shape[0]}, G: {pool_G.pool.shape[0]}')

    cycle_gan_model.compile(
        gen_optimizer=keras.optimizers.Adam(
            learning_rate=lr_opts["gen_optimizer"], beta_1=0.5),
        disc_optimizer=keras.optimizers.Adam(
            learning_rate=lr_opts["disc_optimizer"], beta_1=0.5),
        gen_loss_fn=generator_loss_fn,
        cycle_loss=cycle_loss_fn,
        id_loss=id_loss_fn,
        disc_loss_fn=discriminator_loss_fn,
        hvd=horovod if args.use_horovod else None,
        pool_f=pool_F,
        pool_g=pool_G)
    print('Model is compiled, setting hooks')
    if is_local_master(args.use_horovod, horovod):
        print('Saving initial checkpoint')
        cycle_gan_model.save_weights(os.path.join(
            args.logdir, f'init_checkpoint.{start_epoch:03d}'))
    if args.use_horovod:
        horovod.broadcast_variables(cycle_gan_model.variables, 0)
    print('Start model training')
    with dump_callback(args.dumps_config):
        cycle_gan_model.fit(
            train_ds,
            epochs=args.epochs,
            initial_epoch=start_epoch,
            steps_per_epoch=args.steps_per_epoch,
            callbacks=hooks,
            verbose=is_master(args.use_horovod, horovod),
        )
    if is_local_master(args.use_horovod, horovod):
        print('Saving final checkpoint')
        cycle_gan_model.save_weights(os.path.join(
            args.logdir, f'final_checkpoint.{args.epochs:03d}'))