Ejemplo n.º 1
0
def train(bert_config, opts):
    # --------------- OPTIONS ---------------------
    epochs = opts["epochs"]
    total_samples = dataset.get_dataset_files_count(opts, is_training=True)
    logger.info("Total samples with duplications {}".format(total_samples))
    total_independent_samples = total_samples // opts['duplication_factor']
    logger.info("Total samples without duplications {}".format(
        total_independent_samples))
    steps_per_epoch = total_independent_samples // (opts['batches_per_step'] *
                                                    opts["total_batch_size"])
    iterations_per_epoch = total_independent_samples // (
        opts["total_batch_size"])

    # total iterations
    if opts['steps']:
        logger.warn("Ignoring the epoch flag and using the steps one")
        steps = opts['steps']
    else:
        steps = epochs * steps_per_epoch
    logger.info(
        "Total training steps equal to {}, total number of samples being analyzed equal to {}"
        .format(steps,
                steps * opts['batches_per_step'] * opts['total_batch_size']))
    iterations_per_step = opts['batches_per_step']
    ckpt_per_step = opts['steps_per_ckpts']

    # avoid nan issue caused by queue length is zero.
    queue_len = iterations_per_epoch // iterations_per_step
    if queue_len == 0:
        queue_len = 1
    batch_times = deque(maxlen=queue_len)

    # learning rate strategy
    lr_schedule_name = opts['lr_schedule']
    logger.info(f"Using learning rate schedule {lr_schedule_name}")
    LR = make_lr_schedule(lr_schedule_name, opts, steps)

    if opts['do_train']:
        # -------------- BUILD TRAINING GRAPH ----------------

        train = build_graph(bert_config,
                            opts,
                            iterations_per_step,
                            is_training=True,
                            feed_name="trainfeed")
        train.session.run(train.init)
        train.session.run(train.iterator.initializer)

        step = 0
        i = 0

        if opts['restore_path'] is not None:
            if os.path.isdir(opts['restore_path']):
                ckpt_file_path = tf.train.latest_checkpoint(
                    opts['restore_path'])
                logger.info(f"Restoring training from latest checkpoint")
            else:
                # Assume it's a directory
                ckpt_file_path = opts['restore_path']

            logger.info(
                f"Restoring training from checkpoint: {ckpt_file_path}")
            train.restore.restore(train.session, ckpt_file_path)

            ckpt_pattern = re.compile(".*ckpt-([0-9]+)$")
            i = int(ckpt_pattern.match(ckpt_file_path).groups()[0])
            step = int(i // iterations_per_step)

        if opts['start_from_ckpt']:
            # We use a checkpoint to initialise our model
            train.restore.restore(train.session, opts['start_from_ckpt'])
            logger.info("Starting the training from the checkpoint {}".format(
                opts['start_from_ckpt']))

        # Initialise Weights & Biases if available
        if opts['wandb']:
            import wandb
            wandb.init(project="tf-bert", sync_tensorboard=True)
            wandb.config.update(opts)

        # Tensorboard logs path
        log_path = os.path.join(opts["logs_path"], 'event')
        logger.info("Tensorboard event file path {}".format(log_path))
        summary_writer = tf.summary.FileWriter(log_path,
                                               train.graph,
                                               session=train.session)

        # ------------- TRAINING LOOP ----------------
        logger.info(
            "################################################################################"
        )
        logger.info("Start training......")
        print_format = (
            "step: {step:6d}, iteration: {iteration:6d}, epoch: {epoch:6.3f}, lr: {lr:10.3g}, mlm_loss: {mlm_loss:6.3f}, nsp_loss: {nsp_loss:6.3f}, "
            "samples/sec: {samples_per_sec:6.2f}, time: {iter_time:8.6f}, total_time: {total_time:8.1f}, mlm_acc: {mlm_acc:8.5f}, nsp_acc: {nsp_acc:8.5f}"
        )

        start_all = time.time()

        train_saver = train.saver["train_saver"]
        best_saver = train.saver["best_saver"]
        # We initialize the best loss to a super large value
        best_total_loss = 1e10
        best_step = 0

        while step < steps:
            # Run Training
            learning_rate = LR.feed_dict_lr(step)
            try:
                batch_time, mlm_loss, nsp_loss, mlm_acc, nsp_acc = training_step(
                    train, learning_rate)
            except tf.errors.OpError as e:
                raise tf.errors.ResourceExhaustedError(e.node_def, e.op,
                                                       e.message)

            epoch = float(
                opts["total_batch_size"] * i) / total_independent_samples

            batch_time /= iterations_per_step

            if step != 0:
                batch_times.append([batch_time])

            if step == 1:
                poplar_compile_time = time.time() - start_all
                poplar_summary = tf.Summary()
                poplar_summary.value.add(tag='poplar/compile_time',
                                         simple_value=poplar_compile_time)
                summary_writer.add_summary(poplar_summary)

            # Print loss
            if step % opts['steps_per_logs'] == 0:
                if len(batch_times) != 0:
                    avg_batch_time = np.mean(batch_times)
                else:
                    avg_batch_time = batch_time

                samples_per_sec = opts['total_batch_size'] / avg_batch_time

                # flush times every time it is reported
                batch_times.clear()

                total_time = time.time() - start_all

                stats = OrderedDict([
                    ('step', step),
                    ('iteration', i),
                    ('epoch', epoch),
                    ('lr', learning_rate),
                    ('mlm_loss', mlm_loss),
                    ('nsp_loss', nsp_loss),
                    ('mlm_acc', mlm_acc),
                    ('nsp_acc', nsp_acc),
                    ('iter_time', avg_batch_time),
                    ('samples_per_sec', samples_per_sec),
                    ('total_time', total_time),
                ])

                logger.info(print_format.format(**stats))
                bert_logging.write_to_csv(stats, i == 0, True,
                                          opts['logs_path'])

                sys_summary = tf.Summary()
                sys_summary.value.add(tag='perf/throughput_samples_per_second',
                                      simple_value=samples_per_sec)
                sys_summary.value.add(tag='perf/average_batch_time',
                                      simple_value=avg_batch_time)
                summary_writer.add_summary(sys_summary, step)

            # Log training statistics
            train_summary = tf.Summary()
            train_summary.value.add(tag='epoch', simple_value=epoch)
            train_summary.value.add(tag='loss/MLM', simple_value=mlm_loss)
            train_summary.value.add(tag='loss/NSP', simple_value=nsp_loss)
            train_summary.value.add(tag='accuracy/MLM', simple_value=mlm_acc)
            train_summary.value.add(tag='accuracy/NSP', simple_value=nsp_acc)
            train_summary.value.add(tag='defaultLearningRate',
                                    simple_value=learning_rate)
            train_summary.value.add(tag='samples',
                                    simple_value=step *
                                    opts['batches_per_step'] *
                                    opts['total_batch_size'])
            summary_writer.add_summary(train_summary, step)
            summary_writer.flush()

            if step % ckpt_per_step == 0 and step:
                filepath = train_saver.save(train.session,
                                            save_path=opts["checkpoint_path"],
                                            global_step=step)
                logger.info("Saved checkpoint to {}".format(filepath))

                if not opts['wandb']:
                    bert_logging.save_model_statistics(filepath,
                                                       summary_writer, step)

            # Mechanism to checkpoint the best model.
            # set opts["best_ckpt_min_steps"] to 0 to disable
            if best_total_loss > mlm_loss + nsp_loss and step - best_step > opts[
                    "best_ckpt_min_steps"] and opts["best_ckpt_min_steps"]:
                best_total_loss = mlm_loss + nsp_loss
                best_step = step
                filepath = best_saver.save(train.session,
                                           save_path=opts["checkpoint_path"] +
                                           '_best',
                                           global_step=step)
                logger.info("Saved Best checkpoint to {}".format(filepath))

            i += iterations_per_step
            step += 1

        # --------------- LAST CHECKPOINT ----------------
        filepath = train_saver.save(train.session,
                                    save_path=opts["checkpoint_path"] +
                                    '_last',
                                    global_step=step)
        logger.info("Final model saved to to {}".format(filepath))

        # --------------- CLEANUP ----------------
        train.session.close()
Ejemplo n.º 2
0
def main(opts):
    tf.logging.set_verbosity(tf.logging.INFO)
    """
    Set up for synthetic data.
    """
    if opts["synthetic_data"] or opts["generated_data"]:
        opts['task_name'] = 'synthetic'
        if opts['task_type'] == 'regression':
            opts['task_name'] = 'synthetic_regression'
    print(opts['task_name'])
    print(opts['task_type'])
    processors = {
        "cola": glue_data.ColaProcessor,
        "mnli": glue_data.MnliProcessor,
        "mrpc": glue_data.MrpcProcessor,
        "sst2": glue_data.Sst2Processor,
        "stsb": glue_data.StsbProcessor,
        "qqp": glue_data.QqpProcessor,
        "qnli": glue_data.QnliProcessor,
        "rte": glue_data.RteProcessor,
        "wnli": glue_data.WnliProcessor,
        "mnli-mm": glue_data.MnliMismatchProcessor,
        "ax": glue_data.AxProcessor,
        "synthetic": glue_data.SyntheticProcessor,
        "synthetic_regression": glue_data.SyntheticProcessorRegression
    }

    tokenization.validate_case_matches_checkpoint(
        do_lower_case=opts["do_lower_case"],
        init_checkpoint=opts["init_checkpoint"])

    tf.gfile.MakeDirs(opts["output_dir"])

    task_name = opts["task_name"].lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(vocab_file=opts["vocab_file"],
                                           do_lower_case=opts["do_lower_case"])
    opts["pass_in"] = (processor, label_list, tokenizer)

    train_examples = None
    # num_train_steps = None
    num_warmup_steps = None
    # So many iterations will be run for one step.
    iterations_per_step = opts['batches_per_step']
    # Avoid nan issue caused by queue length is zero.
    if opts["do_training"]:
        train_examples = processor.get_train_examples(opts["data_dir"])
        num_train_steps = int(
            len(train_examples) / opts["total_batch_size"] * opts['epochs'])
        iterations_per_epoch = len(train_examples) // opts["total_batch_size"]
        if opts.get('num_train_steps'):
            # total iterations
            iterations = opts['num_train_steps'] * opts['batches_per_step']
        else:
            iterations = iterations_per_epoch * opts['epochs']
        num_warmup_steps = int(iterations * opts["warmup"])

        tf.logging.info("***** Running training *****")
        tf.logging.info(f"  Num examples = {len(train_examples)}")
        tf.logging.info(f"  Micro batch size = {opts['micro_batch_size']}")
        tf.logging.info(f"  Num steps / epoch = {iterations_per_epoch}")
        tf.logging.info(f"  Num iterations = {iterations}")
        tf.logging.info(f"  Num steps = {num_train_steps}")
        tf.logging.info(f"  Warm steps = {num_warmup_steps}")
        tf.logging.info(f"  Warm frac = {opts['warmup']}")
        # Learning rate schedule
        lr_schedule_name = opts['lr_schedule']
        logger.info(f"Using learning rate schedule {lr_schedule_name}")
        learning_rate_schedule = make_lr_schedule(lr_schedule_name, opts,
                                                  iterations)

    if opts["do_training"]:
        log_iterations = opts['batches_per_step'] * opts["steps_per_logs"]

        # -------------- BUILD TRAINING GRAPH ----------------
        opts['current_mode'] = 'train'
        train = build_graph(opts, iterations_per_step, is_training=True)
        train.session.run(train.init)
        train.session.run(train.iterator.initializer)

        # Checkpoints load and save
        init_checkpoint_path = opts['init_checkpoint']
        if init_checkpoint_path:
            if os.path.isfile(init_checkpoint_path):
                init_checkpoint_path = os.path.splitext(
                    init_checkpoint_path)[0]

            (assignment_map, initialized_variable_names
             ) = bert_ipu.get_assignment_map_from_checkpoint(
                 train.tvars, init_checkpoint_path)

            for var in train.tvars:
                if var.name in initialized_variable_names:
                    mark = "*"
                else:
                    mark = " "
                logger.info("%-60s [%s]\t%s (%s)", var.name, mark, var.shape,
                            var.dtype.name)

            reader = tf.train.NewCheckpointReader(init_checkpoint_path)
            load_vars = reader.get_variable_to_shape_map()

            saver_restore = tf.train.Saver(assignment_map)
            saver_restore.restore(train.session, init_checkpoint_path)

        if opts['steps_per_ckpts']:
            filepath = train.saver.save(train.session,
                                        opts["checkpoint_path"],
                                        global_step=0)
            logger.info(f"Saved checkpoint to {filepath}")
            ckpt_iterations = opts['batches_per_step'] * \
                opts["steps_per_ckpts"]

        else:
            i = 0

        # Tensorboard logs path
        log_path = os.path.join(opts["logs_path"], 'event')
        logger.info("Tensorboard event file path {}".format(log_path))
        summary_writer = tf.summary.FileWriter(log_path,
                                               train.graph,
                                               session=train.session)
        start_time = datetime.datetime.now()
        # Training loop
        if opts['task_type'] == 'regression':
            print_format = (
                "step: {step:6d}, iteration: {iteration:6d} ({percent_done:.3f}%),  epoch: {epoch:6.2f}, lr: {lr:6.4g}, loss: {loss:6.3f}, pearson: {pearson:6.3f}, spearman: {spearman:6.3f}, "
                "throughput {throughput_samples_per_sec:6.2f} samples/sec, batch time: {avg_batch_time:8.6f} s, total_time: {total_time:8.1f} s"
            )
        else:
            print_format = (
                "step: {step:6d}, iteration: {iteration:6d} ({percent_done:.3f}%),  epoch: {epoch:6.2f}, lr: {lr:6.4g}, loss: {loss:6.3f}, acc: {acc:6.3f}, "
                "throughput {throughput_samples_per_sec:6.2f} samples/sec, batch time: {avg_batch_time:8.6f} s, total_time: {total_time:8.1f} s"
            )
        step = 0
        start_all = time.time()
        i = 0
        total_samples = len(train_examples)

        while i < iterations:
            step += 1
            epoch = float(opts["total_batch_size"] * i) / total_samples

            learning_rate = learning_rate_schedule.get_at_step(step)

            try:
                if opts['task_type'] == 'regression':
                    loss, pred, batch_time, pearson, spearman = training_step(
                        train, learning_rate, i, opts)
                else:
                    loss, batch_time, acc, mean_preds = training_step(
                        train, learning_rate, i, opts)
            except tf.errors.OpError as e:
                raise tf.errors.ResourceExhaustedError(e.node_def, e.op,
                                                       e.message)

            batch_time /= iterations_per_step

            avg_batch_time = batch_time

            if i % log_iterations == 0:
                throughput = opts['total_batch_size'] / avg_batch_time

                # flush times every time it is reported
                # batch_times.clear()

                total_time = time.time() - start_all
                if opts['task_type'] == 'regression':
                    stats = OrderedDict([
                        ('step', step), ('iteration', i + iterations_per_step),
                        ('percent_done', i / iterations * 100), ('epoch',
                                                                 epoch),
                        ('lr', learning_rate), ('loss', loss),
                        ('pearson', pearson), ('spearman', spearman),
                        ('avg_batch_time', avg_batch_time),
                        ('throughput_samples_per_sec', throughput),
                        ('total_time', total_time),
                        ('learning_rate', learning_rate)
                    ])
                else:
                    stats = OrderedDict([
                        ('step', step), ('iteration', i + iterations_per_step),
                        ('percent_done', i / iterations * 100), ('epoch',
                                                                 epoch),
                        ('lr', learning_rate), ('loss', loss), ('acc', acc),
                        ('avg_batch_time', avg_batch_time),
                        ('throughput_samples_per_sec', throughput),
                        ('total_time', total_time),
                        ('learning_rate', learning_rate)
                    ])
                logger.info(print_format.format(**stats))

                train_summary = tf.Summary()
                train_summary.value.add(tag='epoch', simple_value=epoch)
                train_summary.value.add(tag='loss', simple_value=loss)
                if opts['task_type'] == 'regression':
                    train_summary.value.add(tag='pearson',
                                            simple_value=pearson)
                    train_summary.value.add(tag='spearman',
                                            simple_value=spearman)
                else:
                    train_summary.value.add(tag='acc', simple_value=acc)
                train_summary.value.add(tag='learning_rate',
                                        simple_value=learning_rate)
                train_summary.value.add(tag='througput',
                                        simple_value=throughput)

                if opts['wandb']:
                    wandb.log(dict(stats))

                summary_writer.add_summary(train_summary, step)
                summary_writer.flush()

            if i % ckpt_iterations == 0 and i > 1:
                filepath = train.saver.save(train.session,
                                            opts["checkpoint_path"],
                                            global_step=i +
                                            iterations_per_step)
                logger.info(f"Saved checkpoint to {filepath}")

            i += iterations_per_step

        # We save the final checkpoint
        finetuned_checkpoint_path = train.saver.save(train.session,
                                                     opts["checkpoint_path"],
                                                     global_step=i +
                                                     iterations_per_step)
        logger.info(f"Saved checkpoint to {finetuned_checkpoint_path}")
        train.session.close()
        end_time = datetime.datetime.now()
        consume_time = (end_time - start_time).seconds
        logger.info(f"training times: {consume_time} s")

    if opts["do_eval"]:
        eval_examples = processor.get_dev_examples(opts["data_dir"])
        num_actual_eval_examples = len(eval_examples)
        opts["eval_batch_size"] = opts['micro_batch_size'] * \
            opts['gradient_accumulation_count']

        eval_file = os.path.join(opts["output_dir"], "eval.tf_record")

        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(eval_examples), num_actual_eval_examples,
                        len(eval_examples) - num_actual_eval_examples)
        tf.logging.info("  Evaluate batch size = %d", opts["eval_batch_size"])

        iterations_per_step = 1
        opts['current_mode'] = 'eval'
        predict = build_graph(opts, iterations_per_step, is_training=False)
        predict.session.run(predict.init)
        predict.session.run(predict.iterator.initializer)

        if opts["init_checkpoint"] and not opts['do_training'] and opts[
                'do_eval']:
            finetuned_checkpoint_path = opts['init_checkpoint']

        if finetuned_checkpoint_path:
            print("********** RESTORING FROM CHECKPOINT *************")
            (assignment_map, _initialized_variable_names
             ) = bert_ipu.get_assignment_map_from_checkpoint(
                 predict.tvars, finetuned_checkpoint_path)
            saver_restore = tf.train.Saver(assignment_map)
            saver_restore.restore(predict.session, finetuned_checkpoint_path)
            print("Done.")

        i = 0
        all_time_consumption = []

        iterations = int(
            len(eval_examples) //
            (opts['micro_batch_size'] * opts['gradient_accumulation_count']) +
            1)

        all_accs = []
        all_pearson = []
        all_spearman = []
        all_loss = []
        while i < iterations:
            try:
                start = time.time()
                tmp_output = predict_step(predict)
                if opts['task_type'] == 'regression':
                    all_pearson.append(tmp_output['pearson'])
                    all_spearman.append(tmp_output['spearman'])
                else:
                    all_accs.append(tmp_output['acc'])
                all_loss.append(tmp_output['loss'])
                output_eval_file = os.path.join(opts['output_dir'],
                                                "eval_results.txt")
                duration = time.time() - start
                all_time_consumption.append(duration /
                                            opts["batches_per_step"])
            except tf.errors.OpError as e:
                raise tf.errors.ResourceExhaustedError(e.node_def, e.op,
                                                       e.message)

            i += iterations_per_step

            if len(all_loss) % 1000 == 0:
                logger.info(f"Procesing example: {len(all_loss)}")
        if opts['task_type'] == 'regression':
            tmp_output['average_pearson'] = np.mean(all_pearson)
            tmp_output['average_spearman'] = np.mean(all_spearman)
        else:
            tmp_output['average_acc'] = np.mean(all_accs)
        tmp_output['average_loss'] = np.mean(all_loss)

        with tf.gfile.GFile(output_eval_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(tmp_output.keys()):
                tf.logging.info("  %s = %s", key, str(tmp_output[key]))
                writer.write("%s = %s\n" % (key, str(tmp_output[key])))
        # The time consumption of First 10 steps is not stable for time measurement.
        if len(all_time_consumption) >= 10 * 2:
            all_time_consumption = np.array(all_time_consumption[10:])
        else:
            logger.warning(
                f"if the first 10 steps is counted, the measurement of throughtput and latency is not accurate."
            )
            all_time_consumption = np.array(all_time_consumption)

        logger.info((
            f"inference throughput: { (opts['micro_batch_size'] * opts['gradient_accumulation_count'] ) / all_time_consumption.mean() } "
            f"exmples/sec - Latency: {all_time_consumption.mean()} {all_time_consumption.min()} "
            f"{all_time_consumption.max()} (mean min max) sec "))
        # Done evaluations

    if opts["do_predict"]:
        predict_examples = processor.get_test_examples(opts["data_dir"])
        num_actual_predict_examples = len(predict_examples)
        opts["predict_batch_size"] = opts['micro_batch_size'] * \
            opts['gradient_accumulation_count']
        tf.logging.info("***** Running prediction *****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(predict_examples), num_actual_predict_examples,
                        len(predict_examples) - num_actual_predict_examples)
        tf.logging.info("  Predict batch size = %d",
                        opts["predict_batch_size"])

        iterations_per_step = 1
        opts['current_mode'] = 'predict'
        prediction = build_graph(opts, iterations_per_step, is_training=False)
        prediction.session.run(prediction.init)
        prediction.session.run(prediction.iterator.initializer)

        if opts["init_checkpoint"] and not opts['do_training'] and opts[
                'do_predict']:
            finetuned_checkpoint_path = opts['init_checkpoint']
        else:
            finetuned_checkpoint_path = False

        if finetuned_checkpoint_path:
            print("********** RESTORING FROM CHECKPOINT *************")
            (assignment_map, _initialized_variable_names
             ) = bert_ipu.get_assignment_map_from_checkpoint(
                 prediction.tvars, finetuned_checkpoint_path)
            saver_restore = tf.train.Saver(assignment_map)
            saver_restore.restore(prediction.session,
                                  finetuned_checkpoint_path)
            print("Done.")

        all_results = []
        i = 0
        all_time_consumption = []

        iterations = int(
            len(predict_examples) //
            (opts['micro_batch_size'] * opts['gradient_accumulation_count']) +
            1)

        all_preds = []
        while i < iterations:
            try:
                start = time.time()
                tmp_output = predict_step(prediction)
                all_preds.append(tmp_output['preds'])

                output_predict_file = os.path.join(opts['output_dir'],
                                                   "predict_results.txt")
                duration = time.time() - start
                all_time_consumption.append(duration /
                                            opts["batches_per_step"])
            except tf.errors.OpError as e:
                raise tf.errors.ResourceExhaustedError(e.node_def, e.op,
                                                       e.message)

            i += iterations_per_step

        all_preds = np.array(all_preds)
        all_preds = all_preds.flatten()
        headers = ["index", "prediction"]
        name_list = ["mnli", "mnli-mm", "ax", "qnli", "rte"]
        if task_name in name_list:
            all_preds = glue_data.get_output_labels(opts, all_preds)

        with tf.gfile.GFile(output_predict_file, "w") as writer:
            tf.logging.info("***** Predict results writing*****")
            for i in range(len(predict_examples)):
                if i == 0:
                    writer.write("%s\t%s\n" %
                                 (str(headers[0]), str(headers[1])))
                output_line = "%s\t%s\n" % (i, all_preds[i])
                writer.write(output_line)
Ejemplo n.º 3
0
def train(opts):
    # --------------- OPTIONS ---------------------
    total_samples = data_loader.get_dataset_files_count(opts, is_training=True)
    opts["dataset_repeat"] = math.ceil(
        (opts["num_train_steps"] * opts["global_batch_size"]) / total_samples)

    total_samples_per_epoch = total_samples / opts["duplicate_factor"]
    logger.info(f"Total samples for each epoch {total_samples_per_epoch}")
    logger.info(f"Global batch size {opts['global_batch_size']}")
    steps_per_epoch = total_samples_per_epoch // opts["global_batch_size"]
    logger.info(f"Total steps for each epoch {steps_per_epoch}")

    steps_per_logs = math.ceil(
        opts["steps_per_logs"] /
        opts['batches_per_step']) * opts['batches_per_step']
    steps_per_tensorboard = math.ceil(
        opts["steps_per_tensorboard"] /
        opts['batches_per_step']) * opts['batches_per_step']
    steps_per_ckpts = math.ceil(
        opts["steps_per_ckpts"] /
        opts['batches_per_step']) * opts['batches_per_step']
    logger.info(f"Checkpoint will be saved every {steps_per_ckpts} steps.")

    total_steps = (opts["num_train_steps"] //
                   opts['batches_per_step']) * opts['batches_per_step']
    logger.info(
        f"{opts['batches_per_step']} steps will be run for ipu to host synchronization once, it should be divided by num_train_steps, so num_train_steps will limit to {total_steps}.",
        opts)

    # learning rate strategy
    lr_schedule_name = opts['lr_schedule']
    logger.info(f"Using learning rate schedule {lr_schedule_name}")
    learning_rate_schedule = make_lr_schedule(lr_schedule_name, opts,
                                              total_steps)

    # variable loss scaling
    loss_scaling_schedule = LossScalingScheduler(opts['loss_scaling'],
                                                 opts['loss_scaling_by_step'])

    # -------------- BUILD TRAINING GRAPH ----------------
    train = build_graph(opts, is_training=True)
    train.session.run(train.init)
    train.session.run(train.iterator.initializer)

    is_main_worker = opts['distributed_worker_index'] == 0

    step = 0
    # -------------- SAVE AND RESTORE --------------
    if opts["restore_dir"]:
        restore_path = opts['restore_dir']
        if os.path.isfile(restore_path):
            latest_checkpoint = os.path.splitext(restore_path)[0]
        else:
            latest_checkpoint = tf.train.latest_checkpoint(restore_path)
        logger.info(
            f"Restoring training from latest checkpoint: {latest_checkpoint}")
        step_pattern = re.compile(".*ckpt-([0-9]+)$")
        step = int(step_pattern.match(latest_checkpoint).groups()[0])
        train.saver.restore(train.session, latest_checkpoint)
        epoch = step / steps_per_epoch

        # restore event files
        source_path = os.path.join(opts["restore_dir"], '/event')
        target_path = os.path.join(opts["save_path"], '/event')
        if os.path.isdir(source_path):
            copytree(source_path, target_path)
    else:
        if opts["init_checkpoint"]:
            train.saver.restore(train.session, opts["init_checkpoint"])
            logger.info(
                f'Init Model from checkpoint {opts["init_checkpoint"]}')

    if opts['save_path']:
        file_path = train.saver.save(train.session,
                                     opts["checkpoint_path"],
                                     global_step=0)
        logger.info(f"Saved checkpoint to {file_path}")

    # Initialise Weights & Biases if available
    if opts['wandb'] and is_main_worker:
        import wandb
        wandb.init(project="tf-bert",
                   sync_tensorboard=True,
                   name=opts['wandb_name'])
        wandb.config.update(opts)

    # Tensorboard logs path
    log_path = os.path.join(opts["logs_path"], 'event')
    logger.info("Tensorboard event file path {}".format(log_path))
    summary_writer = tf.summary.FileWriter(log_path,
                                           train.graph,
                                           session=train.session)

    # End to avoid any training if compile only mode
    if opts['compile_only']:

        # single warm up step without weight update or training
        # Graph gets compiled in here
        compilation_time, _, _, _, _ = training_step(train, 0, 0)

        print("Training graph successfully compiled. " +
              "Exiting as --compile-only was passed.")

        # Copying these from below, adding compile time to summary
        poplar_summary = tf.Summary()
        poplar_summary.value.add(tag='poplar/compile_time',
                                 simple_value=compilation_time)
        summary_writer.add_summary(poplar_summary)
        summary_writer.flush()

        logger.info("Compile time: {}".format(compilation_time))

        sys.exit(0)

    # ------------- TRAINING LOOP ----------------
    print_format = (
        "step: {step:6d}, epoch: {epoch:6.2f}, lr: {lr:6.7f}, mlm_loss: {mlm_loss:6.3f}, nsp_loss: {nsp_loss:6.3f},\
        mlm_acc: {mlm_acc:6.5f}, nsp_acc: {nsp_acc:6.5f}, samples/sec: {samples_per_sec:6.2f}, time: {iter_time:8.6f}, total_time: {total_time:8.1f}"
    )
    learning_rate = mlm_loss = nsp_loss = 0
    start_all = time.time()

    try:
        while step < total_steps:
            learning_rate = learning_rate_schedule.get_at_step(step)
            loss_scaling = loss_scaling_schedule.get_at_step(step)
            try:
                batch_time, mlm_loss, nsp_loss, mlm_acc, nsp_acc = training_step(
                    train, learning_rate, loss_scaling)
            except tf.errors.OpError as e:
                raise tf.errors.ResourceExhaustedError(e.node_def, e.op,
                                                       e.message)

            batch_time /= opts['batches_per_step']

            is_log_step = (step % steps_per_logs == 0)
            is_save_tensorboard_step = (steps_per_tensorboard > 0 and
                                        (step % steps_per_tensorboard == 0))
            is_save_ckpt_step = (step and
                                 (step % steps_per_ckpts == 0 or step
                                  == total_steps - opts['batches_per_step']))

            if (step == 1 and (is_main_worker or opts['log_all_workers'])):
                poplar_compile_time = time.time() - start_all
                logger.info(f"Poplar compile time: {poplar_compile_time:.2f}s")
                poplar_summary = tf.Summary()
                poplar_summary.value.add(tag='poplar/compile_time',
                                         simple_value=poplar_compile_time)
                summary_writer.add_summary(poplar_summary)

            if is_log_step:
                total_time = time.time() - start_all
                epoch = step / steps_per_epoch
                stats = OrderedDict([
                    ('step', step),
                    ('epoch', epoch),
                    ('lr', learning_rate),
                    ('loss_scaling', loss_scaling),
                    ('mlm_loss', mlm_loss),
                    ('nsp_loss', nsp_loss),
                    ('mlm_acc', mlm_acc),
                    ('nsp_acc', nsp_acc),
                    ('iter_time', batch_time),
                    ('samples_per_sec',
                     opts['global_batch_size'] / batch_time),
                    ('total_time', total_time),
                ])

                logger.info(print_format.format(**stats))

            # Log training statistics
            train_summary = tf.Summary()
            train_summary.value.add(tag='epoch', simple_value=epoch)
            train_summary.value.add(tag='loss/MLM', simple_value=mlm_loss)
            train_summary.value.add(tag='loss/NSP', simple_value=nsp_loss)
            train_summary.value.add(tag='accuracy/MLM', simple_value=mlm_acc)
            train_summary.value.add(tag='accuracy/NSP', simple_value=nsp_acc)
            train_summary.value.add(tag='learning_rate',
                                    simple_value=learning_rate)
            train_summary.value.add(tag='loss_scaling',
                                    simple_value=loss_scaling)
            train_summary.value.add(tag='samples_per_sec',
                                    simple_value=opts['global_batch_size'] /
                                    batch_time)
            train_summary.value.add(tag='samples',
                                    simple_value=step *
                                    opts['batches_per_step'] *
                                    opts['global_batch_size'])
            summary_writer.add_summary(train_summary, step)
            summary_writer.flush()

            if is_save_ckpt_step or is_save_tensorboard_step:
                if is_main_worker:
                    file_path = train.saver.save(train.session,
                                                 opts["checkpoint_path"],
                                                 global_step=step)
                    logger.info(f"Saved checkpoint to {file_path}")

                    if is_save_tensorboard_step:
                        log.save_model_statistics(file_path, summary_writer,
                                                  step)

                if opts['use_popdist']:
                    ipu_utils.barrier()

            step += opts['batches_per_step']
    finally:
        train.session.close()
Ejemplo n.º 4
0
def training_loop(opts):

    consume_time = None

    if opts["version_2_with_negative"]:
        base_name_train = f"{opts['seq_length']}_{opts['doc_stride']}_{opts['max_query_length']}_SQuAD20"
    else:
        base_name_train = f"{opts['seq_length']}_{opts['doc_stride']}_{opts['max_query_length']}_SQuAD11"

    train_metafile = os.path.join(opts["tfrecord_dir"],
                                  "train_" + base_name_train + ".metadata")
    if os.path.exists(train_metafile):
        with open(train_metafile) as f:
            total_samples = int(f.readline())
    else:
        if opts["version_2_with_negative"]:
            logger.info(
                f"SQUAD 2.0 DATASET SIZE 131944 (based on no. of features).")
            total_samples = 131944
        else:
            logger.info(
                f"SQUAD 1.1 DATASET SIZE 88641 (based on no. of features).")
            total_samples = 88641

    logger.info(f"Total samples {total_samples}")
    iterations_per_epoch = total_samples // opts["total_batch_size"]
    log_iterations = opts['batches_per_step'] * opts["steps_per_logs"]
    ckpt_iterations = opts['batches_per_step'] * opts["steps_per_ckpts"]

    if opts.get('num_train_steps'):
        # total iterations
        iterations = opts['num_train_steps'] * opts['batches_per_step']
    elif opts.get('epochs'):
        iterations = iterations_per_epoch * opts['epochs']
    else:
        logger.error("One between epochs and num_train_step must be set")
        sys.exit(os.EX_OK)

    logger.info(
        f"Training will last {iterations} iterations and {iterations//opts['batches_per_step']} steps will be executed."
    )

    # So many iterations will be run for one step.
    iterations_per_step = opts['batches_per_step']
    # Avoid nan issue caused by queue length is zero.
    queue_len = iterations_per_epoch // iterations_per_step
    if queue_len == 0:
        queue_len = 1
    batch_times = deque(maxlen=queue_len)

    total_steps = (iterations //
                   opts['batches_per_step']) * opts['batches_per_step']

    # Learning rate schedule
    lr_schedule_name = opts['lr_schedule']
    logger.info(f"Using learning rate schedule {lr_schedule_name}")
    learning_rate_schedule = make_lr_schedule(lr_schedule_name, opts,
                                              total_steps)

    # -------------- BUILD TRAINING GRAPH ----------------
    train = build_graph(opts, iterations_per_step, is_training=True)
    train.session.run(train.init)
    train.session.run(train.iterator.initializer)

    # Checkpoints restore and save
    init_checkpoint_path = opts['init_checkpoint']
    if init_checkpoint_path and not opts.get('generated_data', False):
        if os.path.isfile(init_checkpoint_path):
            init_checkpoint_path = os.path.splitext(init_checkpoint_path)[0]

        (assignment_map, initialized_variable_names
         ) = bert_ipu.get_assignment_map_from_checkpoint(
             train.tvars, init_checkpoint_path)

        for var in train.tvars:
            if var.name in initialized_variable_names:
                mark = "*"
            else:
                mark = " "
            logger.info("%-60s [%s]\t%s (%s)", var.name, mark, var.shape,
                        var.dtype.name)

        reader = tf.train.NewCheckpointReader(init_checkpoint_path)
        load_vars = reader.get_variable_to_shape_map()

        saver_restore = tf.train.Saver(assignment_map)
        saver_restore.restore(train.session, init_checkpoint_path)

    if opts['steps_per_ckpts']:
        filepath = train.saver.save(train.session,
                                    opts["checkpoint_path"],
                                    global_step=0)
        logger.info(f"Saved checkpoint to {filepath}")

    if opts.get('restore_dir'):
        restore_path = opts['restore_dir']
        if os.path.isfile(restore_path):
            latest_checkpoint = os.path.splitext(restore_path)[0]
        else:
            latest_checkpoint = tf.train.latest_checkpoint(restore_path)
        ckpt_pattern = re.compile(".*ckpt-([0-9]+)$")
        i = int(ckpt_pattern.match(latest_checkpoint).groups()[0]) + 1
        train.saver.restore(train.session, latest_checkpoint)
        epoch = float(opts["total_batch_size"] *
                      (i + iterations_per_step)) / total_samples
    else:
        i = 0

    # Tensorboard logs path
    log_path = os.path.join(opts["logs_path"], 'event')
    logger.info("Tensorboard event file path {}".format(log_path))
    summary_writer = tf.summary.FileWriter(log_path,
                                           train.graph,
                                           session=train.session)
    start_time = datetime.datetime.now()

    # Training loop
    print_format = (
        "step: {step:6d}, iteration: {iteration:6d}, epoch: {epoch:6.2f}, lr: {lr:6.4g}, loss: {loss:6.3f}, "
        "throughput {throughput_samples_per_sec:6.2f} samples/sec, batch time: {avg_batch_time:8.6f} s, total_time: {total_time:8.1f} s"
    )
    step = 0
    start_all = time.time()

    while i < iterations:
        step += 1
        epoch = float(opts["total_batch_size"] * i) / total_samples

        learning_rate = learning_rate_schedule.get_at_step(step)

        try:
            loss, batch_time = training_step(train, learning_rate)
        except tf.errors.OpError as e:
            raise tf.errors.ResourceExhaustedError(e.node_def, e.op, e.message)

        batch_time /= iterations_per_step

        if i != 0:
            batch_times.append([batch_time])
            avg_batch_time = np.mean(batch_times)
        else:
            avg_batch_time = batch_time

        if i % log_iterations == 0:
            throughput = opts['total_batch_size'] / avg_batch_time

            # flush times every time it is reported
            batch_times.clear()

            total_time = time.time() - start_all

            stats = OrderedDict([('step', step),
                                 ('iteration', i + iterations_per_step),
                                 ('epoch', epoch), ('lr', learning_rate),
                                 ('loss', loss),
                                 ('avg_batch_time', avg_batch_time),
                                 ('throughput_samples_per_sec', throughput),
                                 ('total_time', total_time),
                                 ('learning_rate', learning_rate)])
            logger.info(print_format.format(**stats))

            train_summary = tf.Summary()
            train_summary.value.add(tag='epoch', simple_value=epoch)
            train_summary.value.add(tag='loss', simple_value=loss)
            train_summary.value.add(tag='learning_rate',
                                    simple_value=learning_rate)
            train_summary.value.add(tag='througput', simple_value=throughput)

            if opts['wandb']:
                wandb.log(dict(stats))

            summary_writer.add_summary(train_summary, step)
            summary_writer.flush()

        if i % ckpt_iterations == 0:
            filepath = train.saver.save(train.session,
                                        opts["checkpoint_path"],
                                        global_step=i + iterations_per_step)
            logger.info(f"Saved checkpoint to {filepath}")

        i += iterations_per_step

    # We save the final checkpoint
    finetuned_checkpoint_path = train.saver.save(train.session,
                                                 opts["checkpoint_path"],
                                                 global_step=i +
                                                 iterations_per_step)
    logger.info(f"Saved checkpoint to {finetuned_checkpoint_path}")
    train.session.close()
    end_time = datetime.datetime.now()
    consume_time = (end_time - start_time).seconds
    logger.info(f"training times: {consume_time} s")
    return finetuned_checkpoint_path