示例#1
0
def train(opts):
    # --------------- OPTIONS ---------------------
    total_samples = data_loader.get_dataset_files_count(opts, is_training=True)
    opts["dataset_repeat"] = math.ceil(
        (opts["num_train_steps"] * opts["global_batch_size"]) / total_samples)

    total_samples_per_epoch = total_samples / opts["duplicate_factor"]
    logger.info(f"Total samples for each epoch {total_samples_per_epoch}")
    logger.info(f"Global batch size {opts['global_batch_size']}")
    steps_per_epoch = total_samples_per_epoch // opts["global_batch_size"]
    logger.info(f"Total steps for each epoch {steps_per_epoch}")

    steps_per_logs = math.ceil(
        opts["steps_per_logs"] /
        opts['batches_per_step']) * opts['batches_per_step']
    steps_per_tensorboard = math.ceil(
        opts["steps_per_tensorboard"] /
        opts['batches_per_step']) * opts['batches_per_step']
    steps_per_ckpts = math.ceil(
        opts["steps_per_ckpts"] /
        opts['batches_per_step']) * opts['batches_per_step']
    logger.info(f"Checkpoint will be saved every {steps_per_ckpts} steps.")

    total_steps = (opts["num_train_steps"] //
                   opts['batches_per_step']) * opts['batches_per_step']
    logger.info(
        f"{opts['batches_per_step']} steps will be run for ipu to host synchronization once, it should be divided by num_train_steps, so num_train_steps will limit to {total_steps}.",
        opts)

    # learning rate strategy
    lr_schedule_name = opts['lr_schedule']
    logger.info(f"Using learning rate schedule {lr_schedule_name}")
    learning_rate_schedule = make_lr_schedule(lr_schedule_name, opts,
                                              total_steps)

    # variable loss scaling
    loss_scaling_schedule = LossScalingScheduler(opts['loss_scaling'],
                                                 opts['loss_scaling_by_step'])

    # -------------- BUILD TRAINING GRAPH ----------------
    train = build_graph(opts, is_training=True)
    train.session.run(train.init)
    train.session.run(train.iterator.initializer)

    is_main_worker = opts['distributed_worker_index'] == 0

    step = 0
    # -------------- SAVE AND RESTORE --------------
    if opts["restore_dir"]:
        restore_path = opts['restore_dir']
        if os.path.isfile(restore_path):
            latest_checkpoint = os.path.splitext(restore_path)[0]
        else:
            latest_checkpoint = tf.train.latest_checkpoint(restore_path)
        logger.info(
            f"Restoring training from latest checkpoint: {latest_checkpoint}")
        step_pattern = re.compile(".*ckpt-([0-9]+)$")
        step = int(step_pattern.match(latest_checkpoint).groups()[0])
        train.saver.restore(train.session, latest_checkpoint)
        epoch = step / steps_per_epoch

        # restore event files
        source_path = os.path.join(opts["restore_dir"], '/event')
        target_path = os.path.join(opts["save_path"], '/event')
        if os.path.isdir(source_path):
            copytree(source_path, target_path)
    else:
        if opts["init_checkpoint"]:
            train.saver.restore(train.session, opts["init_checkpoint"])
            logger.info(
                f'Init Model from checkpoint {opts["init_checkpoint"]}')

    if opts['save_path']:
        file_path = train.saver.save(train.session,
                                     opts["checkpoint_path"],
                                     global_step=0)
        logger.info(f"Saved checkpoint to {file_path}")

    # Initialise Weights & Biases if available
    if opts['wandb'] and is_main_worker:
        import wandb
        wandb.init(project="tf-bert",
                   sync_tensorboard=True,
                   name=opts['wandb_name'])
        wandb.config.update(opts)

    # Tensorboard logs path
    log_path = os.path.join(opts["logs_path"], 'event')
    logger.info("Tensorboard event file path {}".format(log_path))
    summary_writer = tf.summary.FileWriter(log_path,
                                           train.graph,
                                           session=train.session)

    # End to avoid any training if compile only mode
    if opts['compile_only']:

        # single warm up step without weight update or training
        # Graph gets compiled in here
        compilation_time, _, _, _, _ = training_step(train, 0, 0)

        print("Training graph successfully compiled. " +
              "Exiting as --compile-only was passed.")

        # Copying these from below, adding compile time to summary
        poplar_summary = tf.Summary()
        poplar_summary.value.add(tag='poplar/compile_time',
                                 simple_value=compilation_time)
        summary_writer.add_summary(poplar_summary)
        summary_writer.flush()

        logger.info("Compile time: {}".format(compilation_time))

        sys.exit(0)

    # ------------- TRAINING LOOP ----------------
    print_format = (
        "step: {step:6d}, epoch: {epoch:6.2f}, lr: {lr:6.7f}, mlm_loss: {mlm_loss:6.3f}, nsp_loss: {nsp_loss:6.3f},\
        mlm_acc: {mlm_acc:6.5f}, nsp_acc: {nsp_acc:6.5f}, samples/sec: {samples_per_sec:6.2f}, time: {iter_time:8.6f}, total_time: {total_time:8.1f}"
    )
    learning_rate = mlm_loss = nsp_loss = 0
    start_all = time.time()

    try:
        while step < total_steps:
            learning_rate = learning_rate_schedule.get_at_step(step)
            loss_scaling = loss_scaling_schedule.get_at_step(step)
            try:
                batch_time, mlm_loss, nsp_loss, mlm_acc, nsp_acc = training_step(
                    train, learning_rate, loss_scaling)
            except tf.errors.OpError as e:
                raise tf.errors.ResourceExhaustedError(e.node_def, e.op,
                                                       e.message)

            batch_time /= opts['batches_per_step']

            is_log_step = (step % steps_per_logs == 0)
            is_save_tensorboard_step = (steps_per_tensorboard > 0 and
                                        (step % steps_per_tensorboard == 0))
            is_save_ckpt_step = (step and
                                 (step % steps_per_ckpts == 0 or step
                                  == total_steps - opts['batches_per_step']))

            if (step == 1 and (is_main_worker or opts['log_all_workers'])):
                poplar_compile_time = time.time() - start_all
                logger.info(f"Poplar compile time: {poplar_compile_time:.2f}s")
                poplar_summary = tf.Summary()
                poplar_summary.value.add(tag='poplar/compile_time',
                                         simple_value=poplar_compile_time)
                summary_writer.add_summary(poplar_summary)

            if is_log_step:
                total_time = time.time() - start_all
                epoch = step / steps_per_epoch
                stats = OrderedDict([
                    ('step', step),
                    ('epoch', epoch),
                    ('lr', learning_rate),
                    ('loss_scaling', loss_scaling),
                    ('mlm_loss', mlm_loss),
                    ('nsp_loss', nsp_loss),
                    ('mlm_acc', mlm_acc),
                    ('nsp_acc', nsp_acc),
                    ('iter_time', batch_time),
                    ('samples_per_sec',
                     opts['global_batch_size'] / batch_time),
                    ('total_time', total_time),
                ])

                logger.info(print_format.format(**stats))

            # Log training statistics
            train_summary = tf.Summary()
            train_summary.value.add(tag='epoch', simple_value=epoch)
            train_summary.value.add(tag='loss/MLM', simple_value=mlm_loss)
            train_summary.value.add(tag='loss/NSP', simple_value=nsp_loss)
            train_summary.value.add(tag='accuracy/MLM', simple_value=mlm_acc)
            train_summary.value.add(tag='accuracy/NSP', simple_value=nsp_acc)
            train_summary.value.add(tag='learning_rate',
                                    simple_value=learning_rate)
            train_summary.value.add(tag='loss_scaling',
                                    simple_value=loss_scaling)
            train_summary.value.add(tag='samples_per_sec',
                                    simple_value=opts['global_batch_size'] /
                                    batch_time)
            train_summary.value.add(tag='samples',
                                    simple_value=step *
                                    opts['batches_per_step'] *
                                    opts['global_batch_size'])
            summary_writer.add_summary(train_summary, step)
            summary_writer.flush()

            if is_save_ckpt_step or is_save_tensorboard_step:
                if is_main_worker:
                    file_path = train.saver.save(train.session,
                                                 opts["checkpoint_path"],
                                                 global_step=step)
                    logger.info(f"Saved checkpoint to {file_path}")

                    if is_save_tensorboard_step:
                        log.save_model_statistics(file_path, summary_writer,
                                                  step)

                if opts['use_popdist']:
                    ipu_utils.barrier()

            step += opts['batches_per_step']
    finally:
        train.session.close()
def train(bert_config, opts):
    # --------------- OPTIONS ---------------------
    epochs = opts["epochs"]
    total_samples = dataset.get_dataset_files_count(opts, is_training=True)
    logger.info("Total samples with duplications {}".format(total_samples))
    total_independent_samples = total_samples // opts['duplication_factor']
    logger.info("Total samples without duplications {}".format(
        total_independent_samples))
    steps_per_epoch = total_independent_samples // (opts['batches_per_step'] *
                                                    opts["total_batch_size"])
    iterations_per_epoch = total_independent_samples // (
        opts["total_batch_size"])

    # total iterations
    if opts['steps']:
        logger.warn("Ignoring the epoch flag and using the steps one")
        steps = opts['steps']
    else:
        steps = epochs * steps_per_epoch
    logger.info(
        "Total training steps equal to {}, total number of samples being analyzed equal to {}"
        .format(steps,
                steps * opts['batches_per_step'] * opts['total_batch_size']))
    iterations_per_step = opts['batches_per_step']
    ckpt_per_step = opts['steps_per_ckpts']

    # avoid nan issue caused by queue length is zero.
    queue_len = iterations_per_epoch // iterations_per_step
    if queue_len == 0:
        queue_len = 1
    batch_times = deque(maxlen=queue_len)

    # learning rate strategy
    lr_schedule_name = opts['lr_schedule']
    logger.info(f"Using learning rate schedule {lr_schedule_name}")
    LR = make_lr_schedule(lr_schedule_name, opts, steps)

    if opts['do_train']:
        # -------------- BUILD TRAINING GRAPH ----------------

        train = build_graph(bert_config,
                            opts,
                            iterations_per_step,
                            is_training=True,
                            feed_name="trainfeed")
        train.session.run(train.init)
        train.session.run(train.iterator.initializer)

        step = 0
        i = 0

        if opts['restore_path'] is not None:
            if os.path.isdir(opts['restore_path']):
                ckpt_file_path = tf.train.latest_checkpoint(
                    opts['restore_path'])
                logger.info(f"Restoring training from latest checkpoint")
            else:
                # Assume it's a directory
                ckpt_file_path = opts['restore_path']

            logger.info(
                f"Restoring training from checkpoint: {ckpt_file_path}")
            train.restore.restore(train.session, ckpt_file_path)

            ckpt_pattern = re.compile(".*ckpt-([0-9]+)$")
            i = int(ckpt_pattern.match(ckpt_file_path).groups()[0])
            step = int(i // iterations_per_step)

        if opts['start_from_ckpt']:
            # We use a checkpoint to initialise our model
            train.restore.restore(train.session, opts['start_from_ckpt'])
            logger.info("Starting the training from the checkpoint {}".format(
                opts['start_from_ckpt']))

        # Initialise Weights & Biases if available
        if opts['wandb']:
            import wandb
            wandb.init(project="tf-bert", sync_tensorboard=True)
            wandb.config.update(opts)

        # Tensorboard logs path
        log_path = os.path.join(opts["logs_path"], 'event')
        logger.info("Tensorboard event file path {}".format(log_path))
        summary_writer = tf.summary.FileWriter(log_path,
                                               train.graph,
                                               session=train.session)

        # ------------- TRAINING LOOP ----------------
        logger.info(
            "################################################################################"
        )
        logger.info("Start training......")
        print_format = (
            "step: {step:6d}, iteration: {iteration:6d}, epoch: {epoch:6.3f}, lr: {lr:10.3g}, mlm_loss: {mlm_loss:6.3f}, nsp_loss: {nsp_loss:6.3f}, "
            "samples/sec: {samples_per_sec:6.2f}, time: {iter_time:8.6f}, total_time: {total_time:8.1f}, mlm_acc: {mlm_acc:8.5f}, nsp_acc: {nsp_acc:8.5f}"
        )

        start_all = time.time()

        train_saver = train.saver["train_saver"]
        best_saver = train.saver["best_saver"]
        # We initialize the best loss to a super large value
        best_total_loss = 1e10
        best_step = 0

        while step < steps:
            # Run Training
            learning_rate = LR.feed_dict_lr(step)
            try:
                batch_time, mlm_loss, nsp_loss, mlm_acc, nsp_acc = training_step(
                    train, learning_rate)
            except tf.errors.OpError as e:
                raise tf.errors.ResourceExhaustedError(e.node_def, e.op,
                                                       e.message)

            epoch = float(
                opts["total_batch_size"] * i) / total_independent_samples

            batch_time /= iterations_per_step

            if step != 0:
                batch_times.append([batch_time])

            if step == 1:
                poplar_compile_time = time.time() - start_all
                poplar_summary = tf.Summary()
                poplar_summary.value.add(tag='poplar/compile_time',
                                         simple_value=poplar_compile_time)
                summary_writer.add_summary(poplar_summary)

            # Print loss
            if step % opts['steps_per_logs'] == 0:
                if len(batch_times) != 0:
                    avg_batch_time = np.mean(batch_times)
                else:
                    avg_batch_time = batch_time

                samples_per_sec = opts['total_batch_size'] / avg_batch_time

                # flush times every time it is reported
                batch_times.clear()

                total_time = time.time() - start_all

                stats = OrderedDict([
                    ('step', step),
                    ('iteration', i),
                    ('epoch', epoch),
                    ('lr', learning_rate),
                    ('mlm_loss', mlm_loss),
                    ('nsp_loss', nsp_loss),
                    ('mlm_acc', mlm_acc),
                    ('nsp_acc', nsp_acc),
                    ('iter_time', avg_batch_time),
                    ('samples_per_sec', samples_per_sec),
                    ('total_time', total_time),
                ])

                logger.info(print_format.format(**stats))
                bert_logging.write_to_csv(stats, i == 0, True,
                                          opts['logs_path'])

                sys_summary = tf.Summary()
                sys_summary.value.add(tag='perf/throughput_samples_per_second',
                                      simple_value=samples_per_sec)
                sys_summary.value.add(tag='perf/average_batch_time',
                                      simple_value=avg_batch_time)
                summary_writer.add_summary(sys_summary, step)

            # Log training statistics
            train_summary = tf.Summary()
            train_summary.value.add(tag='epoch', simple_value=epoch)
            train_summary.value.add(tag='loss/MLM', simple_value=mlm_loss)
            train_summary.value.add(tag='loss/NSP', simple_value=nsp_loss)
            train_summary.value.add(tag='accuracy/MLM', simple_value=mlm_acc)
            train_summary.value.add(tag='accuracy/NSP', simple_value=nsp_acc)
            train_summary.value.add(tag='defaultLearningRate',
                                    simple_value=learning_rate)
            train_summary.value.add(tag='samples',
                                    simple_value=step *
                                    opts['batches_per_step'] *
                                    opts['total_batch_size'])
            summary_writer.add_summary(train_summary, step)
            summary_writer.flush()

            if step % ckpt_per_step == 0 and step:
                filepath = train_saver.save(train.session,
                                            save_path=opts["checkpoint_path"],
                                            global_step=step)
                logger.info("Saved checkpoint to {}".format(filepath))

                if not opts['wandb']:
                    bert_logging.save_model_statistics(filepath,
                                                       summary_writer, step)

            # Mechanism to checkpoint the best model.
            # set opts["best_ckpt_min_steps"] to 0 to disable
            if best_total_loss > mlm_loss + nsp_loss and step - best_step > opts[
                    "best_ckpt_min_steps"] and opts["best_ckpt_min_steps"]:
                best_total_loss = mlm_loss + nsp_loss
                best_step = step
                filepath = best_saver.save(train.session,
                                           save_path=opts["checkpoint_path"] +
                                           '_best',
                                           global_step=step)
                logger.info("Saved Best checkpoint to {}".format(filepath))

            i += iterations_per_step
            step += 1

        # --------------- LAST CHECKPOINT ----------------
        filepath = train_saver.save(train.session,
                                    save_path=opts["checkpoint_path"] +
                                    '_last',
                                    global_step=step)
        logger.info("Final model saved to to {}".format(filepath))

        # --------------- CLEANUP ----------------
        train.session.close()
示例#3
0
    def train(self):
        # Configure the IPU options.
        ipu_options = ipu_utils.get_ipu_config(
            ipu_id=self.opts["select_ipu"],
            num_ipus_required=len(self.opts["train"]["device_mapping"]) *
            self.opts["train"]["replicas"],
            fp_exceptions=False,
            stochastic_rounding=True,
            xla_recompute=True,
            available_memory_proportion=0.2,
            max_cross_replica_buffer_size=16 * 1024 * 1024,
            scheduler_selection="Clustering",
            compile_only=False,
            partials_type="half")
        # config replication strategy
        if self.opts["use_popdist"]:
            strategy = create_popdist_strategy()
            ipu_options = strategy.update_ipu_config(ipu_options)
            ipu_options = popdist.tensorflow.set_ipu_config(
                ipu_options,
                len(self.opts["train"]["device_mapping"]),
                configure_device=False)
        ipu_options.configure_ipu_system()

        self.sess = tf.Session(config=tf.ConfigProto())

        stop_flag = []
        data_threads = []
        ds = self.get_dataset_on_the_fly(stop_flag, data_threads)

        global_step_holder = tf.placeholder(dtype=tf.int32, shape=())

        # we write this wrapper because self.model_func has "self" as it's parameter
        # it will cause an error when cal ipu_compiler.compile
        def model_wrapper():
            self.model_func(self.model, self.opts, global_step_holder,
                            self.infeed_queue, self.outfeed_queue)

        with ExitStack() as stack:
            if self.opts["use_popdist"]:
                stack.enter_context(strategy.scope())
            self.infeed_queue = ipu_infeed_queue.IPUInfeedQueue(ds)
            self.outfeed_queue = ipu_outfeed_queue.IPUOutfeedQueue()

            with ipu.scopes.ipu_scope("/device:IPU:0"):
                if self.opts["use_popdist"]:

                    def distributed_per_replica_func():
                        return ipu_compiler.compile(model_wrapper, inputs=[])

                    compiled_model = strategy.experimental_run_v2(
                        distributed_per_replica_func, args=[])
                else:
                    compiled_model = ipu_compiler.compile(model_wrapper,
                                                          inputs=[])
            # The outfeed dequeue has to happen after the outfeed enqueue(after calling compile)
            dequeue_outfeed = self.outfeed_queue.dequeue()

            if self.opts["use_popdist"]:
                # Take the mean of all the outputs across the distributed workers
                dequeue_outfeed = [
                    strategy.reduce(tf.distribute.ReduceOp.MEAN, v)
                    for v in dequeue_outfeed
                ]

            with tf.name_scope("loader_and_saver"):
                self.loader, self.saver = self.get_loader_and_saver()

        self.sess.run(self.infeed_queue.initializer)
        self.sess.run(tf.global_variables_initializer())

        begin_epoch = 0

        if self.opts["train"]["load_type"] == "resume":
            # resume a half-trained run
            ckpts = []
            if os.path.exists("./checkpoint"):
                ckpts = sorted([
                    path for path in os.listdir("./checkpoint")
                    if "meta" in path
                ])
            if len(ckpts) == 0:
                logger.info("fail to resume, not find any ckpt")
                return
            ckpt_path = "./checkpoint/" + ckpts[-1].replace(".meta", "")
            logger.info("=> Resume training from: %s ... " % ckpt_path)
            self.loader.restore(self.sess, ckpt_path)
            begin_epoch = int(
                re.search("epoch=([0-9]+)", ckpt_path).groups()[0])
        elif self.opts["train"]["load_type"] in [
                "yolov3", "darknet53", "phase1"
        ]:
            # if load some pretrained ckpt
            if self.initial_weight and os.path.exists(self.initial_weight +
                                                      ".meta"):
                logger.info("=> Restoring weights from: %s ... " %
                            self.initial_weight)
                self.loader.restore(self.sess, self.initial_weight)
            else:
                raise Exception("can't find ckpt to load")

        elif self.opts["train"]["load_type"] == "empty":
            logger.info("=> no checkpoint to load !!!")
            logger.info("=> Now it starts to train YOLOV3 from scratch ...")
        else:
            raise Exception(
                "'load_type' is not one of expected values: yolov3, darknet53, phase1, resume, empty"
            )

        total_epochs = self.epochs
        total_batch_size = self.opts["train"]["pipeline_depth"] * \
            self.batch_size * \
            self.opts["train"]["replicas"] * \
            self.opts["distributed_worker_count"]
        samples_per_interaction = total_batch_size * self.repeat_count
        samples_per_epoch = len(self.trainset) * self.batch_size
        interactions_per_epoch = samples_per_epoch // samples_per_interaction
        if self.for_speed_test:
            interactions_per_epoch = 30
            total_epochs = 1
        steps_per_epoch = interactions_per_epoch * self.repeat_count
        logger.info("total epochs: {}".format(total_epochs))
        logger.info("steps_per_epoch: {}".format(steps_per_epoch))
        moving_loss = deque(maxlen=30)

        if self.opts["distributed_worker_index"] == 0:
            # we only write logs to tensorboard on main worker
            summary_writer = tf.summary.FileWriter(
                "./tf_log/" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
                session=self.sess)

        train_begin_time = time.time()
        for epoch in range(begin_epoch, total_epochs):
            logger.info("epoch {}:".format(epoch + 1))

            start_time = time.time()
            for interaction_count in range(interactions_per_epoch):
                global_step = epoch * steps_per_epoch + interaction_count * self.repeat_count
                self.sess.run(compiled_model,
                              feed_dict={global_step_holder: global_step})
                result = self.sess.run(dequeue_outfeed)

                if self.opts["distributed_worker_index"] == 0:
                    giou_loss = np.mean(result[0])
                    conf_loss = np.mean(result[1])
                    prob_loss = np.mean(result[2])
                    lr = np.mean(result[3])
                    total_loss = giou_loss + conf_loss + prob_loss
                    moving_loss.append(total_loss)
                    end_time = time.time()
                    duration = end_time - start_time
                    start_time = time.time()
                    total_samples = global_step * total_batch_size
                    logger.info(
                        "epoch:{}, global_steps:{}, total_samples:{}, lr:{:.3e}, \
 moving_total_loss:{:.2f}, duration:{:.2f}, samples/s:{:.2f},\
 total_time:{:.2f}".format(epoch + 1, global_step, total_samples, lr,
                           np.mean(moving_loss), duration,
                           samples_per_interaction / duration,
                           time.time() - train_begin_time))

                    train_summary = tf.Summary()
                    train_summary.value.add(tag="giou_loss",
                                            simple_value=giou_loss)
                    train_summary.value.add(tag="conf_loss",
                                            simple_value=conf_loss)
                    train_summary.value.add(tag="prob_loss",
                                            simple_value=prob_loss)
                    train_summary.value.add(tag="total_loss",
                                            simple_value=total_loss)
                    train_summary.value.add(tag="lr", simple_value=lr)
                    train_summary.value.add(
                        tag="samples_per_sec",
                        simple_value=samples_per_interaction / duration)
                    summary_writer.add_summary(train_summary, total_samples)
                    summary_writer.flush()
            if (not self.for_speed_test) and (
                    epoch % self.opts["train"]["epochs_per_ckpt"] == 0
                    or epoch == total_epochs - 1):
                if self.opts["distributed_worker_index"] == 0:
                    ckpt_loss = np.mean(moving_loss)
                else:
                    # if not call save on all instances, there will be a all-reduce error
                    # but call save on all workers is pointless
                    # so only ckpt saved at worker 0 will have a name with loss value
                    ckpt_loss = 0.0
                ckpt_file = "./checkpoint/yolov3-{}-epoch={}-moving_total_loss={:.4f}.ckpt".format(
                    datetime.now().strftime("%Y-%m-%d_%H:%M:%S"), epoch + 1,
                    ckpt_loss)

                logger.info("saving to: " + ckpt_file)
                model_path = self.saver.save(self.sess,
                                             ckpt_file,
                                             global_step=global_step)
                if self.opts["distributed_worker_index"] == 0:
                    log.save_model_statistics(model_path, summary_writer,
                                              global_step * total_batch_size)
        # tell threads to stop
        stop_flag.append(0)
        for data_thread in data_threads:
            data_thread.join()
        self.sess.close()