Beispiel #1
0
        logger.info(
            f"Restoring previous training checkpoint from {latest_checkpoint_path}"
        )
        start_epoch, global_step = load_checkpoint(model, optimizer,
                                                   latest_checkpoint_path)
        logger.info(
            f"The model is loaded from last checkpoint at epoch {start_epoch} when the global steps were at {global_step}"
        )

    logger.info("Training the model")

    for index in range(start_epoch, job_config.get_total_epoch_count()):
        logger.info(f"Training epoch: {index + 1}")

        train(index)

        if check_write_log():
            epoch_ckp_path = os.path.join(
                saved_model_path,
                "bert_encoder_epoch_{0:04d}.pt".format(index + 1))
            logger.info(
                f"Saving checkpoint of the model from epoch {index + 1} at {epoch_ckp_path}"
            )
            model.save_bert(epoch_ckp_path)
            checkpoint_model(
                os.path.join(
                    saved_model_path,
                    "training_state_checkpoint_{0:04d}.tar".format(index + 1)),
                model, optimizer, index, global_step)
    best_loss = None
    for index in range(start_epoch, args.epochs):
        logger.info(f"Training epoch: {index + 1}")

        eval_loss = train(index)

        if check_write_log():
            if best_loss is None or eval_loss is None or eval_loss < best_loss * 0.99:
                best_loss = eval_loss
                epoch_ckp_path = os.path.join(
                    saved_model_path,
                    "{}_{}_bert_encoder_epoch_{:04d}.pt".format(
                        job_id, local_rank, index + 1))
                checkpoint_model(
                    os.path.join(
                        saved_model_path,
                        "{}_{}_training_state_checkpoint_{:04d}.tar".format(
                            job_id, local_rank, index + 1)), model, optimizer,
                    index, global_step)
                logger.info(
                    f"Saving checkpoint of the model from epoch {index + 1} at {epoch_ckp_path}"
                )
                model.save_bert(epoch_ckp_path)

                #save best checkpoint in separate directory
                if args.best_cp_dir:
                    best_ckp_path = os.path.join(
                        args.best_cp_dir,
                        "{}_{}_bert_encoder_epoch_{:04d}.pt".format(
                            job_id, local_rank, index + 1))
                    shutil.rmtree(args.best_cp_dir)
                    os.makedirs(args.best_cp_dir, exist_ok=True)
Beispiel #3
0
def train():
    model.train()
    global global_step
    # Pretraining datasets
    batchs_per_dataset = []
    shuffle_numbers = 10

    midea_dataset = MideaDataset(
        tokenizer=tokenizer,
        folder=args.train_path,
        max_seq_length=max_seq_length,
        shuffle_numbers=shuffle_numbers,
        max_predictions_per_seq=max_predictions_per_seq,
        masked_lm_prob=masked_lm_prob)
    num_batches = get_effective_batch(len(midea_dataset))
    logger.info('Wikpedia data file: Number of samples {}'.format(
        len(midea_dataset)))
    batchs_per_dataset.append(num_batches)

    logger.info("Training on Midea dataset")
    dataset_batches = []
    for i, batch_count in enumerate(batchs_per_dataset):
        dataset_batches.extend([i] * batch_count)
    random.shuffle(dataset_batches)

    dataset_picker = []
    for dataset_batch_type in dataset_batches:
        dataset_picker.extend([dataset_batch_type] *
                              gradient_accumulation_steps)
    print("dataset_picker", len(dataset_picker))
    # We don't want the dataset to be n the form of alternate chunks if we have more than
    # one dataset type, instead we want to organize them into contiguous chunks of each
    # data type, hence the multiplication with grad_accumulation_steps with dataset_batch_type
    model.train()

    # Counter of sequences in an "epoch"
    sequences_counter = 0
    global_step_loss = 0
    dataloaders = get_dataloader(midea_dataset)
    step = 0
    best_loss = None
    for index in range(start_epoch, args.epochs):
        logger.info(f"Training epoch: {index + 1}")
        for batch in tqdm(dataloaders):
            # batch = [t.reshape(batch_size*2*shuffle_numbers, -1) for t in batch]
            sequences_counter += batch[1].shape[0]

            # if n_gpu == 1:
            # batch = tuple(t.to(device) for t in batch)  # Move to GPU
            batch = tuple(t.cuda(device, non_blocking=True) for t in batch)

            # logger.info("{} Number of sequences processed so far: {} (cumulative in {} steps)".format(datetime.utcnow(), sequences_counter, step))
            loss = model.network(batch)

            if n_gpu > 1:
                # this is to average loss for multi-gpu. In DistributedDataParallel
                # setting, we get tuple of losses form all proccesses
                loss = loss.mean()

            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps

            # Enabling  optimized Reduction
            # reduction only happens in backward if this method is called before
            # when using the distributed module
            if accumulate_gradients:
                if use_multigpu_with_single_device_per_process and (
                        step + 1) % gradient_accumulation_steps == 0:
                    model.network.enable_need_reduction()
                else:
                    model.network.disable_need_reduction()
            if fp16:
                optimizer.backward(loss)
            else:
                loss.backward()

            global_step_loss += loss
            if (step + 1) % gradient_accumulation_steps == 0:
                if fp16:
                    # modify learning rate with special warm up BERT uses
                    # if fp16 is False, BertAdam is used that handles this automatically
                    lr_this_step = job_config.get_learning_rate(
                    ) * warmup_linear_decay_exp(
                        global_step, job_config.get_decay_rate(),
                        job_config.get_decay_step(),
                        job_config.get_total_training_steps(),
                        job_config.get_warmup_proportion())
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step

                    # Record the LR against global_step on tensorboard
                    if check_write_log():
                        summary_writer.add_scalar(f'Train/lr', lr_this_step,
                                                  global_step)

                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1
                global_step_loss = 0
                step += 1

        logger.info("Completed {} steps".format(step))
        logger.info(
            "Completed processing {} sequences".format(sequences_counter))
        eval_loss = pretrain_validation(index)
        if check_write_log():
            if best_loss is None or eval_loss is None or eval_loss < best_loss * 0.99:
                best_loss = eval_loss
                epoch_ckp_path = os.path.join(
                    saved_model_path,
                    "bert_encoder_epoch_{0:04d}.pt".format(index + 1))
                checkpoint_model(
                    os.path.join(
                        saved_model_path,
                        "training_state_checkpoint_{0:04d}.tar".format(index +
                                                                       1)),
                    model, optimizer, index, global_step)
                logger.info(
                    f"Saving checkpoint of the model from epoch {index + 1} at {epoch_ckp_path}"
                )
                model.save_bert(epoch_ckp_path)

                # save best checkpoint in separate directory
                if args.best_cp_dir:
                    best_ckp_path = os.path.join(
                        args.best_cp_dir,
                        "bert_encoder_epoch_{0:04d}.pt".format(index + 1))
                    shutil.rmtree(args.best_cp_dir)
                    os.makedirs(args.best_cp_dir, exist_ok=True)
                    model.save_bert(best_ckp_path)

            if args.latest_cp_dir:
                shutil.rmtree(args.latest_cp_dir)
                os.makedirs(args.latest_cp_dir, exist_ok=True)
                checkpoint_model(
                    os.path.join(
                        args.latest_cp_dir,
                        "training_state_checkpoint_{0:04d}.tar".format(index +
                                                                       1)),
                    model, optimizer, index, global_step)
                latest_ckp_path = os.path.join(
                    args.latest_cp_dir,
                    "bert_encoder_epoch_{0:04d}.pt".format(index + 1))
                model.save_bert(latest_ckp_path)