def before_run(self, run_context):
     if horovod_enabled() and hvd_rank() != 0:
         return
     self.t0 = time.time()
     if self.num_accumulation_steps <= 1:
         if FLAGS.manual_fp16 or FLAGS.amp:
             return tf.estimator.SessionRunArgs(fetches=[
                 'step_update:0', 'total_loss:0', 'learning_rate:0',
                 'nsp_loss:0', 'mlm_loss:0', 'loss_scale:0'
             ])
         else:
             return tf.estimator.SessionRunArgs(fetches=[
                 'step_update:0', 'total_loss:0', 'learning_rate:0',
                 'nsp_loss:0', 'mlm_loss:0'
             ])
     else:
         if FLAGS.manual_fp16 or FLAGS.amp:
             return tf.estimator.SessionRunArgs(fetches=[
                 'step_update:0', 'update_step:0', 'total_loss:0',
                 'learning_rate:0', 'nsp_loss:0', 'mlm_loss:0',
                 'loss_scale:0'
             ])
         else:
             return tf.estimator.SessionRunArgs(fetches=[
                 'step_update:0', 'update_step:0', 'total_loss:0',
                 'learning_rate:0', 'nsp_loss:0', 'mlm_loss:0'
             ])
    def input_fn(params):
        """The actual input function."""
        batch_size = params["batch_size"]

        name_to_features = {
            "input_ids":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids":
            tf.io.FixedLenFeature([max_seq_length], tf.int64),
            "masked_lm_positions":
            tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
            "masked_lm_ids":
            tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
            "masked_lm_weights":
            tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
            "next_sentence_labels":
            tf.io.FixedLenFeature([1], tf.int64),
        }

        # For training, we want a lot of parallel reading and shuffling.
        # For eval, we want no shuffling and parallel reading doesn't matter.
        if is_training:
            d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
            if horovod_enabled(): d = d.shard(hvd_size(), hvd_rank())
            d = d.repeat()
            d = d.shuffle(buffer_size=len(input_files))

            # `cycle_length` is the number of parallel files that get read.
            cycle_length = min(num_cpu_threads, len(input_files))

            # `sloppy` mode means that the interleaving is not exact. This adds
            # even more randomness to the training pipeline.
            d = d.apply(
                tf.data.experimental.parallel_interleave(
                    tf.data.TFRecordDataset,
                    sloppy=is_training,
                    cycle_length=cycle_length))
            d = d.shuffle(buffer_size=100)
        else:
            d = tf.data.TFRecordDataset(input_files)
            # Since we evaluate for a fixed number of steps we don't want to encounter
            # out-of-range exceptions.
            d = d.repeat()

        # We must `drop_remainder` on training because the TPU requires fixed
        # size dimensions. For eval, we assume we are evaluating on the CPU or GPU
        # and we *don't* want to drop the remainder, otherwise we wont cover
        # every sample.
        d = d.apply(
            tf.data.experimental.map_and_batch(
                lambda record: _decode_record(record, name_to_features),
                batch_size=batch_size,
                num_parallel_batches=num_cpu_threads,
                drop_remainder=True if is_training else False))
        return d
示例#3
0
def get_logger(params):
    backends = []
    worker_id = hvd_rank() if horovod_enabled() else 0
    if worker_id == 0:
        backends += [StdOutBackend(Verbosity.VERBOSE)]
        if params.log_dir:
            os.makedirs(params.log_dir, exist_ok=True)
            log_file = f"{params.log_dir}/log.json"
            backends += [JSONStreamBackend(Verbosity.VERBOSE, log_file)]
    logger.init(backends=backends)
    return logger
示例#4
0
def prepare_model_dir(params):
    worker_id = hvd_rank() if horovod_enabled() else 0
    if params.benchmark or (not params.log_all_workers and worker_id != 0):
        return None

    model_dir = os.path.join(params.model_dir, "model_checkpoint")
    if params.log_all_workers and horovod_enabled():
        model_dir = os.path.join(model_dir, f'worker_{worker_id}')

    os.makedirs(model_dir, exist_ok=True)
    if ('train' in params.exec_mode) and (not params.resume_training):
        os.system('rm -rf {}/*'.format(model_dir))
    return model_dir
示例#5
0
def main():
    """
    Starting point of the application
    """
    params = parse_args(description="UNet-medical")
    if params.use_horovod:
        hvd_init()
    set_flags(params)

    model_dir = prepare_model_dir(params)
    params.model_dir = model_dir
    logger = get_logger(params)

    tb_logger = None
    if params.tensorboard_logging:
        log_dir = params.log_dir
        if horovod_enabled() and params.log_all_workers:
            log_dir = os.path.join(log_dir, f'worker_{hvd_rank()}')
        tb_logger = namedtuple('TBSummaryWriters', 'train_writer eval_writer')(
            tf.summary.create_file_writer(log_dir),
            tf.summary.create_file_writer(os.path.join(log_dir, 'eval')))

    model = Unet()

    dataset = Dataset(data_dir=params.data_dir,
                      batch_size=params.batch_size,
                      fold=params.fold,
                      augment=params.augment,
                      hpu_id=hvd_rank() if horovod_enabled() else 0,
                      num_hpus=hvd_size() if horovod_enabled() else 1,
                      seed=params.seed)

    if 'train' in params.exec_mode:
        with dump_callback(params.dump_config):
            train(params, model, dataset, logger, tb_logger)

    if 'evaluate' in params.exec_mode:
        evaluate(params, model, dataset, logger, tb_logger)

    if 'predict' in params.exec_mode:
        predict(params, model, dataset, logger)
def main(_):
    os.environ[
        "TF_XLA_FLAGS"] = "--tf_xla_enable_lazy_compilation=false"  #causes memory fragmentation for bert leading to OOM

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    # In multi-node scenario, on each of HLSes there must be a checkpoint directly in the output_dir (read by Phase 2).
    # There may be only one worker with comm_local_rank() == 0 on each machine and this worker will put its checkpoints there.
    # All other workers use sub-directories to keep checkpoints.
    if horovod_enabled() and comm_local_rank() != 0:
        FLAGS.output_dir = os.path.join(FLAGS.output_dir, str(hvd_rank()))

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    tf.io.gfile.makedirs(FLAGS.output_dir)

    input_files = []
    for input_file_dir in FLAGS.input_files_dir.split(","):
        input_files.extend(tf.io.gfile.glob(os.path.join(input_file_dir, "*")))

    if FLAGS.horovod and len(input_files) < hvd.size():
        tf.compat.v1.logging.warning(
            "Input files count lower then expected. Using single file for OVERFIT test."
        )
        input_files = [input_files[0] for i in range(hvd.size())]
    if FLAGS.amp and FLAGS.manual_fp16:
        raise ValueError(
            "AMP and Manual Mixed Precision Training are both activated! Error"
        )

    is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2

    # The Scoped Allocator Optimization is enabled by default unless disabled by a flag.
    if condition_env_var('TF_DISABLE_SCOPED_ALLOCATOR', default=False):
        session_config = tf.compat.v1.ConfigProto()
    else:
        from tensorflow.core.protobuf import rewriter_config_pb2  # pylint: disable=import-error

        session_config = tf.compat.v1.ConfigProto()
        session_config.graph_options.rewrite_options.scoped_allocator_optimization = rewriter_config_pb2.RewriterConfig.ON

        enable_op = session_config.graph_options.rewrite_options.scoped_allocator_opts.enable_op
        del enable_op[:]
        enable_op.append("HorovodAllreduce")

    if FLAGS.horovod:
        session_config.gpu_options.visible_device_list = str(hvd.local_rank())
        if hvd.rank() == 0:
            tf.compat.v1.logging.info("***** Configuaration *****")
            for key in FLAGS.__flags.keys():
                tf.compat.v1.logging.info('  {}: {}'.format(
                    key, getattr(FLAGS, key)))
            tf.compat.v1.logging.info("**************************")


#    config.gpu_options.per_process_gpu_memory_fraction = 0.7
    if FLAGS.use_xla:
        session_config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
        session_config.graph_options.rewrite_options.memory_optimization = rewriter_config_pb2.RewriterConfig.NO_MEM_OPT
        if FLAGS.amp:
            tf.compat.v1.enable_resource_variables()

    run_config = tf.estimator.RunConfig(
        model_dir=FLAGS.output_dir,
        session_config=session_config,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        save_summary_steps=FLAGS.save_checkpoints_steps
        if not FLAGS.horovod else None,
        log_step_count_steps=1)

    model_fn = model_fn_builder(bert_config=bert_config,
                                init_checkpoint=FLAGS.init_checkpoint,
                                learning_rate=FLAGS.learning_rate
                                if not FLAGS.horovod else FLAGS.learning_rate *
                                hvd_size(),
                                num_train_steps=FLAGS.num_train_steps,
                                num_warmup_steps=FLAGS.num_warmup_steps,
                                use_one_hot_embeddings=False)

    estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)

    if FLAGS.do_train:

        training_hooks = []
        if horovod_enabled():
            training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

            if os.environ.get("FORCE_WEIGHT_SYNC",
                              "False").lower() in ["true", "1"]:
                # Use this hook to allreduce trainable variables before the optimizer run
                training_hooks.append(
                    TrainableVarsAllreducingHookPreOpt(
                        FLAGS.num_accumulation_steps))

        global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps if not FLAGS.horovod else FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size(
        )
        training_hooks.append(
            _LogSessionRunHook(global_batch_size, FLAGS.num_accumulation_steps,
                               dllogging, FLAGS.display_loss_steps,
                               FLAGS.save_checkpoints_steps,
                               FLAGS.report_loss))

        tf.compat.v1.logging.info("***** Running training *****")
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        train_input_fn = input_fn_builder(
            input_files=input_files,
            batch_size=FLAGS.train_batch_size,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=True)

        train_start_time = time.time()
        estimator.train(input_fn=train_input_fn,
                        hooks=training_hooks,
                        max_steps=FLAGS.num_train_steps)
        train_time_elapsed = time.time() - train_start_time

        if (not FLAGS.horovod or hvd_rank() == 0):
            train_time_wo_overhead = training_hooks[-1].total_time
            avg_sentences_per_second = FLAGS.num_train_steps * global_batch_size * 1.0 / train_time_elapsed
            try:
                ss_sentences_per_second = (
                    FLAGS.num_train_steps - training_hooks[-1].skipped
                ) * global_batch_size * 1.0 / train_time_wo_overhead
                throughput_avg_wo_overhead_msg = [
                    "Throughput Average (sentences/sec) = %0.2f",
                    ss_sentences_per_second
                ]
            except:
                ss_sentences_per_second = float('nan')
                throughput_avg_wo_overhead_msg = [
                    f"Throughput Average W/O Overhead is not logged when num_train_steps < {training_hooks[-1].skip_iters}"
                ]

            tf.compat.v1.logging.info("-----------------------------")
            tf.compat.v1.logging.info(
                "Total Training Time = %0.2f for Sentences = %d",
                train_time_elapsed, FLAGS.num_train_steps * global_batch_size)
            tf.compat.v1.logging.info(
                "Total Training Time W/O Overhead = %0.2f for Sentences = %d",
                train_time_wo_overhead,
                (FLAGS.num_train_steps - training_hooks[-1].skipped) *
                global_batch_size)
            tf.compat.v1.logging.info(
                "Throughput Average (sentences/sec) with overhead = %0.2f",
                avg_sentences_per_second)
            tf.compat.v1.logging.info(*throughput_avg_wo_overhead_msg)
            dllogging.logger.log(
                step=(),
                data={"throughput_train": ss_sentences_per_second},
                verbosity=Verbosity.DEFAULT)
            tf.compat.v1.logging.info("-----------------------------")

    if FLAGS.do_eval and (not FLAGS.horovod or hvd_rank() == 0):
        tf.compat.v1.logging.info("***** Running evaluation *****")
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        eval_files = []
        for eval_file_dir in FLAGS.eval_files_dir.split(","):
            eval_files.extend(
                tf.io.gfile.glob(os.path.join(eval_file_dir, "*")))

        eval_input_fn = input_fn_builder(
            input_files=eval_files,
            batch_size=FLAGS.eval_batch_size,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=False)

        eval_hooks = [LogEvalRunHook(FLAGS.eval_batch_size)]
        eval_start_time = time.time()
        result = estimator.evaluate(input_fn=eval_input_fn,
                                    steps=FLAGS.max_eval_steps,
                                    hooks=eval_hooks)

        eval_time_elapsed = time.time() - eval_start_time
        time_list = eval_hooks[-1].time_list
        time_list.sort()
        # Removing outliers (init/warmup) in throughput computation.
        eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
        num_sentences = (int(len(time_list) * 0.99)) * FLAGS.eval_batch_size

        ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead

        tf.compat.v1.logging.info("-----------------------------")
        tf.compat.v1.logging.info(
            "Total Inference Time = %0.2f for Sentences = %d",
            eval_time_elapsed, eval_hooks[-1].count * FLAGS.eval_batch_size)
        tf.compat.v1.logging.info(
            "Total Inference Time W/O Overhead = %0.2f for Sentences = %d",
            eval_time_wo_overhead, num_sentences)
        tf.compat.v1.logging.info("Summary Inference Statistics on EVAL set")
        tf.compat.v1.logging.info("Batch size = %d", FLAGS.eval_batch_size)
        tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
        tf.compat.v1.logging.info("Precision = %s",
                                  "fp16" if FLAGS.amp else "fp32")
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f",
                                  ss_sentences_per_second)
        dllogging.logger.log(step=(),
                             data={"throughput_val": ss_sentences_per_second},
                             verbosity=Verbosity.DEFAULT)
        tf.compat.v1.logging.info("-----------------------------")

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.io.gfile.GFile(output_eval_file, "w") as writer:
            tf.compat.v1.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.compat.v1.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
    def after_run(self, run_context, run_values):
        if horovod_enabled() and hvd_rank() != 0:
            return
        run_time = time.time() - self.t0

        if self.num_accumulation_steps <= 1:
            if FLAGS.manual_fp16 or FLAGS.amp:
                self.global_step, total_loss, lr, nsp_loss, mlm_loss, loss_scaler = run_values.results
            else:
                self.global_step, total_loss, lr, nsp_loss, mlm_loss = run_values. \
                    results
            update_step = True
        else:
            if FLAGS.manual_fp16 or FLAGS.amp:
                self.global_step, update_step, total_loss, lr, nsp_loss, mlm_loss, loss_scaler = run_values.results
            else:
                self.global_step, update_step, total_loss, lr, nsp_loss, mlm_loss = run_values.\
                    results

        self.elapsed_secs += run_time
        self.step_time += run_time

        print_step = self.global_step + 1  # One-based index for printing.
        self.loss += total_loss
        self.all_count += 1
        if update_step:

            self.count += 1

            # Removing first six steps after every checkpoint save from timing
            if (self.global_step - self.init_global_step
                ) % self.save_ckpt_steps < self.skip_iters:
                print("Skipping time record for ", self.global_step,
                      " due to checkpoint-saving/warmup overhead")
                self.skipped += 1
            else:
                self.total_time += self.step_time

            self.step_time = 0.0  #Reset Step Time

            if (print_step == 1 or print_step % self.display_every == 0):
                dt = self.elapsed_secs / self.count
                sent_per_sec = self.global_batch_size / dt
                avg_loss_step = self.loss / self.all_count
                if self.hvd_rank >= 0 and FLAGS.report_loss:
                    if FLAGS.manual_fp16 or FLAGS.amp:
                        self.dllogging.logger.log(step=(print_step),
                                                  data={
                                                      "Rank":
                                                      int(self.hvd_rank),
                                                      "throughput_train":
                                                      float(sent_per_sec),
                                                      "mlm_loss":
                                                      float(mlm_loss),
                                                      "nsp_loss":
                                                      float(nsp_loss),
                                                      "total_loss":
                                                      float(total_loss),
                                                      "avg_loss_step":
                                                      float(avg_loss_step),
                                                      "learning_rate":
                                                      str(lr),
                                                      "loss_scaler":
                                                      int(loss_scaler)
                                                  },
                                                  verbosity=Verbosity.DEFAULT)
                    else:
                        self.dllogging.logger.log(step=int(print_step),
                                                  data={
                                                      "Rank":
                                                      int(self.hvd_rank),
                                                      "throughput_train":
                                                      float(sent_per_sec),
                                                      "mlm_loss":
                                                      float(mlm_loss),
                                                      "nsp_loss":
                                                      float(nsp_loss),
                                                      "total_loss":
                                                      float(total_loss),
                                                      "avg_loss_step":
                                                      float(avg_loss_step),
                                                      "learning_rate":
                                                      str(lr)
                                                  },
                                                  verbosity=Verbosity.DEFAULT)
                else:
                    if FLAGS.manual_fp16 or FLAGS.amp:
                        self.dllogging.logger.log(step=int(print_step),
                                                  data={
                                                      "throughput_train":
                                                      float(sent_per_sec),
                                                      "mlm_loss":
                                                      float(mlm_loss),
                                                      "nsp_loss":
                                                      float(nsp_loss),
                                                      "total_loss":
                                                      float(total_loss),
                                                      "avg_loss_step":
                                                      float(avg_loss_step),
                                                      "learning_rate":
                                                      str(lr),
                                                      "loss_scaler":
                                                      int(loss_scaler)
                                                  },
                                                  verbosity=Verbosity.DEFAULT)
                    else:
                        self.dllogging.logger.log(step=int(print_step),
                                                  data={
                                                      "throughput_train":
                                                      float(sent_per_sec),
                                                      "mlm_loss":
                                                      float(mlm_loss),
                                                      "nsp_loss":
                                                      float(nsp_loss),
                                                      "total_loss":
                                                      float(total_loss),
                                                      "avg_loss_step":
                                                      float(avg_loss_step),
                                                      "learning_rate":
                                                      str(lr)
                                                  },
                                                  verbosity=Verbosity.DEFAULT)

                self.elapsed_secs = 0.0
                self.count = 0
                self.loss = 0.0
                self.all_count = 0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if horovod_enabled():
        FLAGS.output_dir = FLAGS.output_dir if hvd_rank(
        ) == 0 else os.path.join(FLAGS.output_dir, str(hvd_rank()))

    albert_config = modeling.AlbertConfig.from_json_file(
        FLAGS.albert_config_file)
    if FLAGS.deterministic_run and (albert_config.attention_probs_dropout_prob
                                    or albert_config.hidden_dropout_prob):
        albert_config.attention_probs_dropout_prob = 0.0
        albert_config.hidden_dropout_prob = 0.0

    tf.gfile.MakeDirs(FLAGS.output_dir)

    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.gfile.Glob(input_pattern))

    if FLAGS.use_horovod and len(input_files) < hvd.size():
        input_files = [input_files[0] for i in range(hvd.size())]

    tf.logging.info("*** Input Files ***")
    for input_file in input_files:
        tf.logging.info("  %s" % input_file)

    eval_files = []
    for eval_pattern in FLAGS.eval_file.split(","):
        eval_files.extend(tf.gfile.Glob(eval_pattern))

    if FLAGS.use_horovod and len(eval_files) < hvd.size():
        eval_files = [eval_files[0] for i in range(hvd.size())]

    tf.logging.info("*** Eval Files ***")
    for eval_file in eval_files:
        tf.logging.info("  %s" % eval_file)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        save_summary_steps=FLAGS.save_summary_steps,
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    num_train_steps = FLAGS.num_train_steps
    num_warmup_steps = FLAGS.num_warmup_steps
    if FLAGS.do_train and horovod_enabled():
        num_train_steps //= hvd_size()
        num_warmup_steps //= hvd_size()

    model_fn = model_fn_builder(
        albert_config=albert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate
        if not FLAGS.use_horovod else FLAGS.learning_rate * hvd_size(),
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu,
        optimizer=FLAGS.optimizer,
        poly_power=FLAGS.poly_power,
        start_warmup_step=FLAGS.start_warmup_step,
        use_einsum=FLAGS.use_einsum)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    write_hparams_v1(
        FLAGS.output_dir, {
            'batch_size': FLAGS.train_batch_size,
            'batch_size_per_pu': FLAGS.train_batch_size,
            **{x: getattr(FLAGS, x)
               for x in FLAGS}
        })

    if FLAGS.do_train:
        training_hooks = []
        if horovod_enabled():
            training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

        tf.logging.info("***** Running training *****")
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        train_input_fn = input_fn_builder(
            input_files=input_files,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=True)
        with dump_callback():
            estimator.train(input_fn=train_input_fn,
                            hooks=training_hooks,
                            max_steps=FLAGS.num_train_steps)

    if FLAGS.do_eval and (not FLAGS.use_horovod or hvd_rank() == 0):
        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)
        global_step = -1
        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        writer = tf.gfile.GFile(output_eval_file, "w")
        eval_input_fn = input_fn_builder(
            input_files=eval_files,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=False)
        best_perf = 0
        key_name = "masked_lm_accuracy"
        while global_step < FLAGS.num_train_steps:
            if estimator.latest_checkpoint() is None:
                tf.logging.info("No checkpoint found yet. Sleeping.")
                time.sleep(1)
            else:
                result = estimator.evaluate(input_fn=eval_input_fn,
                                            steps=FLAGS.max_eval_steps)
                global_step = result["global_step"]
                tf.logging.info("***** Eval results *****")
                checkpoint_path = estimator.latest_checkpoint()
                for key in sorted(result.keys()):
                    tf.logging.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))
                    if result[key_name] > best_perf:
                        best_perf = result[key_name]
                        for ext in ["meta", "data-00000-of-00001", "index"]:
                            src_ckpt = checkpoint_path + ".{}".format(ext)
                            tgt_ckpt = checkpoint_path.rsplit(
                                "-", 1)[0] + "-best.{}".format(ext)
                            tf.logging.info("saving {} to {}".format(
                                src_ckpt, tgt_ckpt))
                            tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True)
                            writer.write("saved {} to {}\n".format(
                                src_ckpt, tgt_ckpt))
    def input_fn(params):
        """The actual input function."""
        batch_size = params["batch_size"]

        name_to_features = {
            "input_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask": tf.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
            # Note: We keep this feature name `next_sentence_labels` to be
            # compatible with the original data created by lanzhzh@. However, in
            # the ALBERT case it does represent sentence_order_labels.
            "next_sentence_labels": tf.FixedLenFeature([1], tf.int64),
        }

        if FLAGS.masked_lm_budget:
            name_to_features.update({
                "token_boundary":
                tf.FixedLenFeature([max_seq_length], tf.int64)
            })
        else:
            name_to_features.update({
                "masked_lm_positions":
                tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
                "masked_lm_ids":
                tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
                "masked_lm_weights":
                tf.FixedLenFeature([max_predictions_per_seq], tf.float32)
            })

        if FLAGS.deterministic_run:
            d = tf.data.TFRecordDataset(input_files)
            d = d.apply(
                tf.data.experimental.map_and_batch(
                    lambda record: _decode_record(record, name_to_features),
                    batch_size=batch_size,
                    num_parallel_calls=1,
                    drop_remainder=True))
            return d

        # For training, we want a lot of parallel reading and shuffling.
        # For eval, we want no shuffling and parallel reading doesn't matter.
        if is_training:
            d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
            if horovod_enabled():
                d = d.shard(hvd_size(), hvd_rank())
            d = d.repeat()
            d = d.shuffle(buffer_size=len(input_files))

            # `cycle_length` is the number of parallel files that get read.
            cycle_length = min(num_cpu_threads, len(input_files))

            # `sloppy` mode means that the interleaving is not exact. This adds
            # even more randomness to the training pipeline.
            d = d.apply(
                tf.data.experimental.parallel_interleave(
                    tf.data.TFRecordDataset,
                    sloppy=is_training,
                    cycle_length=cycle_length))
            d = d.shuffle(buffer_size=100)
        else:
            d = tf.data.TFRecordDataset(input_files)
            # Since we evaluate for a fixed number of steps we don't want to encounter
            # out-of-range exceptions.
            d = d.repeat()

        # We must `drop_remainder` on training because the TPU requires fixed
        # size dimensions. For eval, we assume we are evaluating on the CPU or GPU
        # and we *don't* want to drop the remainder, otherwise we wont cover
        # every sample.
        d = d.apply(
            tf.data.experimental.map_and_batch_with_legacy_function(
                lambda record: _decode_record(record, name_to_features),
                batch_size=batch_size,
                num_parallel_batches=num_cpu_threads,
                drop_remainder=True))
        tf.logging.info(d)
        return d
def main(_):
    init_bert_flags()
    os.environ[
        "TF_XLA_FLAGS"] = "--tf_xla_enable_lazy_compilation=false"  #causes memory fragmentation for bert leading to OOM

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if horovod_enabled():
        FLAGS.output_dir = FLAGS.output_dir if hvd_rank(
        ) == 0 else os.path.join(FLAGS.output_dir, str(hvd_rank()))

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    tf.io.gfile.makedirs(FLAGS.output_dir)

    input_files = []
    for input_file_dir in FLAGS.input_files_dir.split(","):
        input_files.extend(tf.io.gfile.glob(os.path.join(input_file_dir, "*")))

    if FLAGS.horovod and len(input_files) < hvd.size():
        input_files = [input_files[0] for i in range(hvd.size())]

    if FLAGS.amp and FLAGS.manual_fp16:
        raise ValueError(
            "AMP and Manual Mixed Precision Training are both activated! Error"
        )

    is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.compat.v1.estimator.tpu.RunConfig(
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.compat.v1.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            per_host_input_for_training=is_per_host))

    num_train_steps = FLAGS.num_train_steps
    num_warmup_steps = FLAGS.num_warmup_steps
    if FLAGS.do_train and horovod_enabled():
        num_train_steps //= hvd_size()
        num_warmup_steps //= hvd_size()

    model_fn = model_fn_builder(bert_config=bert_config,
                                init_checkpoint=FLAGS.init_checkpoint,
                                learning_rate=FLAGS.learning_rate
                                if not FLAGS.horovod else FLAGS.learning_rate *
                                hvd_size(),
                                num_train_steps=num_train_steps,
                                num_warmup_steps=num_warmup_steps,
                                use_one_hot_embeddings=False)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.compat.v1.estimator.tpu.TPUEstimator(
        use_tpu=False,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    if FLAGS.do_train:

        training_hooks = []
        if horovod_enabled():
            training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
        if (not FLAGS.horovod or hvd_rank() == 0):
            global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps if not FLAGS.horovod else FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size(
            )
            training_hooks.append(
                _LogSessionRunHook(global_batch_size,
                                   FLAGS.num_accumulation_steps, dllogging,
                                   FLAGS.display_loss_steps,
                                   FLAGS.save_checkpoints_steps,
                                   FLAGS.report_loss))

        tf.compat.v1.logging.info("***** Running training *****")
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        train_input_fn = input_fn_builder(
            input_files=input_files,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=True)

        train_start_time = time.time()
        estimator.train(input_fn=train_input_fn,
                        hooks=training_hooks,
                        max_steps=num_train_steps)
        train_time_elapsed = time.time() - train_start_time

        if (not FLAGS.horovod or hvd_rank() == 0):
            train_time_wo_overhead = training_hooks[-1].total_time
            avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
            ss_sentences_per_second = (
                num_train_steps - training_hooks[-1].skipped
            ) * global_batch_size * 1.0 / train_time_wo_overhead

            tf.compat.v1.logging.info("-----------------------------")
            tf.compat.v1.logging.info(
                "Total Training Time = %0.2f for Sentences = %d",
                train_time_elapsed, num_train_steps * global_batch_size)
            tf.compat.v1.logging.info(
                "Total Training Time W/O Overhead = %0.2f for Sentences = %d",
                train_time_wo_overhead,
                (num_train_steps - training_hooks[-1].skipped) *
                global_batch_size)
            tf.compat.v1.logging.info(
                "Throughput Average (sentences/sec) with overhead = %0.2f",
                avg_sentences_per_second)
            tf.compat.v1.logging.info(
                "Throughput Average (sentences/sec) = %0.2f",
                ss_sentences_per_second)
            dllogging.logger.log(
                step=(),
                data={"throughput_train": ss_sentences_per_second},
                verbosity=Verbosity.DEFAULT)
            tf.compat.v1.logging.info("-----------------------------")

    if FLAGS.do_eval and (not FLAGS.horovod or hvd_rank() == 0):
        tf.compat.v1.logging.info("***** Running evaluation *****")
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        eval_files = []
        for eval_file_dir in FLAGS.eval_files_dir.split(","):
            eval_files.extend(
                tf.io.gfile.glob(os.path.join(eval_file_dir, "*")))

        eval_input_fn = input_fn_builder(
            input_files=input_files,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=False)

        eval_hooks = [LogEvalRunHook(FLAGS.eval_batch_size)]
        eval_start_time = time.time()
        result = estimator.evaluate(input_fn=eval_input_fn,
                                    steps=FLAGS.max_eval_steps,
                                    hooks=eval_hooks)

        eval_time_elapsed = time.time() - eval_start_time
        time_list = eval_hooks[-1].time_list
        time_list.sort()
        # Removing outliers (init/warmup) in throughput computation.
        eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
        num_sentences = (int(len(time_list) * 0.99)) * FLAGS.eval_batch_size

        ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead

        tf.compat.v1.logging.info("-----------------------------")
        tf.compat.v1.logging.info(
            "Total Inference Time = %0.2f for Sentences = %d",
            eval_time_elapsed, eval_hooks[-1].count * FLAGS.eval_batch_size)
        tf.compat.v1.logging.info(
            "Total Inference Time W/O Overhead = %0.2f for Sentences = %d",
            eval_time_wo_overhead, num_sentences)
        tf.compat.v1.logging.info("Summary Inference Statistics on EVAL set")
        tf.compat.v1.logging.info("Batch size = %d", FLAGS.eval_batch_size)
        tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
        tf.compat.v1.logging.info("Precision = %s",
                                  "fp16" if FLAGS.amp else "fp32")
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f",
                                  ss_sentences_per_second)
        dllogging.logger.log(step=(),
                             data={"throughput_val": ss_sentences_per_second},
                             verbosity=Verbosity.DEFAULT)
        tf.compat.v1.logging.info("-----------------------------")

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.io.gfile.GFile(output_eval_file, "w") as writer:
            tf.compat.v1.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.compat.v1.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
示例#11
0
def train(params, model, dataset, logger, tb_logger=None):
    np.random.seed(params.seed)
    tf.random.set_seed(params.seed)

    num_workers = hvd_size() if horovod_enabled() else 1
    worker_id = hvd_rank() if horovod_enabled() else 0
    max_steps = params.max_steps // num_workers

    optimizer = tf.keras.optimizers.Adam(learning_rate=params.learning_rate)

    ce_loss = tf.keras.metrics.Mean(name='ce_loss')
    f1_loss = tf.keras.metrics.Mean(name='dice_loss')
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
    if params.resume_training and params.model_dir:
        checkpoint.restore(tf.train.latest_checkpoint(params.model_dir))

    if tb_logger is not None:
        write_hparams_v2(tb_logger.train_writer, vars(params))

    @tf.function
    def train_step(features, labels, warmup_batch=False):
        with tf.GradientTape() as tape:
            output_map = model(features)
            crossentropy_loss, dice_loss = partial_losses(output_map, labels)
            added_losses = tf.add(crossentropy_loss, dice_loss, name="total_loss_ref")
            loss = added_losses + params.weight_decay * tf.add_n(
                [tf.nn.l2_loss(v) for v in model.trainable_variables
                 if 'batch_normalization' not in v.name])

        if horovod_enabled():
            tape = hvd.DistributedGradientTape(tape)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # Note: broadcast should be done after the first gradient step to ensure optimizer
        # initialization.
        if horovod_enabled() and warmup_batch:
            hvd.broadcast_variables(model.variables, root_rank=0)
            hvd.broadcast_variables(optimizer.variables(), root_rank=0)

        ce_loss(crossentropy_loss)
        f1_loss(dice_loss)
        return loss

    if params.benchmark:
        assert max_steps * num_workers > params.warmup_steps, \
        "max_steps value has to be greater than warmup_steps"
        timestamps = []
        for iteration, (images, labels) in enumerate(dataset.train_fn(drop_remainder=True)):
            loss = train_step(images, labels, warmup_batch=iteration == 0).numpy()
            if iteration > params.warmup_steps:
                timestamps.append(time())

            if iteration >= max_steps * num_workers:
                break

        if worker_id == 0:
            deltas = np.array([timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)])
            stats = process_performance_stats(deltas, num_workers * params.batch_size, mode="train")
            logger.log(step=(), data=stats)
    else:
        timestamp = time()
        dataset_fn = dataset.synth_fn if params.synth_data else dataset.train_fn
        for iteration, (images, labels) in enumerate(dataset_fn()):
            # assign returned loss as a numpy object to transfer the data to host
            loss = train_step(images, labels, warmup_batch=iteration == 0).numpy()
            if worker_id == 0 or params.log_all_workers:
                if iteration % params.log_every == 0:
                    duration = float(time() - timestamp) / params.log_every
                    timestamp = time()
                    data = {
                        "train_ce_loss": float(ce_loss.result()),
                        "train_dice_loss": float(f1_loss.result()),
                        "train_total_loss": float(f1_loss.result() + ce_loss.result()),
                        "iter duration [ms]": 1000 * duration,
                        "IPS": params.batch_size / duration
                    }
                    logger.log(step=(iteration, max_steps), data=data)

                    if tb_logger is not None:
                        with tb_logger.train_writer.as_default():
                            for name, value in data.items():
                                tf.summary.scalar(name, value, step=iteration)
                            # for consistency
                            tf.summary.scalar("loss", data["train_total_loss"], step=iteration)
                            tf.summary.scalar("examples/sec", data["IPS"], step=iteration)
                            tf.summary.scalar("global_step/sec", 1. / duration, step=iteration)

                if (params.evaluate_every > 0) and (iteration % params.evaluate_every == 0):
                    evaluate(params, model, dataset, logger, tb_logger,
                             restore_checkpoint=False)

                f1_loss.reset_states()
                ce_loss.reset_states()

            if iteration >= max_steps:
                break

        if not params.disable_ckpt_saving and worker_id == 0:
            checkpoint.save(file_prefix=os.path.join(params.model_dir, "checkpoint"))

    logger.flush()