Пример #1
0
class Framework(object):
    """A framework wrapping the Relational Graph Extraction model. This framework allows to train, predict, evaluate,
    saving and loading the model with a single line of code.
    """
    def __init__(self, **config):
        super().__init__()

        self.config = config

        self.grad_acc = self.config[
            'grad_acc'] if 'grad_acc' in self.config else 1
        self.device = torch.device(self.config['device'])
        if isinstance(self.config['model'], str):
            self.model = MODELS[self.config['model']](**self.config)
        else:
            self.model = self.config['model']

        self.class_weights = torch.tensor(self.config['class_weights']).float(
        ) if 'class_weights' in self.config else torch.ones(
            self.config['n_rel'])
        if 'lambda' in self.config:
            self.class_weights[0] = self.config['lambda']
        self.loss_fn = nn.CrossEntropyLoss(weight=self.class_weights.to(
            self.device),
                                           reduction='mean')
        if self.config['optimizer'] == 'SGD':
            self.optimizer = torch.optim.SGD(
                self.model.get_parameters(self.config.get('l2', .01)),
                lr=self.config['lr'],
                momentum=self.config.get('momentum', 0),
                nesterov=self.config.get('nesterov', False))
        elif self.config['optimizer'] == 'Adam':
            self.optimizer = torch.optim.Adam(self.model.get_parameters(
                self.config.get('l2', .01)),
                                              lr=self.config['lr'])
        elif self.config['optimizer'] == 'AdamW':
            self.optimizer = AdamW(self.model.get_parameters(
                self.config.get('l2', .01)),
                                   lr=self.config['lr'])
        else:
            raise Exception('The optimizer must be SGD, Adam or AdamW')

    def _train_step(self, dataset, epoch, scheduler=None):
        print("Training:")
        self.model.train()

        total_loss = 0
        predictions, labels, positions = [], [], []
        precision = recall = fscore = 0.0
        progress = tqdm(
            enumerate(dataset),
            desc=
            f"Epoch: {epoch} - Loss: {0.0} - P/R/F: {precision}/{recall}/{fscore}",
            total=len(dataset))
        for i, batch in progress:
            # uncompress the batch
            seq, mask, ent, label = batch
            seq = seq.to(self.device)
            mask = mask.to(self.device)
            ent = ent.to(self.device)
            label = label.to(self.device)

            #self.optimizer.zero_grad()
            output = self.model(seq, mask, ent)
            loss = self.loss_fn(output, label)
            total_loss += loss.item()

            if self.config['half']:
                with amp.scale_loss(loss, self.optimizer) as scale_loss:
                    scale_loss.backward()
            else:
                loss.backward()

            if (i + 1) % self.grad_acc == 0:
                if self.config.get('grad_clip', False):
                    if self.config['half']:
                        nn.utils.clip_grad_norm_(
                            amp.master_params(self.optimizer),
                            self.config['grad_clip'])
                    else:
                        nn.utils.clip_grad_norm_(self.model.parameters(),
                                                 self.config['grad_clip'])

                self.optimizer.step()
                self.model.zero_grad()
                if scheduler:
                    scheduler.step()

            # Evaluate results
            pre, lab, pos = dataset.evaluate(
                i,
                output.detach().numpy() if self.config['device'] is 'cpu' else
                output.detach().cpu().numpy())

            predictions.extend(pre)
            labels.extend(lab)
            positions.extend(pos)

            if (i + 1) % 10 == 0:
                precision, recall, fscore, _ = precision_recall_fscore_support(
                    np.array(labels),
                    np.array(predictions),
                    average='micro',
                    labels=list(range(1, self.model.n_rel)))

            progress.set_description(
                f"Epoch: {epoch} - Loss: {total_loss/(i+1):.3f} - P/R/F: {precision:.2f}/{recall:.2f}/{fscore:.2f}"
            )

        # For last iteration
        #self.optimizer.step()
        #self.optimizer.zero_grad()

        predictions, labels = np.array(predictions), np.array(labels)
        precision, recall, fscore, _ = precision_recall_fscore_support(
            labels,
            predictions,
            average='micro',
            labels=list(range(1, self.model.n_rel)))
        print(
            f"Precision: {precision:.3f} - Recall: {recall:.3f} - F-Score: {fscore:.3f}"
        )
        precision, recall, fscore, _ = precision_recall_fscore_support(
            labels, predictions, average='micro')
        print(
            f"[with NO-RELATION] Precision: {precision:.3f} - Recall: {recall:.3f} - F-Score: {fscore:.3f}"
        )

        return total_loss / (i + 1)

    def _val_step(self, dataset, epoch):
        print("Validating:")
        self.model.eval()

        predictions, labels, positions = [], [], []
        total_loss = 0
        with torch.no_grad():
            progress = tqdm(enumerate(dataset),
                            desc=f"Epoch: {epoch} - Loss: {0.0}",
                            total=len(dataset))
            for i, batch in progress:
                # uncompress the batch
                seq, mask, ent, label = batch
                seq = seq.to(self.device)
                mask = mask.to(self.device)
                ent = ent.to(self.device)
                label = label.to(self.device)

                output = self.model(seq, mask, ent)
                loss = self.loss_fn(output, label)
                total_loss += loss.item()

                # Evaluate results
                pre, lab, pos = dataset.evaluate(
                    i,
                    output.detach().numpy() if self.config['device'] is 'cpu'
                    else output.detach().cpu().numpy())

                predictions.extend(pre)
                labels.extend(lab)
                positions.extend(pos)

                progress.set_description(
                    f"Epoch: {epoch} - Loss: {total_loss/(i+1):.3f}")

        predictions, labels = np.array(predictions), np.array(labels)
        precision, recall, fscore, _ = precision_recall_fscore_support(
            labels,
            predictions,
            average='micro',
            labels=list(range(1, self.model.n_rel)))
        print(
            f"Precision: {precision:.3f} - Recall: {recall:.3f} - F-Score: {fscore:.3f}"
        )
        noprecision, norecall, nofscore, _ = precision_recall_fscore_support(
            labels, predictions, average='micro')
        print(
            f"[with NO-RELATION] Precision: {noprecision:.3f} - Recall: {norecall:.3f} - F-Score: {nofscore:.3f}"
        )

        return total_loss / (i + 1), precision, recall, fscore

    def _save_checkpoint(self, dataset, epoch, loss, val_loss):
        print(f"Saving checkpoint ({dataset.name}.pth) ...")
        PATH = os.path.join('checkpoints', f"{dataset.name}.pth")
        config_PATH = os.path.join('checkpoints',
                                   f"{dataset.name}_config.json")
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': loss,
                'val_loss': val_loss
            }, PATH)
        with open(config_PATH, 'wt') as f:
            json.dump(self.config, f)

    def _load_checkpoint(self, PATH: str, config_PATH: str):
        checkpoint = torch.load(PATH)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']

        with open(config_PATH, 'rt') as f:
            self.config = json.load(f)

        return epoch, loss

    def fit(self,
            dataset,
            validation=True,
            batch_size=1,
            patience=3,
            delta=0.):
        """ Fits the model to the given dataset.

        Usage:
        ``` y
        >>> rge = Framework(**config)
        >>> rge.fit(train_data)
        """
        self.model.to(self.device)
        train_data = dataset.get_train(batch_size)

        if self.config['half']:
            self.model, self.optimizer = amp.initialize(
                self.model,
                self.optimizer,
                opt_level='O2',
                keep_batchnorm_fp32=True)

        if self.config['linear_scheduler']:
            num_training_steps = int(
                len(train_data) // self.grad_acc * self.config['epochs'])
            scheduler = get_linear_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=self.config.get('warmup_steps', 0),
                num_training_steps=num_training_steps)
        else:
            scheduler = None

        early_stopping = EarlyStopping(patience, delta, self._save_checkpoint)

        for epoch in range(self.config['epochs']):
            self.optimizer.zero_grad()
            loss = self._train_step(train_data, epoch, scheduler=scheduler)
            if validation:
                val_loss, _, _, _ = self._val_step(dataset.get_val(batch_size),
                                                   epoch)
                if early_stopping(val_loss,
                                  dataset=dataset,
                                  epoch=epoch,
                                  loss=loss):
                    break

        # Recover the best epoch
        path = os.path.join("checkpoints", f"{dataset.name}.pth")
        config_path = os.path.join("checkpoints",
                                   f"{dataset.name}_config.json")
        _, _ = self._load_checkpoint(path, config_path)

    def predict(self, dataset, return_proba=False) -> torch.Tensor:
        """ Predicts the relations graph for the given dataset.
        """
        self.model.to(self.device)
        self.model.eval()

        predictions, instances = [], []
        with torch.no_grad():
            progress = tqdm(enumerate(dataset), total=len(dataset))
            for i, batch in progress:
                # uncompress the batch
                seq, mask, ent, label = batch
                seq = seq.to(self.device)
                mask = mask.to(self.device)
                ent = ent.to(self.device)
                label = label.to(self.device)

                output = self.model(seq, mask, ent)
                if not return_proba:
                    pred = np.argmax(output.detach().cpu().numpy(),
                                     axis=1).tolist()
                else:
                    pred = output.detach().cpu().numpy().tolist()
                inst = dataset.get_instances(i)

                predictions.extend(pred)
                instances.extend(inst)

        return predictions, instances

    def evaluate(self, dataset: Dataset, batch_size=1) -> torch.Tensor:
        """ Evaluates the model given for the given dataset.
        """
        loss, precision, recall, fscore = self._val_step(
            dataset.get_val(batch_size), 0)
        return loss, precision, recall, fscore

    def save_model(self, path: str):
        """ Saves the model to a file.

        Usage:
        ``` 
        >>> rge = Framework(**config)
        >>> rge.fit(train_data)

        >>> rge.save_model("path/to/file")
        ```

        TODO
        """
        self.model.save_pretrained(path)
        with open(f"{path}/fine_tunning.config.json", 'wt') as f:
            json.dump(self.config, f, indent=4)

    @classmethod
    def load_model(cls,
                   path: str,
                   config_path: str = None,
                   from_checkpoint=False):
        """ Loads the model from a file.

        Args:
            path: str Path to the file that stores the model.

        Returns:
            Framework instance with the loaded model.

        Usage:
        ```
        >>> rge = Framework.load_model("path/to/model")
        ```

        TODO
        """
        if not from_checkpoint:
            config_path = path + '/fine_tunning.config.json'
            with open(config_path) as f:
                config = json.load(f)
            config['pretrained_model'] = path
            rge = cls(**config)

        else:
            if config_path is None:
                raise Exception(
                    'Loading the model from a checkpoint requires config_path argument.'
                )
            with open(config_path) as f:
                config = json.load(f)
            rge = cls(**config)
            rge._load_checkpoint(path, config_path)

        return rge
Пример #2
0
def main(train_file,
         dev_file,
         target_dir,
         epochs=10,
         batch_size=32,
         lr=2e-05,
         patience=3,
         max_grad_norm=10.0,
         checkpoint=None):
    bert_tokenizer = XLNetTokenizer.from_pretrained('hfl/chinese-xlnet-base',
                                                    do_lower_case=True)
    device = torch.device("cuda")
    print(20 * "=", " Preparing for training ", 20 * "=")
    # 保存模型的路径
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)
    # -------------------- Data loading ------------------- #
    print("\t* Loading training data...")
    train_data = DataPrecessForSentence(bert_tokenizer, train_file)
    train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
    print("\t* Loading validation data...")
    dev_data = DataPrecessForSentence(bert_tokenizer, dev_file)
    dev_loader = DataLoader(dev_data, shuffle=True, batch_size=batch_size)
    # -------------------- Model definition ------------------- #
    print("\t* Building model...")
    model = XlnetModel().to(device)
    # -------------------- Preparation for training  ------------------- #
    # 待优化的参数
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
    #optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode="max",
                                                           factor=0.85,
                                                           patience=0)
    best_score = 0.0
    start_epoch = 1
    # Data for loss curves plot
    epochs_count = []
    train_losses = []
    valid_losses = []
    # Continuing training from a checkpoint if one was given as argument
    if checkpoint:
        checkpoint = torch.load(checkpoint, map_location=torch.device("cpu"))
        start_epoch = checkpoint["epoch"] + 1
        best_score = checkpoint["best_score"]
        print("\t* Training will continue on existing model from epoch {}...".
              format(start_epoch))
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epochs_count = checkpoint["epochs_count"]
        train_losses = checkpoint["train_losses"]
        valid_losses = checkpoint["valid_losses"]
    # Compute loss and accuracy before starting (or resuming) training.
    _, valid_loss, valid_accuracy, auc = validate(model, dev_loader)
    print(
        "\t* Validation loss before training: {:.4f}, accuracy: {:.4f}%, auc: {:.4f}"
        .format(valid_loss, (valid_accuracy * 100), auc))
    # -------------------- Training epochs ------------------- #
    print("\n", 20 * "=", "Training Xlnet model on device: {}".format(device),
          20 * "=")
    patience_counter = 0
    for epoch in range(start_epoch, epochs + 1):
        epochs_count.append(epoch)
        print("* Training epoch {}:".format(epoch))
        epoch_time, epoch_loss, epoch_accuracy = train(model, train_loader,
                                                       optimizer, epoch,
                                                       max_grad_norm)
        train_losses.append(epoch_loss)
        print("-> Training time: {:.4f}s, loss = {:.4f}, accuracy: {:.4f}%".
              format(epoch_time, epoch_loss, (epoch_accuracy * 100)))
        print("* Validation for epoch {}:".format(epoch))
        epoch_time, epoch_loss, epoch_accuracy, epoch_auc = validate(
            model, dev_loader)
        valid_losses.append(epoch_loss)
        print(
            "-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%, auc: {:.4f}\n"
            .format(epoch_time, epoch_loss, (epoch_accuracy * 100), epoch_auc))
        # Update the optimizer's learning rate with the scheduler.
        scheduler.step(epoch_accuracy)
        # Early stopping on validation accuracy.
        if epoch_accuracy < best_score:
            patience_counter += 1
        else:
            best_score = epoch_accuracy
            patience_counter = 0
            torch.save(
                {
                    "epoch": epoch,
                    "model": model.state_dict(),
                    "best_score": best_score,
                    "epochs_count": epochs_count,
                    "train_losses": train_losses,
                    "valid_losses": valid_losses
                }, os.path.join(target_dir, "best.pth.tar"))
        if patience_counter >= patience:
            print("-> Early stopping: patience limit reached, stopping...")
            break
Пример #3
0
            p for n, p in multi_choice_model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr)

    # load trained model from checkpoint
    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint, map_location='cuda:0')
        if checkpoint["name"] == args.bert_model:
            logger.info("***** Loading saved model based on '%s' *****",
                        checkpoint["name"])
            multi_choice_model.load_state_dict(checkpoint["model"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            args.start_epoch = checkpoint["epoch"]
        else:
            raise Exception(
                "The loaded model does not match the pre-trained model",
                checkpoint["name"])

    # train and evaluate
    if args.do_train == True and args.do_eval == True:
        global_step, tr_loss = train(args, train_datasets, multi_choice_model,
                                     optimizer, evaluate_dataset)
        logger.info("***** End of training *****")
        logger.info(" global_step = %s, average loss = %s", global_step,
                    tr_loss)

    # only evaluate
Пример #4
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default='bert-base-uncased',
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        '--task',
        type=str,
        default=None,
        required=True,
        help="Task code in {hotpot_open, hotpot_distractor, squad, nq}")

    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=378,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=1,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=5,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam. (def: 5e-5)")
    parser.add_argument("--num_train_epochs",
                        default=5.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument('--local_rank', default=-1, type=int)

    # RNN graph retriever-specific parameters
    parser.add_argument("--example_limit", default=None, type=int)

    parser.add_argument("--max_para_num", default=10, type=int)
    parser.add_argument(
        "--neg_chunk",
        default=8,
        type=int,
        help="The chunk size of negative examples during training (to "
        "reduce GPU memory consumption with negative sampling)")
    parser.add_argument(
        "--eval_chunk",
        default=100000,
        type=int,
        help=
        "The chunk size of evaluation examples (to reduce RAM consumption during evaluation)"
    )
    parser.add_argument(
        "--split_chunk",
        default=300,
        type=int,
        help=
        "The chunk size of BERT encoding during inference (to reduce GPU memory consumption)"
    )

    parser.add_argument('--train_file_path',
                        type=str,
                        default=None,
                        help="File path to the training data")
    parser.add_argument('--dev_file_path',
                        type=str,
                        default=None,
                        help="File path to the eval data")

    parser.add_argument('--beam', type=int, default=1, help="Beam size")
    parser.add_argument('--min_select_num',
                        type=int,
                        default=1,
                        help="Minimum number of selected paragraphs")
    parser.add_argument('--max_select_num',
                        type=int,
                        default=3,
                        help="Maximum number of selected paragraphs")
    parser.add_argument(
        "--use_redundant",
        action='store_true',
        help="Whether to use simulated seqs (only for training)")
    parser.add_argument(
        "--use_multiple_redundant",
        action='store_true',
        help="Whether to use multiple simulated seqs (only for training)")
    parser.add_argument(
        '--max_redundant_num',
        type=int,
        default=100000,
        help=
        "Whether to limit the number of the initial TF-IDF pool (only for open-domain eval)"
    )
    parser.add_argument(
        "--no_links",
        action='store_true',
        help=
        "Whether to omit any links (or in other words, only use TF-IDF-based paragraphs)"
    )
    parser.add_argument("--pruning_by_links",
                        action='store_true',
                        help="Whether to do pruning by links (and top 1)")
    parser.add_argument(
        "--expand_links",
        action='store_true',
        help=
        "Whether to expand links with paragraphs in the same article (for NQ)")
    parser.add_argument(
        '--tfidf_limit',
        type=int,
        default=None,
        help=
        "Whether to limit the number of the initial TF-IDF pool (only for open-domain eval)"
    )

    parser.add_argument("--pred_file",
                        default=None,
                        type=str,
                        help="File name to write paragraph selection results")
    parser.add_argument("--tagme",
                        action='store_true',
                        help="Whether to use tagme at inference")
    parser.add_argument(
        '--topk',
        type=int,
        default=2,
        help="Whether to use how many paragraphs from the previous steps")

    parser.add_argument(
        "--model_suffix",
        default=None,
        type=str,
        help="Suffix to load a model file ('pytorch_model_' + suffix +'.bin')")

    parser.add_argument("--db_save_path",
                        default=None,
                        type=str,
                        help="File path to DB")
    parser.add_argument("--fp16", default=False, action='store_true')
    parser.add_argument("--fp16_opt_level", default="O1", type=str)
    parser.add_argument("--do_label",
                        default=False,
                        action='store_true',
                        help="For pre-processing features only.")

    parser.add_argument("--oss_cache_dir", default=None, type=str)
    parser.add_argument("--cache_dir", default=None, type=str)
    parser.add_argument("--dist",
                        default=False,
                        action='store_true',
                        help='use distributed training.')
    parser.add_argument("--save_steps", default=5000, type=int)
    parser.add_argument("--resume", default=None, type=int)
    parser.add_argument("--oss_pretrain", default=None, type=str)
    parser.add_argument("--model_version", default='v1', type=str)
    parser.add_argument("--disable_rnn_layer_norm",
                        default=False,
                        action='store_true')

    args = parser.parse_args()

    if args.dist:
        dist.init_process_group(backend='nccl')
        print(f"local rank: {args.local_rank}")
        print(f"global rank: {dist.get_rank()}")
        print(f"world size: {dist.get_world_size()}")

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
        dist.init_process_group(backend='nccl')

    if args.dist:
        global_rank = dist.get_rank()
        world_size = dist.get_world_size()
        if world_size > 1:
            args.local_rank = global_rank

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if args.train_file_path is not None:
        do_train = True

        if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
            raise ValueError(
                "Output directory ({}) already exists and is not empty.".
                format(args.output_dir))
        if args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir, exist_ok=True)

    elif args.dev_file_path is not None:
        do_train = False

    else:
        raise ValueError(
            'One of train_file_path: {} or dev_file_path: {} must be non-None'.
            format(args.train_file_path, args.dev_file_path))

    processor = DataProcessor()

    # Configurations of the graph retriever
    graph_retriever_config = GraphRetrieverConfig(
        example_limit=args.example_limit,
        task=args.task,
        max_seq_length=args.max_seq_length,
        max_select_num=args.max_select_num,
        max_para_num=args.max_para_num,
        tfidf_limit=args.tfidf_limit,
        train_file_path=args.train_file_path,
        use_redundant=args.use_redundant,
        use_multiple_redundant=args.use_multiple_redundant,
        max_redundant_num=args.max_redundant_num,
        dev_file_path=args.dev_file_path,
        beam=args.beam,
        min_select_num=args.min_select_num,
        no_links=args.no_links,
        pruning_by_links=args.pruning_by_links,
        expand_links=args.expand_links,
        eval_chunk=args.eval_chunk,
        tagme=args.tagme,
        topk=args.topk,
        db_save_path=args.db_save_path,
        disable_rnn_layer_norm=args.disable_rnn_layer_norm)

    logger.info(graph_retriever_config)
    logger.info(args)

    tokenizer = AutoTokenizer.from_pretrained(args.bert_model)

    if args.model_version == 'roberta':
        from modeling_graph_retriever_roberta import RobertaForGraphRetriever
    elif args.model_version == 'v3':
        from modeling_graph_retriever_roberta import RobertaForGraphRetrieverIterV3 as RobertaForGraphRetriever
    else:
        raise RuntimeError()

    ##############################
    # Training                   #
    ##############################
    if do_train:
        _model_state_dict = None
        if args.oss_pretrain is not None:
            _model_state_dict = torch.load(load_pretrain_from_oss(
                args.oss_pretrain),
                                           map_location='cpu')
            logger.info(f"Loaded pretrained model from {args.oss_pretrain}")

        if args.resume is not None:
            _model_state_dict = torch.load(load_buffer_from_oss(
                os.path.join(args.oss_cache_dir,
                             f"pytorch_model_{args.resume}.bin")),
                                           map_location='cpu')

        model = RobertaForGraphRetriever.from_pretrained(
            args.bert_model,
            graph_retriever_config=graph_retriever_config,
            state_dict=_model_state_dict)

        model.to(device)

        global_step = 0

        POSITIVE = 1.0
        NEGATIVE = 0.0

        _cache_file_name = f"cache_roberta_train_{args.max_seq_length}_{args.max_para_num}"
        _examples_cache_file_name = f"examples_{_cache_file_name}"
        _features_cache_file_name = f"features_{_cache_file_name}"

        # Load training examples
        logger.info(f"Loading training examples and features.")
        try:
            if args.cache_dir is not None and os.path.exists(
                    os.path.join(args.cache_dir, _features_cache_file_name)):
                logger.info(
                    f"Loading pre-processed features from {os.path.join(args.cache_dir, _features_cache_file_name)}"
                )
                train_features = torch.load(
                    os.path.join(args.cache_dir, _features_cache_file_name))
            else:
                # train_examples = torch.load(load_buffer_from_oss(os.path.join(oss_features_cache_dir,
                #                                                               _examples_cache_file_name)))
                train_features = torch.load(
                    load_buffer_from_oss(
                        os.path.join(oss_features_cache_dir,
                                     _features_cache_file_name)))
                logger.info(
                    f"Pre-processed features are loaded from oss: "
                    f"{os.path.join(oss_features_cache_dir, _features_cache_file_name)}"
                )
        except:
            train_examples = processor.get_train_examples(
                graph_retriever_config)
            train_features = convert_examples_to_features(
                train_examples,
                args.max_seq_length,
                args.max_para_num,
                graph_retriever_config,
                tokenizer,
                train=True)
            logger.info(
                f"Saving pre-processed features into oss: {oss_features_cache_dir}"
            )
            torch_save_to_oss(
                train_examples,
                os.path.join(oss_features_cache_dir,
                             _examples_cache_file_name))
            torch_save_to_oss(
                train_features,
                os.path.join(oss_features_cache_dir,
                             _features_cache_file_name))

        if args.do_label:
            logger.info("Finished.")
            return

        # len(train_examples) and len(train_features) can be different, depending on the redundant setting
        num_train_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

        # Prepare optimizer
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'layer_norm']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        t_total = num_train_steps
        if args.local_rank != -1:
            t_total = t_total // dist.get_world_size()

        optimizer = AdamW(optimizer_grouped_parameters,
                          betas=(0.9, 0.98),
                          lr=args.learning_rate)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, int(t_total * args.warmup_proportion), t_total)

        logger.info(optimizer)
        if args.fp16:
            from apex import amp
            amp.register_half_function(torch, "einsum")

            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.fp16_opt_level)

        if args.local_rank != -1:
            if args.fp16_opt_level == 'O2':
                try:
                    import apex
                    model = apex.parallel.DistributedDataParallel(
                        model, delay_allreduce=True)
                except ImportError:
                    model = torch.nn.parallel.DistributedDataParallel(
                        model, find_unused_parameters=True)
            else:
                model = torch.nn.parallel.DistributedDataParallel(
                    model, find_unused_parameters=True)

        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        if args.resume is not None:
            _amp_state_dict = os.path.join(args.oss_cache_dir,
                                           f"amp_{args.resume}.bin")
            _optimizer_state_dict = os.path.join(
                args.oss_cache_dir, f"optimizer_{args.resume}.pt")
            _scheduler_state_dict = os.path.join(
                args.oss_cache_dir, f"scheduler_{args.resume}.pt")

            amp.load_state_dict(
                torch.load(load_buffer_from_oss(_amp_state_dict)))
            optimizer.load_state_dict(
                torch.load(load_buffer_from_oss(_optimizer_state_dict)))
            scheduler.load_state_dict(
                torch.load(load_buffer_from_oss(_scheduler_state_dict)))

            logger.info(f"Loaded resumed state dict of step {args.resume}")

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_features))
        logger.info("  Instantaneous batch size per GPU = %d",
                    args.train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            args.train_batch_size * args.gradient_accumulation_steps *
            (dist.get_world_size() if args.local_rank != -1 else 1),
        )
        logger.info("  Gradient Accumulation steps = %d",
                    args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        model.train()
        epc = 0
        # test
        if args.local_rank in [-1, 0]:
            if args.fp16:
                amp_file = os.path.join(args.oss_cache_dir,
                                        f"amp_{global_step}.bin")
                torch_save_to_oss(amp.state_dict(), amp_file)
            optimizer_file = os.path.join(args.oss_cache_dir,
                                          f"optimizer_{global_step}.pt")
            torch_save_to_oss(optimizer.state_dict(), optimizer_file)
            scheduler_file = os.path.join(args.oss_cache_dir,
                                          f"scheduler_{global_step}.pt")
            torch_save_to_oss(scheduler.state_dict(), scheduler_file)

        tr_loss = 0
        for _ in range(int(args.num_train_epochs)):
            logger.info('Epoch ' + str(epc + 1))

            TOTAL_NUM = len(train_features)
            train_start_index = 0
            CHUNK_NUM = 8
            train_chunk = TOTAL_NUM // CHUNK_NUM
            chunk_index = 0

            random.shuffle(train_features)

            save_retry = False
            while train_start_index < TOTAL_NUM:
                train_end_index = min(train_start_index + train_chunk - 1,
                                      TOTAL_NUM - 1)
                chunk_len = train_end_index - train_start_index + 1

                if args.resume is not None and global_step < args.resume:
                    _chunk_steps = int(
                        math.ceil(chunk_len * 1.0 / args.train_batch_size /
                                  (1 if args.local_rank == -1 else
                                   dist.get_world_size())))
                    _chunk_steps = _chunk_steps // args.gradient_accumulation_steps
                    if global_step + _chunk_steps <= args.resume:
                        global_step += _chunk_steps
                        train_start_index = train_end_index + 1
                        continue

                train_features_ = train_features[
                    train_start_index:train_start_index + chunk_len]

                all_input_ids = torch.tensor(
                    [f.input_ids for f in train_features_], dtype=torch.long)
                all_input_masks = torch.tensor(
                    [f.input_masks for f in train_features_], dtype=torch.long)
                all_segment_ids = torch.tensor(
                    [f.segment_ids for f in train_features_], dtype=torch.long)
                all_output_masks = torch.tensor(
                    [f.output_masks for f in train_features_],
                    dtype=torch.float)
                all_num_paragraphs = torch.tensor(
                    [f.num_paragraphs for f in train_features_],
                    dtype=torch.long)
                all_num_steps = torch.tensor(
                    [f.num_steps for f in train_features_], dtype=torch.long)
                train_data = TensorDataset(all_input_ids, all_input_masks,
                                           all_segment_ids, all_output_masks,
                                           all_num_paragraphs, all_num_steps)

                if args.local_rank != -1:
                    train_sampler = torch.utils.data.DistributedSampler(
                        train_data)
                else:
                    train_sampler = RandomSampler(train_data)
                train_dataloader = DataLoader(train_data,
                                              sampler=train_sampler,
                                              batch_size=args.train_batch_size,
                                              pin_memory=True,
                                              num_workers=4)

                if args.local_rank != -1:
                    train_dataloader.sampler.set_epoch(epc)

                logger.info('Examples from ' + str(train_start_index) +
                            ' to ' + str(train_end_index))
                for step, batch in enumerate(
                        tqdm(train_dataloader,
                             desc="Iteration",
                             disable=args.local_rank not in [-1, 0])):
                    if args.resume is not None and global_step < args.resume:
                        if (step + 1) % args.gradient_accumulation_steps == 0:
                            global_step += 1
                        continue

                    input_masks = batch[1]
                    batch_max_len = input_masks.sum(dim=2).max().item()

                    num_paragraphs = batch[4]
                    batch_max_para_num = num_paragraphs.max().item()

                    num_steps = batch[5]
                    batch_max_steps = num_steps.max().item()

                    # output_masks_cpu = (batch[3])[:, :batch_max_steps, :batch_max_para_num + 1]

                    batch = tuple(t.to(device) for t in batch)
                    input_ids, input_masks, segment_ids, output_masks, _, _ = batch
                    B = input_ids.size(0)

                    input_ids = input_ids[:, :batch_max_para_num, :
                                          batch_max_len]
                    input_masks = input_masks[:, :batch_max_para_num, :
                                              batch_max_len]
                    segment_ids = segment_ids[:, :batch_max_para_num, :
                                              batch_max_len]
                    output_masks = output_masks[:, :batch_max_steps, :
                                                batch_max_para_num +
                                                1]  # 1 for EOE

                    target = torch.zeros(output_masks.size()).fill_(
                        NEGATIVE)  # (B, NUM_STEPS, |P|+1) <- 1 for EOE
                    for i in range(B):
                        output_masks[i, :num_steps[i], -1] = 1.0  # for EOE

                        for j in range(num_steps[i].item() - 1):
                            target[i, j, j].fill_(POSITIVE)

                        target[i, num_steps[i] - 1, -1].fill_(POSITIVE)
                    target = target.to(device)

                    neg_start = batch_max_steps - 1
                    while neg_start < batch_max_para_num:
                        neg_end = min(neg_start + args.neg_chunk - 1,
                                      batch_max_para_num - 1)
                        neg_len = (neg_end - neg_start + 1)

                        input_ids_ = torch.cat(
                            (input_ids[:, :batch_max_steps - 1, :],
                             input_ids[:, neg_start:neg_start + neg_len, :]),
                            dim=1)
                        input_masks_ = torch.cat(
                            (input_masks[:, :batch_max_steps - 1, :],
                             input_masks[:, neg_start:neg_start + neg_len, :]),
                            dim=1)
                        segment_ids_ = torch.cat(
                            (segment_ids[:, :batch_max_steps - 1, :],
                             segment_ids[:, neg_start:neg_start + neg_len, :]),
                            dim=1)
                        output_masks_ = torch.cat(
                            (output_masks[:, :, :batch_max_steps - 1],
                             output_masks[:, :, neg_start:neg_start + neg_len],
                             output_masks[:, :, batch_max_para_num:
                                          batch_max_para_num + 1]),
                            dim=2)
                        target_ = torch.cat(
                            (target[:, :, :batch_max_steps - 1],
                             target[:, :, neg_start:neg_start + neg_len],
                             target[:, :,
                                    batch_max_para_num:batch_max_para_num +
                                    1]),
                            dim=2)

                        if neg_start != batch_max_steps - 1:
                            output_masks_[:, :, :batch_max_steps - 1] = 0.0
                            output_masks_[:, :, -1] = 0.0

                        loss = model(input_ids_, segment_ids_, input_masks_,
                                     output_masks_, target_, batch_max_steps)

                        if n_gpu > 1:
                            loss = loss.mean(
                            )  # mean() to average on multi-gpu.
                        if args.gradient_accumulation_steps > 1:
                            loss = loss / args.gradient_accumulation_steps

                        if args.fp16:
                            with amp.scale_loss(loss,
                                                optimizer) as scaled_loss:
                                scaled_loss.backward()
                        else:
                            loss.backward()

                        tr_loss += loss.item()
                        neg_start = neg_end + 1

                        # del input_ids_
                        # del input_masks_
                        # del segment_ids_
                        # del output_masks_
                        # del target_

                    if (step + 1) % args.gradient_accumulation_steps == 0:

                        if args.fp16:
                            torch.nn.utils.clip_grad_norm_(
                                amp.master_params(optimizer), 1.0)
                        else:
                            torch.nn.utils.clip_grad_norm_(
                                model.parameters(), 1.0)

                        optimizer.step()
                        scheduler.step()
                        # optimizer.zero_grad()
                        model.zero_grad()
                        global_step += 1

                        if global_step % 50 == 0:
                            _cur_steps = global_step if args.resume is None else global_step - args.resume
                            logger.info(
                                f"Training loss: {tr_loss / _cur_steps}\t"
                                f"Learning rate: {scheduler.get_lr()[0]}\t"
                                f"Global step: {global_step}")

                        if global_step % args.save_steps == 0:
                            if args.local_rank in [-1, 0]:
                                model_to_save = model.module if hasattr(
                                    model, 'module') else model
                                output_model_file = os.path.join(
                                    args.oss_cache_dir,
                                    f"pytorch_model_{global_step}.bin")
                                torch_save_to_oss(model_to_save.state_dict(),
                                                  output_model_file)

                            _suffix = "" if args.local_rank == -1 else f"_{args.local_rank}"
                            if args.fp16:
                                amp_file = os.path.join(
                                    args.oss_cache_dir,
                                    f"amp_{global_step}{_suffix}.bin")
                                torch_save_to_oss(amp.state_dict(), amp_file)
                            optimizer_file = os.path.join(
                                args.oss_cache_dir,
                                f"optimizer_{global_step}{_suffix}.pt")
                            torch_save_to_oss(optimizer.state_dict(),
                                              optimizer_file)
                            scheduler_file = os.path.join(
                                args.oss_cache_dir,
                                f"scheduler_{global_step}{_suffix}.pt")
                            torch_save_to_oss(scheduler.state_dict(),
                                              scheduler_file)

                            logger.info(
                                f"checkpoint of step {global_step} is saved to oss."
                            )

                    # del input_ids
                    # del input_masks
                    # del segment_ids
                    # del output_masks
                    # del target
                    # del batch

                chunk_index += 1
                train_start_index = train_end_index + 1

                # Save the model at the half of the epoch
                if (chunk_index == CHUNK_NUM // 2
                        or save_retry) and args.local_rank in [-1, 0]:
                    status = save(model, args.output_dir, str(epc + 0.5))
                    save_retry = (not status)

                del train_features_
                del all_input_ids
                del all_input_masks
                del all_segment_ids
                del all_output_masks
                del all_num_paragraphs
                del all_num_steps
                del train_data
                del train_sampler
                del train_dataloader
                gc.collect()

            # Save the model at the end of the epoch
            if args.local_rank in [-1, 0]:
                save(model, args.output_dir, str(epc + 1))
                # model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
                # output_model_file = os.path.join(args.oss_cache_dir, "pytorch_model_" + str(epc + 1) + ".bin")
                # torch_save_to_oss(model_to_save.state_dict(), output_model_file)

            epc += 1

    if do_train:
        return

    ##############################
    # Evaluation                 #
    ##############################
    assert args.model_suffix is not None

    if graph_retriever_config.db_save_path is not None:
        import sys
        sys.path.append('../')
        from pipeline.tfidf_retriever import TfidfRetriever
        tfidf_retriever = TfidfRetriever(graph_retriever_config.db_save_path,
                                         None)
    else:
        tfidf_retriever = None

    if args.oss_cache_dir is not None:
        file_name = 'pytorch_model_' + args.model_suffix + '.bin'
        model_state_dict = torch.load(
            load_buffer_from_oss(os.path.join(args.oss_cache_dir, file_name)))
    else:
        model_state_dict = load(args.output_dir, args.model_suffix)

    model = RobertaForGraphRetriever.from_pretrained(
        args.bert_model,
        state_dict=model_state_dict,
        graph_retriever_config=graph_retriever_config)
    model.to(device)

    model.eval()

    if args.pred_file is not None:
        pred_output = []

    eval_examples = processor.get_dev_examples(graph_retriever_config)

    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)

    TOTAL_NUM = len(eval_examples)
    eval_start_index = 0

    while eval_start_index < TOTAL_NUM:
        eval_end_index = min(
            eval_start_index + graph_retriever_config.eval_chunk - 1,
            TOTAL_NUM - 1)
        chunk_len = eval_end_index - eval_start_index + 1

        eval_features = convert_examples_to_features(
            eval_examples[eval_start_index:eval_start_index + chunk_len],
            args.max_seq_length, args.max_para_num, graph_retriever_config,
            tokenizer)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_masks = torch.tensor([f.input_masks for f in eval_features],
                                       dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_output_masks = torch.tensor(
            [f.output_masks for f in eval_features], dtype=torch.float)
        all_num_paragraphs = torch.tensor(
            [f.num_paragraphs for f in eval_features], dtype=torch.long)
        all_num_steps = torch.tensor([f.num_steps for f in eval_features],
                                     dtype=torch.long)
        all_ex_indices = torch.tensor([f.ex_index for f in eval_features],
                                      dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_masks,
                                  all_segment_ids, all_output_masks,
                                  all_num_paragraphs, all_num_steps,
                                  all_ex_indices)

        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        for input_ids, input_masks, segment_ids, output_masks, num_paragraphs, num_steps, ex_indices in tqdm(
                eval_dataloader, desc="Evaluating"):
            batch_max_len = input_masks.sum(dim=2).max().item()
            batch_max_para_num = num_paragraphs.max().item()

            batch_max_steps = num_steps.max().item()

            input_ids = input_ids[:, :batch_max_para_num, :batch_max_len]
            input_masks = input_masks[:, :batch_max_para_num, :batch_max_len]
            segment_ids = segment_ids[:, :batch_max_para_num, :batch_max_len]
            output_masks = output_masks[:, :batch_max_para_num +
                                        2, :batch_max_para_num + 1]
            output_masks[:, 1:, -1] = 1.0  # Ignore EOE in the first step

            input_ids = input_ids.to(device)
            input_masks = input_masks.to(device)
            segment_ids = segment_ids.to(device)
            output_masks = output_masks.to(device)

            examples = [
                eval_examples[eval_start_index + ex_indices[i].item()]
                for i in range(input_ids.size(0))
            ]

            with torch.no_grad():
                pred, prob, topk_pred, topk_prob = model.beam_search(
                    input_ids,
                    segment_ids,
                    input_masks,
                    examples=examples,
                    tokenizer=tokenizer,
                    retriever=tfidf_retriever,
                    split_chunk=args.split_chunk)

            for i in range(len(pred)):
                e = examples[i]
                titles = [e.title_order[p] for p in pred[i]]

                # Output predictions to a file
                if args.pred_file is not None:
                    pred_output.append({})
                    pred_output[-1]['q_id'] = e.guid

                    pred_output[-1]['titles'] = titles
                    pred_output[-1]['probs'] = []
                    for prob_ in prob[i]:
                        entry = {'EOE': prob_[-1]}
                        for j in range(len(e.title_order)):
                            entry[e.title_order[j]] = prob_[j]
                        pred_output[-1]['probs'].append(entry)

                    topk_titles = [[e.title_order[p] for p in topk_pred[i][j]]
                                   for j in range(len(topk_pred[i]))]
                    pred_output[-1]['topk_titles'] = topk_titles

                    topk_probs = []
                    for k in range(len(topk_prob[i])):
                        topk_probs.append([])
                        for prob_ in topk_prob[i][k]:
                            entry = {'EOE': prob_[-1]}
                            for j in range(len(e.title_order)):
                                entry[e.title_order[j]] = prob_[j]
                            topk_probs[-1].append(entry)
                    pred_output[-1]['topk_probs'] = topk_probs

                    # Output the selected paragraphs
                    context = {}
                    for ts in topk_titles:
                        for t in ts:
                            context[t] = e.all_paras[t]
                    pred_output[-1]['context'] = context

        eval_start_index = eval_end_index + 1

        del eval_features
        del all_input_ids
        del all_input_masks
        del all_segment_ids
        del all_output_masks
        del all_num_paragraphs
        del all_num_steps
        del all_ex_indices
        del eval_data

    if args.pred_file is not None:
        json.dump(pred_output, open(args.pred_file, 'w'))
Пример #5
0
class Trainer(object):
    def __init__(self, proto, stage="train"):
        # model config
        model_cfg = proto["model"]
        model_name = model_cfg["name"]
        self.model_name = model_name

        # dataset config
        data_cfg = proto["data"]
        train_data_path = data_cfg.get("train_path", None)
        val_data_path = data_cfg.get("val_path", None)
        pad = data_cfg.get("pad", 32)
        train_bs = data_cfg.get("train_batch_size", None)
        val_bs = data_cfg.get("val_batch_size", None)
        self.val_bs = val_bs
        self.skip_first = data_cfg.get("skip_first", False)
        self.delimiter = data_cfg.get("delimiter", "\t")

        # assorted config
        optim_cfg = proto.get("optimizer", {"lr": 0.00003})
        sched_cfg = proto.get("schedulers", None)
        loss = proto.get("loss", "CE")
        self.device = proto.get("device", None)

        model_cfg.pop("name")

        if torch.cuda.is_available() and self.device is not None:
            print("Using device: %d." % self.device)
            self.device = torch.device(self.device)
            self.gpu = True
        else:
            print("Using cpu device.")
            self.device = torch.device("cpu")
            self.gpu = False

        if stage == "train":
            if train_data_path is None or val_data_path is None:
                raise ValueError("Please specify both train and val data path.")
            if train_bs is None or val_bs is None:
                raise ValueError("Please specify both train and val batch size.")
            # loading model
            self.model = fetch_nn(model_name)(**model_cfg)
            self.model = self.model.cuda(self.device)

            # loading dataset and converting into dataloader
            self.train_data = ChineseTextSet(
                path=train_data_path, tokenizer=self.model.tokenizer, pad=pad,
                delimiter=self.delimiter, skip_first=self.skip_first)
            self.train_loader = DataLoader(
                self.train_data, train_bs, shuffle=True, num_workers=4)
            self.val_data = ChineseTextSet(
                path=val_data_path, tokenizer=self.model.tokenizer, pad=pad,
                delimiter=self.delimiter, skip_first=self.skip_first)
            self.val_loader = DataLoader(
                self.val_data, val_bs, shuffle=True, num_workers=4)

            time_format = "%Y-%m-%d...%H.%M.%S"
            id = time.strftime(time_format, time.localtime(time.time()))
            self.record_path = os.path.join(arg.record, model_name, id)

            os.makedirs(self.record_path)
            sys.stdout = Logger(os.path.join(self.record_path, 'records.txt'))
            print("Writing proto file to file directory: %s." % self.record_path)
            yaml.dump(proto, open(os.path.join(self.record_path, 'protocol.yml'), 'w'))

            print("*" * 25, " PROTO BEGINS ", "*" * 25)
            pprint(proto)
            print("*" * 25, " PROTO ENDS ", "*" * 25)

            self.optimizer = AdamW(self.model.parameters(), **optim_cfg)
            self.scheduler = fetch_scheduler(self.optimizer, sched_cfg)

            self.loss = fetch_loss(loss)

            self.best_f1 = 0.0
            self.best_step = 1
            self.start_step = 1

            self.num_steps = proto["num_steps"]
            self.num_epoch = math.ceil(self.num_steps / len(self.train_loader))

            # the number of steps to write down a log
            self.log_steps = proto["log_steps"]
            # the number of steps to validate on val dataset once
            self.val_steps = proto["val_steps"]

            self.f1_meter = AverageMeter()
            self.p_meter = AverageMeter()
            self.r_meter = AverageMeter()
            self.acc_meter = AverageMeter()
            self.loss_meter = AverageMeter()

        if stage == "test":
            if val_data_path is None:
                raise ValueError("Please specify the val data path.")
            if val_bs is None:
                raise ValueError("Please specify the val batch size.")
            id = proto["id"]
            ckpt_fold = proto.get("ckpt_fold", "runs")
            self.record_path = os.path.join(ckpt_fold, model_name, id)
            sys.stdout = Logger(os.path.join(self.record_path, 'tests.txt'))

            config, state_dict, fc_dict = self._load_ckpt(best=True, train=False)
            weights = {"config": config, "state_dict": state_dict}
            # loading trained model using config and state_dict
            self.model = fetch_nn(model_name)(weights=weights)
            # loading the weights for the final fc layer
            self.model.load_state_dict(fc_dict, strict=False)
            # loading model to gpu device if specified
            if self.gpu:
                self.model = self.model.cuda(self.device)

            print("Testing directory: %s." % self.record_path)
            print("*" * 25, " PROTO BEGINS ", "*" * 25)
            pprint(proto)
            print("*" * 25, " PROTO ENDS ", "*" * 25)

            self.val_path = val_data_path
            self.test_data = ChineseTextSet(
                path=val_data_path, tokenizer=self.model.tokenizer, pad=pad,
                delimiter=self.delimiter, skip_first=self.skip_first)
            self.test_loader = DataLoader(
                self.test_data, val_bs, shuffle=True, num_workers=4)

    def _save_ckpt(self, step, best=False, f=None, p=None, r=None):
        save_dir = os.path.join(self.record_path, "best_model.bin" if best else "latest_model.bin")
        torch.save({
            "step": step,
            "f1": f,
            "precision": p,
            "recall": r,
            "best_step": self.best_step,
            "best_f1": self.best_f1,
            "model": self.model.state_dict(),
            "config": self.model.config,
            "optimizer": self.optimizer.state_dict(),
            "schedulers": self.scheduler.state_dict(),
        }, save_dir)

    def _load_ckpt(self, best=False, train=False):
        load_dir = os.path.join(self.record_path, "best_model.bin" if best else "latest_model.bin")
        load_dict = torch.load(load_dir, map_location=self.device)
        self.start_step = load_dict["step"]
        self.best_step = load_dict["best_step"]
        self.best_f1 = load_dict["best_f1"]
        if train:
            self.optimizer.load_state_dict(load_dict["optimizer"])
            self.scheduler.load_state_dict(load_dict["schedulers"])
        print("Loading checkpoint from %s, best step: %d, best f1: %.4f."
              % (load_dir, self.best_step, self.best_f1))
        if not best:
            print("Checkpoint step %s, f1: %.4f, precision: %.4f, recall: %.4f."
                  % (self.start_step, load_dict["f1"],
                     load_dict["precision"], load_dict["recall"]))
        fc_dict = {
            "fc.weight": load_dict["model"]["fc.weight"],
            "fc.bias": load_dict["model"]["fc.bias"]
        }
        return load_dict["config"], load_dict["model"], fc_dict

    def to_cuda(self, *args):
        return [obj.cuda(self.device) for obj in args]

    @staticmethod
    def fixed_randomness():
        random.seed(0)
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)
        torch.cuda.manual_seed_all(0)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    @staticmethod
    def update_metrics(gt, pre, f1_m, p_m, r_m, acc_m):
        f1_value = f1(gt, pre, average="micro")
        f1_m.update(f1_value)
        p_value = precision(gt, pre, average="micro", zero_division=0)
        p_m.update(p_value)
        r_value = recall(gt, pre, average="micro")
        r_m.update(r_value)
        acc_value = accuracy(gt, pre)
        acc_m.update(acc_value)

    def train(self):
        timer = Timer()
        writer = SummaryWriter(self.record_path)
        print("*" * 25, " TRAINING BEGINS ", "*" * 25)
        start_epoch = self.start_step // len(self.train_loader) + 1
        for epoch_idx in range(start_epoch, self.num_epoch + 1):
            self.f1_meter.reset()
            self.p_meter.reset()
            self.r_meter.reset()
            self.acc_meter.reset()
            self.loss_meter.reset()
            self.optimizer.step()
            self.scheduler.step()
            train_generator = tqdm(enumerate(self.train_loader, 1), position=0, leave=True)

            for batch_idx, data in train_generator:
                global_step = (epoch_idx - 1) * len(self.train_loader) + batch_idx
                self.model.train()
                id, label, _, mask = data[:4]
                if self.gpu:
                    id, mask, label = self.to_cuda(id, mask, label)
                pre = self.model((id, mask))
                loss = self.loss(pre, label)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                lbl = label.cpu().numpy()
                yp = pre.argmax(1).cpu().numpy()
                self.update_metrics(
                    lbl, yp, self.f1_meter, self.p_meter,
                    self.r_meter, self.acc_meter
                )
                self.loss_meter.update(loss.item())

                if global_step % self.log_steps == 0 and writer is not None:
                    writer.add_scalar("train/f1", self.f1_meter.avg, global_step)
                    writer.add_scalar("train/loss", self.loss_meter.avg, global_step)
                    writer.add_scalar("train/lr", self.scheduler.get_lr()[0], global_step)

                train_generator.set_description(
                    "Train Epoch %d (%d/%d), "
                    "Global Step %d, Loss %.4f, f1 %.4f, p %.4f, r %.4f, acc %.4f, LR %.6f" % (
                        epoch_idx, batch_idx, len(self.train_loader), global_step,
                        self.loss_meter.avg, self.f1_meter.avg,
                        self.p_meter.avg, self.r_meter.avg,
                        self.acc_meter.avg,
                        self.scheduler.get_lr()[0]
                    )
                )

                # validating process
                if global_step % self.val_steps == 0:
                    print()
                    self.validate(epoch_idx, global_step, timer, writer)

                # when num_steps has been set and the training process will
                # be stopped earlier than the specified num_epochs, then stop.
                if self.num_steps is not None and global_step == self.num_steps:
                    if writer is not None:
                        writer.close()
                    print()
                    print("*" * 25, " TRAINING ENDS ", "*" * 25)
                    return

            train_generator.close()
            print()
        writer.close()
        print("*" * 25, " TRAINING ENDS ", "*" * 25)

    def validate(self, epoch, step, timer, writer):
        with torch.no_grad():
            f1_meter = AverageMeter()
            p_meter = AverageMeter()
            r_meter = AverageMeter()
            acc_meter = AverageMeter()
            loss_meter = AverageMeter()
            val_generator = tqdm(enumerate(self.val_loader, 1), position=0, leave=True)
            for val_idx, data in val_generator:
                self.model.eval()
                id, label, _, mask = data[:4]
                if self.gpu:
                    id, mask, label = self.to_cuda(id, mask, label)
                pre = self.model((id, mask))
                loss = self.loss(pre, label)

                lbl = label.cpu().numpy()
                yp = pre.argmax(1).cpu().numpy()
                self.update_metrics(lbl, yp, f1_meter, p_meter, r_meter, acc_meter)
                loss_meter.update(loss.item())

                val_generator.set_description(
                    "Eval Epoch %d (%d/%d), Global Step %d, Loss %.4f, "
                    "f1 %.4f, p %.4f, r %.4f, acc %.4f" % (
                        epoch, val_idx, len(self.val_loader), step,
                        loss_meter.avg, f1_meter.avg,
                        p_meter.avg, r_meter.avg, acc_meter.avg
                    )
                )

            print("Eval Epoch %d, f1 %.4f" % (epoch, f1_meter.avg))
            if writer is not None:
                writer.add_scalar("val/loss", loss_meter.avg, step)
                writer.add_scalar("val/f1", f1_meter.avg, step)
                writer.add_scalar("val/precision", p_meter.avg, step)
                writer.add_scalar("val/recall", r_meter.avg, step)
                writer.add_scalar("val/acc", acc_meter.avg, step)
            if f1_meter.avg > self.best_f1:
                self.best_f1 = f1_meter.avg
                self.best_step = step
                self._save_ckpt(step, best=True)
            print("Best Step %d, Best f1 %.4f, Running Time: %s, Estimated Time: %s" % (
                self.best_step, self.best_f1, timer.measure(), timer.measure(step / self.num_steps)
            ))
            self._save_ckpt(step, best=False, f=f1_meter.avg, p=p_meter.avg, r=r_meter.avg)

    def test(self):
        # t_idx = random.randint(0, self.val_bs)
        t_idx = random.randint(0, 5)
        with torch.no_grad():
            self.fixed_randomness()  # for reproduction

            # for writing the total predictions to disk
            data_idxs = list()
            all_preds = list()

            # for ploting P-R Curve
            predicts = list()
            truths = list()

            # for showing predicted samples
            show_ctxs = list()
            pred_lbls = list()
            targets = list()

            f1_meter = AverageMeter()
            p_meter = AverageMeter()
            r_meter = AverageMeter()
            accuracy_meter = AverageMeter()
            test_generator = tqdm(enumerate(self.test_loader, 1))
            for idx, data in test_generator:
                self.model.eval()
                id, label, _, mask, data_idx = data
                if self.gpu:
                    id, mask, label = self.to_cuda(id, mask, label)
                pre = self.model((id, mask))

                lbl = label.cpu().numpy()
                yp = pre.argmax(1).cpu().numpy()
                self.update_metrics(lbl, yp, f1_meter, p_meter, r_meter, accuracy_meter)

                test_generator.set_description(
                    "Test %d/%d, f1 %.4f, p %.4f, r %.4f, acc %.4f"
                    % (idx, len(self.test_loader), f1_meter.avg,
                       p_meter.avg, r_meter.avg, accuracy_meter.avg)
                )

                data_idxs.append(data_idx.numpy())
                all_preds.append(yp)

                predicts.append(torch.select(pre, dim=1, index=1).cpu().numpy())
                truths.append(lbl)

                # show some of the sample
                ctx = torch.select(id, dim=0, index=t_idx).detach()
                ctx = self.model.tokenizer.convert_ids_to_tokens(ctx)
                ctx = "".join([_ for _ in ctx if _ not in [PAD, CLS]])
                yp = yp[t_idx]
                lbl = lbl[t_idx]

                show_ctxs.append(ctx)
                pred_lbls.append(yp)
                targets.append(lbl)

            print("*" * 25, " SAMPLE BEGINS ", "*" * 25)
            for c, t, l in zip(show_ctxs, targets, pred_lbls):
                print("ctx: ", c, " gt: ", t, " est: ", l)
            print("*" * 25, " SAMPLE ENDS ", "*" * 25)
            print("Test, FINAL f1 %.4f, "
                  "p %.4f, r %.4f, acc %.4f\n" %
                  (f1_meter.avg, p_meter.avg, r_meter.avg, accuracy_meter.avg))

            # output the final results to disk
            data_idxs = np.concatenate(data_idxs, axis=0)
            all_preds = np.concatenate(all_preds, axis=0)
            write_predictions(
                self.val_path, os.path.join(self.record_path, "results.txt"),
                data_idxs, all_preds, delimiter=self.delimiter, skip_first=self.skip_first
            )

            # output the p-r values for future plotting P-R Curve
            predicts = np.concatenate(predicts, axis=0)
            truths = np.concatenate(truths, axis=0)
            values = precision_recall_curve(truths, predicts)
            with open(os.path.join(self.record_path, "pr.values"), "wb") as f:
                pickle.dump(values, f)
            p_value, r_value, _ = values

            # plot P-R Curve if specified
            if arg.image:
                plt.figure()
                plt.plot(
                    p_value, r_value,
                    label="%s (ACC: %.2f, F1: %.2f)"
                          % (self.model_name, accuracy_meter.avg, f1_meter.avg)
                )
                plt.legend(loc="best")
                plt.title("2-Classes P-R curve")
                plt.xlabel("precision")
                plt.ylabel("recall")
                plt.savefig(os.path.join(self.record_path, "P-R.png"))
                plt.show()
Пример #6
0
class LengthDropTrainer(Trainer):
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer = None,
        best_metric: str = 'acc',
        length_drop_args: LengthDropArguments = None,
        **kwargs,
    ):
        super(LengthDropTrainer, self).__init__(**kwargs)
        self.tokenizer = tokenizer
        self.best_metric = best_metric
        if length_drop_args is None:
            length_drop_args = LengthDropArguments()
        self.length_drop_args = length_drop_args

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Setup the optimizer and the learning rate scheduler.
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
        """
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            retention_params = []
            wd_params = []
            no_wd_params = []
            for n, p in self.model.named_parameters():
                if "retention" in n:
                    retention_params.append(p)
                elif any(nd in n for nd in no_decay):
                    no_wd_params.append(p)
                else:
                    wd_params.append(p)
            optimizer_grouped_parameters = [
                {"params": wd_params, "weight_decay": self.args.weight_decay, "lr": self.args.learning_rate},
                {"params": no_wd_params, "weight_decay": 0.0, "lr": self.args.learning_rate}
            ]
            if len(retention_params) > 0:
                optimizer_grouped_parameters.append(
                    {"params": retention_params, "weight_decay": 0.0, "lr": self.length_drop_args.lr_soft_extract}
                )
            self.optimizer = AdamW(
                optimizer_grouped_parameters,
                lr=self.args.learning_rate,
                betas=(self.args.adam_beta1, self.args.adam_beta2),
                eps=self.args.adam_epsilon,
            )
        if self.lr_scheduler is None:
            if self.args.warmup_ratio is not None:
                num_warmup_steps = int(self.args.warmup_ratio * num_training_steps)
            else:
                num_warmup_steps = self.args.warmup_steps
            self.lr_scheduler = get_linear_schedule_with_warmup(
                self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
            )

    def div_loss(self, loss):
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps
        return loss

    def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
        """
        Main training entry point.
        Args:
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
            trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
                The trial run or the hyperparameter dictionary for hyperparameter search.
        """
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

        # Model re-init
        if self.model_init is not None:
            # Seed must be set before instantiating the model when using model_init.
            set_seed(self.args.seed)
            model = self.model_init()
            self.model = model.to(self.args.device)

            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None

        # Data loader and number of training steps
        train_dataloader = self.get_train_dataloader()
        num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                self.args.max_steps % num_update_steps_per_epoch > 0
            )
        else:
            t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs
            self.args.max_steps = t_total

        self.create_optimizer_and_scheduler(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            self.optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
            self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16 and _use_apex:
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})

        # Train!
        if is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
            )
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])

                epochs_trained = self.global_step // num_update_steps_per_epoch
                steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", self.global_step)
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss_sum = 0.0
        loss_sum = defaultdict(float)
        best = {self.best_metric: None}
        model.zero_grad()
        disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
        train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
        for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
                epoch_iterator = parallel_loader
            else:
                epoch_iterator = train_dataloader

            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

            epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    epoch_pbar.update(1)
                    continue

                model.train()
                inputs = self._prepare_inputs(inputs)

                inputs["output_attentions"] = self.length_drop_args.length_config is not None

                layer_config = sample_layer_configuration(
                    model.config.num_hidden_layers,
                    layer_dropout_prob=self.length_drop_args.layer_dropout_prob,
                    layer_dropout=0,
                )
                inputs["layer_config"] = layer_config

                inputs["length_config"] = self.length_drop_args.length_config

                outputs = model(**inputs)
                # Save past state if it exists
                if self.args.past_index >= 0:
                    self._past = outputs[self.args.past_index]
                task_loss = self.div_loss(outputs[0])
                if self.length_drop_args.length_adaptive:
                    loss_sum["full"] += task_loss.item()
                loss = task_loss
                if self.length_drop_args.length_adaptive:
                    loss = loss / (self.length_drop_args.num_sandwich + 2)

                tr_loss_sum += loss.item()
                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                # inplace distillation
                if self.length_drop_args.length_adaptive:
                    logits = outputs[1].detach()

                    for i in range(self.length_drop_args.num_sandwich + 1):
                        inputs["output_attentions"] = True

                        layer_config = sample_layer_configuration(
                            model.config.num_hidden_layers,
                            layer_dropout_prob=self.length_drop_args.layer_dropout_prob,
                            layer_dropout=(self.length_drop_args.layer_dropout_bound if i == 0 else None),
                            layer_dropout_bound=self.length_drop_args.layer_dropout_bound,
                        )
                        inputs["layer_config"] = layer_config

                        length_config = sample_length_configuration(
                            self.args.max_seq_length,
                            model.config.num_hidden_layers,
                            layer_config,
                            length_drop_ratio=(self.length_drop_args.length_drop_ratio_bound if i == 0 else None),
                            length_drop_ratio_bound=self.length_drop_args.length_drop_ratio_bound,
                        )
                        inputs["length_config"] = length_config

                        outputs_sub = model(**inputs)
                        task_loss_sub = self.div_loss(outputs_sub[0])
                        if i == 0:
                            loss_sum["smallest"] += task_loss_sub.item()
                            loss_sum["sub"] += 0
                        else:
                            loss_sum["sub"] += task_loss_sub.item() / self.length_drop_args.num_sandwich

                        logits_sub = outputs_sub[1]
                        loss_fct = KLDivLoss(reduction="batchmean")
                        kl_loss = loss_fct(F.log_softmax(logits, -1), F.softmax(logits_sub, -1))
                        loss = self.div_loss(kl_loss)
                        loss_sum["kl"] += loss.item() / (self.length_drop_args.num_sandwich + 1)
                        loss = loss / (self.length_drop_args.num_sandwich + 2)

                        tr_loss_sum += loss.item()
                        if self.args.fp16 and _use_native_amp:
                            self.scaler.scale(loss).backward()
                        elif self.args.fp16 and _use_apex:
                            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                                scaled_loss.backward()
                        else:
                            loss.backward()

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    (step + 1) == len(epoch_iterator) <= self.args.gradient_accumulation_steps
                ):
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                    if is_torch_tpu_available():
                        xm.optimizer_step(self.optimizer)
                    elif self.args.fp16 and _use_native_amp:
                        self.scaler.step(self.optimizer)
                        self.scaler.update()
                    else:
                        self.optimizer.step()

                    self.lr_scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        # backward compatibility for pytorch schedulers
                        lr = (
                            self.lr_scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >= version.parse("1.4")
                            else self.lr_scheduler.get_lr()[0]
                        )
                        loss = tr_loss_sum / self.args.logging_steps
                        tr_loss_sum = 0.0
                        logs = {"lr": lr, "loss": loss}
                        log_str = f"[{self.global_step:5d}] lr {lr:g} | loss {loss:2.3f}"

                        for key, value in loss_sum.items():
                            value /= self.args.logging_steps
                            loss_sum[key] = 0.0
                            logs[f"{key}_loss"] = value
                            log_str += f" | {key}_loss {value:2.3f}"

                        self.log(logs, "train")
                        logger.info(log_str)

                    '''
                    if (
                        self.args.evaluation_strategy == EvaluationStrategy.STEPS
                        and self.global_step % self.args.eval_steps == 0
                    ):
                        results = self.evaluate()
                        self._report_to_hp_search(trial, epoch, results)
                    '''

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert (
                                model.module is self.model
                            ), f"Module {model.module} should be a reference to self.model"
                        else:
                            assert model is self.model, f"Model {model} should be a reference to self.model"

                        if self.args.evaluate_during_training:
                            results = self.evaluate()
                            results = {k[5:]: v for k, v in results.items() if k.startswith("eval_")}
                            self.log(results, "dev")
                            msg = " | ".join([f"{k} {v:.3f}" for k, v in results.items()])
                            logger.info(f"  [{self.global_step:5d}] {msg}")

                        # Save model checkpoint
                        if self.args.save_only_best:
                            output_dirs = []
                        else:
                            checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
                            if self.hp_search_backend is not None and trial is not None:
                                run_id = (
                                    trial.number
                                    if self.hp_search_backend == HPSearchBackend.OPTUNA
                                    else tune.get_trial_id()
                                )
                                checkpoint_folder += f"-run-{run_id}"
                            output_dirs = [os.path.join(self.args.output_dir, checkpoint_folder)]
                            
                        if self.args.evaluate_during_training:
                            if best[self.best_metric] is None or results[self.best_metric] > best[self.best_metric]:
                                logger.info("Congratulations, best model so far!")
                                output_dirs.append(os.path.join(self.args.output_dir, "checkpoint-best"))
                                best = results

                        for output_dir in output_dirs:
                            self.save_model(output_dir)

                            if self.is_world_master() and self.tokenizer is not None:
                                self.tokenizer.save_pretrained(output_dir)

                            if self.is_world_process_zero():
                                self._rotate_checkpoints(use_mtime=True)

                            '''
                            if is_torch_tpu_available():
                                xm.rendezvous("saving_optimizer_states")
                                xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                            elif self.is_world_process_zero():
                                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                            '''

                epoch_pbar.update(1)
                if 0 < self.args.max_steps <= self.global_step:
                    break
            epoch_pbar.close()
            train_pbar.update(1)

            '''
            if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
                results = self.evaluate()
                self._report_to_hp_search(trial, epoch, results)
            '''

            if self.args.tpu_metrics_debug or self.args.debug:
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
            if 0 < self.args.max_steps <= self.global_step:
                break

        train_pbar.close()
        if self.tb_writer:
            self.tb_writer.close()
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        return self.global_step, best

    def log(self, logs, mode="train"):
        self._setup_loggers()
        if self.global_step is None:
            # when logging evaluation metrics without training
            self.global_step = 0
        if self.tb_writer:
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, self.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
            self.tb_writer.flush()
        if is_wandb_available():
            if self.is_world_process_zero():
                wandb.log(logs, step=self.global_step)
        if is_comet_available():
            if self.is_world_process_zero():
                experiment = comet_ml.config.get_global_experiment()
                if experiment is not None:
                    experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
        output = {**logs, **{"step": self.global_step}}
        if self.is_world_process_zero():
            self.log_history.append(output)

    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
        """
        Run evaluation and returns metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent (pass it to the init :obj:`compute_metrics` argument).

        You can also subclass and override this method to inject custom behavior.

        Args:
            eval_dataset (:obj:`Dataset`, `optional`):
                Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
                columns not accepted by the ``model.forward()`` method are automatically removed.

        Returns:
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self.prediction_loop(eval_dataloader, description="Evaluation")

        if self.args.tpu_metrics_debug or self.args.debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        return output.metrics

    def prediction_loop(
        self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
    ) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.

        Works both with or without labels.
        """
        if hasattr(self, "_prediction_loop"):
            warnings.warn(
                "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
                FutureWarning,
            )
            return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)

        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )

        '''
        assert not getattr(
            self.model.config, "output_attentions", False
        ), "The prediction loop does not work with `output_attentions=True`."
        assert not getattr(
            self.model.config, "output_hidden_states", False
        ), "The prediction loop does not work with `output_hidden_states=True`."
        '''

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        else:
            model = self.model
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

        '''
        batch_size = dataloader.batch_size
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
        '''
        eval_losses: List[float] = []
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
        model.eval()

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

        if self.args.past_index >= 0:
            self._past = None

        disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
        for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
            batch_size = inputs[list(inputs.keys())[0]].shape[0]
            if loss is not None:
                eval_losses.extend([loss] * batch_size)
            if logits is not None:
                preds = logits if preds is None else nested_concat(preds, logits, dim=0)
            if labels is not None:
                label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
        elif is_torch_tpu_available():
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
            if preds is not None:
                preds = nested_xla_mesh_reduce(preds, "eval_preds")
            if label_ids is not None:
                label_ids = nested_xla_mesh_reduce(label_ids, "eval_label_ids")
            if eval_losses is not None:
                eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = nested_numpify(preds)
        if label_ids is not None:
            label_ids = nested_numpify(label_ids)

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            if self.args.local_rank != -1:
                metrics["eval_loss"] = (
                    distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
                    .mean()
                    .item()
                )
            else:
                metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)

    def prediction_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
            A tuple with the loss, logits and labels (each being optional).
        """
        has_labels = all(inputs.get(k) is not None for k in self.args.label_names)
        inputs = self._prepare_inputs(inputs)

        output_attentions = getattr(inputs, 'output_attentions', None)
        output_hidden_states = getattr(inputs, 'output_hidden_states', None)

        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states

        num_additional_outputs = int(output_attentions == True) + int(output_hidden_states == True)

        with torch.no_grad():
            outputs = model(**inputs)
            if has_labels:
                # The .mean() is to reduce in case of distributed training
                loss = outputs[0].mean().item()
                logits = outputs[1:(len(outputs) - num_additional_outputs)]
            else:
                loss = None
                # Slicing so we get a tuple even if `outputs` is a `ModelOutput`.
                logits = outputs[:(len(outputs) - num_additional_outputs)]
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

        logits = tuple(logit.detach() for logit in logits)
        if len(logits) == 1:
            logits = logits[0]

        if has_labels:
            labels = tuple(inputs.get(name).detach() for name in self.args.label_names)
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        return (loss, logits, labels)

    def init_evolution(self, lower_constraint=0, upper_constraint=None):
        size = (1, self.args.max_seq_length)
        self.dummy_inputs = (
            torch.ones(size, dtype=torch.long).to(self.args.device),
            torch.ones(size, dtype=torch.long).to(self.args.device),
            torch.zeros(size, dtype=torch.long).to(self.args.device),
        )
        if self.model.config.model_type == "distilbert":
            self.dummy_inputs = self.dummy_inputs[:2]


        self.lower_constraint = lower_constraint
        self.upper_constraint = upper_constraint

        self.store = {}  # gene: (macs, score, method, parent(s))
        self.population = []

    def load_store(self, store_file):
        if not os.path.isfile(store_file):
            return
        with open(store_file, 'r') as f:
            for row in csv.reader(f, delimiter='\t'):
                row = tuple(eval(x) for x in row[:3])
                self.store[row[0]] = row[1:3] + (0, None)

    def save_store(self, store_file):
        store_keys = sorted(self.store.keys(), key=lambda x: self.store[x][0])
        with open(store_file, 'w') as f:
            writer = csv.writer(f, delimiter='\t')
            for gene in store_keys:
                writer.writerow([str(gene)] + [str(x) for x in self.store[gene]])

    def save_population(self, population_file, population):
        with open(population_file, 'w') as f:
            writer = csv.writer(f, delimiter='\t')
            for gene in population:
                writer.writerow([str(gene)] + [str(x) for x in self.store[gene]])

    def ccw(self, gene0, gene1, gene2):
        x0, y0 = self.store[gene0][:2]
        x1, y1 = self.store[gene1][:2]
        x2, y2 = self.store[gene2][:2]
        return (x0 * y1 + x1 * y2 + x2 * y0) - (x0 * y2 + x1 * y0 + x2 * y1)

    def convex_hull(self):
        hull = self.population[:2]
        for gene in self.population[2:]:
            if self.store[hull[-1]][1] >= self.store[gene][1]:
                continue
            while len(hull) >= 2 and self.ccw(hull[-2], hull[-1], gene) >= 0:
                del hull[-1]
            hull.append(gene)
        return hull

    def pareto_frontier(self):
        self.population = sorted(self.population, key=lambda x: self.store[x][:2])

        frontier = [self.population[0]]
        for gene in self.population[1:-1]:
            if self.store[gene][1] > self.store[frontier[-1]][1]:
                if self.store[gene][0] == frontier[-1][0]:
                    del frontier[-1]
                frontier.append(gene)
        frontier.append(self.population[-1])
        self.population = frontier

        area = 0
        for gene0, gene1 in zip(self.population[:-1], self.population[1:]):
            x0, y0 = self.store[gene0][:2]
            x1, y1 = self.store[gene1][:2]
            area += (x1 - x0) * y0
        area /= (self.upper_constraint - self.lower_constraint)
        return self.population, area

    def add_gene(self, gene, macs=None, score=None, method=0, parents=None):
        if gene not in self.store:
            self.model.eval()
            if self.model.config.model_type == "distilbert":
                bert = self.model.distilbert
            else:
                assert hasattr(self.model, "bert")
                bert = self.model.bert
            bert.set_length_config(gene)
            macs = macs or torchprofile.profile_macs(self.model, args=self.dummy_inputs)
            # logger.info(gene, macs)
            if macs < self.lower_constraint:
                return False
            score = score or self.evaluate()["eval_" + self.best_metric]
            self.store[gene] = (macs, score, method, parents)
            logger.info(store2str(gene, macs, score, method, parents))

        macs = self.store[gene][0]
        if macs >= self.lower_constraint \
                and (self.upper_constraint is None or macs <= self.upper_constraint) \
                and gene not in self.population:
            self.population.append(gene)
            return True
        return False

    def mutate(self, mutation_prob):
        gene = random.choice(self.population)
        mutated_gene = ()
        for i in range(self.model.config.num_hidden_layers):
            if np.random.uniform() < mutation_prob:
                prev = (self.args.max_seq_length if i == 0 else mutated_gene[i - 1])
                next = (2 if i == self.model.config.num_hidden_layers - 1 else gene[i + 1])
                mutated_gene += (random.randrange(next, prev + 1),)
            else:
                mutated_gene += (gene[i],)
        return self.add_gene(mutated_gene, method=1, parents=(gene,))

    def crossover(self):
        gene0, gene1 = random.sample(self.population, 2)
        crossovered_gene = tuple((g0 + g1 + 1) // 2 for g0, g1 in zip(gene0, gene1))
        return self.add_gene(crossovered_gene, method=2, parents=(gene0, gene1))
Пример #7
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )
    parser.add_argument(
        "--from_pretrained",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )
    parser.add_argument(
        "--output_dir",
        default="save",
        type=str,
        help="The output directory where the model checkpoints will be written.",
    )
    parser.add_argument(
        "--config_file",
        default="config/bert_base_6layer_6conect.json",
        type=str,
        help="The config file which specified the model details.",
    )
    parser.add_argument(
        "--num_train_epochs",
        default=20,
        type=int,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--train_iter_multiplier",
        default=1.0,
        type=float,
        help="multiplier for the multi-task training.",
    )
    parser.add_argument(
        "--train_iter_gap",
        default=4,
        type=int,
        help="forward every n iteration is the validation score is not improving over the last 3 epoch, -1 means will stop",
    )
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help="Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.",
    )
    parser.add_argument(
        "--no_cuda", action="store_true", help="Whether not to use CUDA when available"
    )
    parser.add_argument(
        "--do_lower_case",
        default=True,
        type=bool,
        help="Whether to lower case the input text. True for uncased models, False for cased models.",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="local_rank for distributed training on gpus",
    )
    parser.add_argument(
        "--seed", type=int, default=0, help="random seed for initialization"
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumualte before performing a backward/update pass.",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit float precision instead of 32-bit",
    )
    parser.add_argument(
        "--loss_scale",
        type=float,
        default=0,
        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=16,
        help="Number of workers in the dataloader.",
    )
    parser.add_argument(
        "--save_name", default="", type=str, help="save name for training."
    )
    parser.add_argument(
        "--in_memory",
        default=False,
        type=bool,
        help="whether use chunck for parallel training.",
    )
    parser.add_argument(
        "--optim", default="AdamW", type=str, help="what to use for the optimization."
    )
    parser.add_argument(
        "--tasks", default="", type=str, help="1-2-3... training task separate by -"
    )
    parser.add_argument(
        "--freeze",
        default=-1,
        type=int,
        help="till which layer of textual stream of vilbert need to fixed.",
    )
    parser.add_argument(
        "--vision_scratch",
        action="store_true",
        help="whether pre-trained the image or not.",
    )
    parser.add_argument(
        "--evaluation_interval", default=1, type=int, help="evaluate very n epoch."
    )
    parser.add_argument(
        "--lr_scheduler",
        default="mannul",
        type=str,
        help="whether use learning rate scheduler.",
    )
    parser.add_argument(
        "--baseline", action="store_true", help="whether use single stream baseline."
    )
    parser.add_argument(
        "--resume_file", default="", type=str, help="Resume from checkpoint"
    )
    parser.add_argument(
        "--dynamic_attention",
        action="store_true",
        help="whether use dynamic attention.",
    )
    parser.add_argument(
        "--clean_train_sets",
        default=True,
        type=bool,
        help="whether clean train sets for multitask data.",
    )
    parser.add_argument(
        "--visual_target",
        default=0,
        type=int,
        help="which target to use for visual branch. \
        0: soft label, \
        1: regress the feature, \
        2: NCE loss.",
    )
    parser.add_argument(
        "--task_specific_tokens",
        action="store_true",
        help="whether to use task specific tokens for the multi-task learning.",
    )

    args = parser.parse_args()
    with open("vilbert_tasks.yml", "r") as f:
        task_cfg = edict(yaml.safe_load(f))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.baseline:
        from pytorch_transformers.modeling_bert import BertConfig
        from vilbert.basebert import BaseBertForVLTasks
    else:
        from vilbert.vilbert import BertConfig
        from vilbert.vilbert import VILBertForVLTasks

    task_names = []
    task_lr = []
    for i, task_id in enumerate(args.tasks.split("-")):
        task = "TASK" + task_id
        name = task_cfg[task]["name"]
        task_names.append(name)
        task_lr.append(task_cfg[task]["lr"])

    base_lr = min(task_lr)
    loss_scale = {}
    for i, task_id in enumerate(args.tasks.split("-")):
        task = "TASK" + task_id
        loss_scale[task] = task_lr[i] / base_lr

    if args.save_name:
        prefix = "-" + args.save_name
    else:
        prefix = ""
    timeStamp = (
        "-".join(task_names)
        + "_"
        + args.config_file.split("/")[1].split(".")[0]
        + prefix
    )
    savePath = os.path.join(args.output_dir, timeStamp)

    bert_weight_name = json.load(
        open("config/" + args.bert_model + "_weight_name.json", "r")
    )

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        )
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        torch.distributed.init_process_group(backend="nccl")

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
            device, n_gpu, bool(args.local_rank != -1), args.fp16
        )
    )

    default_gpu = False
    if dist.is_available() and args.local_rank != -1:
        rank = dist.get_rank()
        if rank == 0:
            default_gpu = True
    else:
        default_gpu = True

    if default_gpu:
        if not os.path.exists(savePath):
            os.makedirs(savePath)

    config = BertConfig.from_json_file(args.config_file)
    if default_gpu:
        # save all the hidden parameters.
        with open(os.path.join(savePath, "command.txt"), "w") as f:
            print(args, file=f)  # Python 3.x
            print("\n", file=f)
            print(config, file=f)

    task_batch_size, task_num_iters, task_ids, task_datasets_train, task_datasets_val, task_dataloader_train, task_dataloader_val = LoadDatasets(
        args, task_cfg, args.tasks.split("-")
    )

    logdir = os.path.join(savePath, "logs")
    tbLogger = utils.tbLogger(
        logdir,
        savePath,
        task_names,
        task_ids,
        task_num_iters,
        args.gradient_accumulation_steps,
    )

    if args.visual_target == 0:
        config.v_target_size = 1601
        config.visual_target = args.visual_target
    else:
        config.v_target_size = 2048
        config.visual_target = args.visual_target

    if args.task_specific_tokens:
        config.task_specific_tokens = True

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_ave_iter = {}
    task_stop_controller = {}
    for task_id, num_iter in task_num_iters.items():
        task_ave_iter[task_id] = int(
            task_cfg[task]["num_epoch"]
            * num_iter
            * args.train_iter_multiplier
            / args.num_train_epochs
        )
        task_stop_controller[task_id] = utils.MultiTaskStopOnPlateau(
            mode="max",
            patience=1,
            continue_threshold=0.005,
            cooldown=1,
            threshold=0.001,
        )

    task_ave_iter_list = sorted(task_ave_iter.values())
    median_num_iter = task_ave_iter_list[-1]
    num_train_optimization_steps = (
        median_num_iter * args.num_train_epochs // args.gradient_accumulation_steps
    )
    num_labels = max([dataset.num_labels for dataset in task_datasets_train.values()])

    if args.dynamic_attention:
        config.dynamic_attention = True
    if "roberta" in args.bert_model:
        config.model = "roberta"

    if args.baseline:
        model = BaseBertForVLTasks.from_pretrained(
            args.from_pretrained,
            config=config,
            num_labels=num_labels,
            default_gpu=default_gpu,
        )
    else:
        model = VILBertForVLTasks.from_pretrained(
            args.from_pretrained,
            config=config,
            num_labels=num_labels,
            default_gpu=default_gpu,
        )

    task_losses = LoadLosses(args, task_cfg, args.tasks.split("-"))

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]

    if args.freeze != -1:
        bert_weight_name_filtered = []
        for name in bert_weight_name:
            if "embeddings" in name:
                bert_weight_name_filtered.append(name)
            elif "encoder" in name:
                layer_num = name.split(".")[2]
                if int(layer_num) <= args.freeze:
                    bert_weight_name_filtered.append(name)

        optimizer_grouped_parameters = []
        for key, value in dict(model.named_parameters()).items():
            if key[12:] in bert_weight_name_filtered:
                value.requires_grad = False

        if default_gpu:
            print("filtered weight")
            print(bert_weight_name_filtered)

    optimizer_grouped_parameters = []
    for key, value in dict(model.named_parameters()).items():
        if value.requires_grad:
            if "vil_" in key:
                lr = 1e-4
            else:
                if args.vision_scratch:
                    if key[12:] in bert_weight_name:
                        lr = base_lr
                    else:
                        lr = 1e-4
                else:
                    lr = base_lr
            if any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [
                    {"params": [value], "lr": lr, "weight_decay": 0.0}
                ]
            if not any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [
                    {"params": [value], "lr": lr, "weight_decay": 0.01}
                ]

    if default_gpu:
        print(len(list(model.named_parameters())), len(optimizer_grouped_parameters))

    if args.optim == "AdamW":
        optimizer = AdamW(optimizer_grouped_parameters, lr=base_lr, correct_bias=False)
    elif args.optim == "RAdam":
        optimizer = RAdam(optimizer_grouped_parameters, lr=base_lr)

    warmpu_steps = args.warmup_proportion * num_train_optimization_steps

    if args.lr_scheduler == "warmup_linear":
        warmup_scheduler = WarmupLinearSchedule(
            optimizer, warmup_steps=warmpu_steps, t_total=num_train_optimization_steps
        )
    else:
        warmup_scheduler = WarmupConstantSchedule(optimizer, warmup_steps=warmpu_steps)

    lr_reduce_list = np.array([5, 7])
    if args.lr_scheduler == "automatic":
        lr_scheduler = ReduceLROnPlateau(
            optimizer, mode="max", factor=0.2, patience=1, cooldown=1, threshold=0.001
        )
    elif args.lr_scheduler == "cosine":
        lr_scheduler = CosineAnnealingLR(
            optimizer, T_max=median_num_iter * args.num_train_epochs
        )
    elif args.lr_scheduler == "cosine_warm":
        lr_scheduler = CosineAnnealingWarmRestarts(
            optimizer, T_0=median_num_iter * args.num_train_epochs
        )
    elif args.lr_scheduler == "mannul":

        def lr_lambda_fun(epoch):
            return pow(0.2, np.sum(lr_reduce_list <= epoch))

        lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda_fun)

    startIterID = 0
    global_step = 0
    start_epoch = 0

    if args.resume_file != "" and os.path.exists(args.resume_file):
        checkpoint = torch.load(args.resume_file, map_location="cpu")
        new_dict = {}
        for attr in checkpoint["model_state_dict"]:
            if attr.startswith("module."):
                new_dict[attr.replace("module.", "", 1)] = checkpoint[
                    "model_state_dict"
                ][attr]
            else:
                new_dict[attr] = checkpoint["model_state_dict"][attr]
        model.load_state_dict(new_dict)
        warmup_scheduler.load_state_dict(checkpoint["warmup_scheduler_state_dict"])
        # lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        global_step = checkpoint["global_step"]
        start_epoch = int(checkpoint["epoch_id"]) + 1
        task_stop_controller = checkpoint["task_stop_controller"]
        tbLogger = checkpoint["tb_logger"]
        del checkpoint

    model.to(device)

    for state in optimizer.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.cuda()

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model, delay_allreduce=True)

    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if default_gpu:
        print("***** Running training *****")
        print("  Num Iters: ", task_num_iters)
        print("  Batch size: ", task_batch_size)
        print("  Num steps: %d" % num_train_optimization_steps)

    task_iter_train = {name: None for name in task_ids}
    task_count = {name: 0 for name in task_ids}
    for epochId in tqdm(range(start_epoch, args.num_train_epochs), desc="Epoch"):
        model.train()
        for step in range(median_num_iter):
            iterId = startIterID + step + (epochId * median_num_iter)
            first_task = True
            for task_id in task_ids:
                is_forward = False
                if (not task_stop_controller[task_id].in_stop) or (
                    iterId % args.train_iter_gap == 0
                ):
                    is_forward = True

                if is_forward:
                    loss, score = ForwardModelsTrain(
                        args,
                        task_cfg,
                        device,
                        task_id,
                        task_count,
                        task_iter_train,
                        task_dataloader_train,
                        model,
                        task_losses,
                    )

                    loss = loss * loss_scale[task_id]
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    loss.backward()
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        if args.fp16:
                            lr_this_step = args.learning_rate * warmup_linear(
                                global_step / num_train_optimization_steps,
                                args.warmup_proportion,
                            )
                            for param_group in optimizer.param_groups:
                                param_group["lr"] = lr_this_step

                        if first_task and (
                            global_step < warmpu_steps
                            or args.lr_scheduler == "warmup_linear"
                        ):
                            warmup_scheduler.step()

                        optimizer.step()
                        model.zero_grad()
                        if first_task:
                            global_step += 1
                            first_task = False

                        if default_gpu:
                            tbLogger.step_train(
                                epochId,
                                iterId,
                                float(loss),
                                float(score),
                                optimizer.param_groups[0]["lr"],
                                task_id,
                                "train",
                            )

            if "cosine" in args.lr_scheduler and global_step > warmpu_steps:
                lr_scheduler.step()

            if (
                step % (20 * args.gradient_accumulation_steps) == 0
                and step != 0
                and default_gpu
            ):
                tbLogger.showLossTrain()

            # decided whether to evaluate on each tasks.
            for task_id in task_ids:
                if (iterId != 0 and iterId % task_num_iters[task_id] == 0) or (
                    epochId == args.num_train_epochs - 1 and step == median_num_iter - 1
                ):
                    evaluate(
                        args,
                        task_dataloader_val,
                        task_stop_controller,
                        task_cfg,
                        device,
                        task_id,
                        model,
                        task_losses,
                        epochId,
                        default_gpu,
                        tbLogger,
                    )

        if args.lr_scheduler == "automatic":
            lr_scheduler.step(sum(val_scores.values()))
            logger.info("best average score is %3f" % lr_scheduler.best)
        elif args.lr_scheduler == "mannul":
            lr_scheduler.step()

        if epochId in lr_reduce_list:
            for task_id in task_ids:
                # reset the task_stop_controller once the lr drop
                task_stop_controller[task_id]._reset()

        if default_gpu:
            # Save a trained model
            logger.info("** ** * Saving fine - tuned model ** ** * ")
            model_to_save = (
                model.module if hasattr(model, "module") else model
            )  # Only save the model it-self
            output_model_file = os.path.join(
                savePath, "pytorch_model_" + str(epochId) + ".bin"
            )
            output_checkpoint = os.path.join(savePath, "pytorch_ckpt_latest.tar")
            torch.save(model_to_save.state_dict(), output_model_file)
            torch.save(
                {
                    "model_state_dict": model_to_save.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "warmup_scheduler_state_dict": warmup_scheduler.state_dict(),
                    # 'lr_scheduler_state_dict': lr_scheduler.state_dict(),
                    "global_step": global_step,
                    "epoch_id": epochId,
                    "task_stop_controller": task_stop_controller,
                    "tb_logger": tbLogger,
                },
                output_checkpoint,
            )
    tbLogger.txt_close()
Пример #8
0
def trainer(args, train_dataloader, valid_dataloader, model):
    """Train model"""
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    if args.total_num_update > 0:
        t_total = args.total_num_update
        length = len(train_dataloader)
        args.epoch = args.total_num_update // (
            length // args.gradient_accumulation_steps) + 1
    else:
        length = len(train_dataloader)
        t_total = length // args.gradient_accumulation_steps * args.epochs

    # Prepare optimizer and schedule(polynomial_decay and warmup)
    no_decay = ["bias", "LayerNorm.weight"]  # 优化器这一块需要和fairseq进行对比和修正

    optimizer_grouped_parameters = [{
        "params": [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay":
        args.weight_decay
    }, {
        "params": [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        "weight_decay":
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_eps)
    scheduler = get_polynomial_decay_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    if os.path.isfile(os.path.join(
            args.output_dir, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.output_dir, "scheduler.pt")):
        # load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.output_dir, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.output_dir, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from http://wwww.github.com/nvidia/apex/ to use"
                "fp16 training")

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training(should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training(should be after apex fp16 initialization)  这一块不是很懂
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # training start
    logger.info("****** Running training ******")
    logger.info(" Num example = %d ",
                len(train_dataloader) * args.per_gpu_train_batch_size)
    logger.info(" Num Epoch = %d", args.epochs)
    logger.info(" Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)

    logger.info(
        " Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.per_gpu_train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))

    logger.info(" Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info(" Total optimizer steps = %d", t_total)

    global_step = 1
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.output_dir):
        try:
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (length //
                                             args.gradient_accumulation_steps)

            logger.info(
                " continue training from checkpoint, will skip to save global_steps"
            )
            logger.info(" Continue training from epoch %d", epochs_trained)
            logger.info(" Continue training from global steps %d", global_step)
            logger.info(" Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info(" Starting fine-tuning")

    tr_loss, logging_loss = 0.0, 0.0  # 全局的损失和日志的损失
    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.epochs),
                            desc="Epochs",
                            disable=args.local_rank not in [-1, 0])
    # Added here for reproductibility
    set_seed(args)

    best_valid_sum_loss = float("inf")
    best_valid_loss = float("inf")
    # 在训练的时候同时加载两个dataloader
    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iterator",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            # Skip past any already trained steps id resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch = -1
                continue

            model.train()

            source_ids = batch["input_ids"].long.to(args.device)
            attention_mask = batch["attention_mask"].long.to(args.device)
            decoder_input_ids = batch["target_ids"][:, :-1].long().to(
                args.device)
            summary_labels = batch["target_ids"][:, :-1].contiguous().long(
            ).to(args.device)
            segment_ids = batch["segment_ids"].long().to(args.device)
            start_positions = batch["start_positions"].long().to(args.device)
            end_positions = batch["end_positions"].long().to(args.device)
            sentence_start_positions = batch["sentence_start_positions"].long(
            ).to(args.device)
            sentence_end_positions = batch["sentence_end_positions"].long().to(
                args.device)

            inputs = {
                "input_ids": source_ids,
                "attention_mask": attention_mask,
                "segment_ids": segment_ids,
                "decoder_input_ids": decoder_input_ids,
                "summary_labels": summary_labels,
                "start_positions": start_positions,
                "end_positions": end_positions,
                "sentence_start_positions": sentence_start_positions,
                "sentence_end_positions": sentence_end_positions
            }

            outputs = model(**inputs)

            loss, qa_loss, sum_loss = outputs["qa_loss"], outputs[
                "summary_loss"], outputs["loss"]
            if args.n_gpu > 1:
                loss.mean()

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            epoch_iterator.set_description(
                "epoch:{}, global_step:{}, qa_loss:{}, sum_loss:{}, loss:{}".
                format(epoch, global_step, qa_loss, sum_loss, loss))

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1

                # Log Metrics
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps,
                                         global_step)  # 日志步数之内的平均损失
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)

                    # Take care of distribution/parallel training
                    model_to_save = model.module if hasattr(
                        model, "module") else model
                    model_to_save.save_pretrained(output_dir)

                    valid_loss, valid_sum_loss, valid_qa_loss = validation(
                        valid_dataloader, model, args)

                    # 保存在摘要数据上表现好的模型
                    if valid_sum_loss < best_valid_sum_loss:
                        best_valid_sum_loss = valid_sum_loss
                        logger.info('Saving best valid summary loss model')
                        best_sum_loss_dir = os.path.join(
                            args.output_dir, "best_sum")
                        if not os.path.exists(best_sum_loss_dir):
                            os.makedirs(best_sum_loss_dir)
                        model_to_save = model.module if hasattr(
                            model, "module") else model
                        model_to_save.save_pretrained(best_sum_loss_dir)

                    # 保存在总体损失上的最小值
                    if valid_loss < best_valid_loss:
                        best_valid_loss = valid_loss
                        logger.info('Saving best valid loss model')
                        best_valid_loss_dir = os.path.join(
                            args.output_dir, "best_loss")
                        if not os.path.exists(best_valid_loss_dir):
                            os.makedirs(best_valid_loss_dir)

                        model_to_save = model.module if hasattr(
                            model, "module") else model
                        model_to_save.save_pretrianed(best_valid_loss_dir)

            if (args.total_num_update > 0) and (global_step >
                                                args.total_num_update):
                epoch_iterator.close()
                break

        if (args.total_num_update > 0) and (global_step >
                                            args.total_num_update):
            epoch_iterator.close()
            break

    tb_writer.close()
    logger.info("training has done!")
def model_train_validate_test(train_df,
                              dev_df,
                              test_df,
                              target_dir,
                              max_seq_len=50,
                              epochs=3,
                              batch_size=32,
                              lr=2e-05,
                              patience=1,
                              max_grad_norm=10.0,
                              if_save_model=True,
                              checkpoint=None):
    """
    Parameters
    ----------
    train_df : pandas dataframe of train set.
    dev_df : pandas dataframe of dev set.
    test_df : pandas dataframe of test set.
    target_dir : the path where you want to save model.
    max_seq_len: the max truncated length.
    epochs : the default is 3.
    batch_size : the default is 32.
    lr : learning rate, the default is 2e-05.
    patience : the default is 1.
    max_grad_norm : the default is 10.0.
    if_save_model: if save the trained model to the target dir.
    checkpoint : the default is None.

    """

    bertmodel = BertModel(requires_grad=True)
    tokenizer = bertmodel.tokenizer

    print(20 * "=", " Preparing for training ", 20 * "=")
    # Path to save the model, create a folder if not exist.
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    # -------------------- Data loading --------------------------------------#

    print("\t* Loading training data...")
    train_data = DataPrecessForSentence(tokenizer,
                                        train_df,
                                        max_seq_len=max_seq_len)
    train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)

    print("\t* Loading validation data...")
    dev_data = DataPrecessForSentence(tokenizer,
                                      dev_df,
                                      max_seq_len=max_seq_len)
    dev_loader = DataLoader(dev_data, shuffle=True, batch_size=batch_size)

    print("\t* Loading test data...")
    test_data = DataPrecessForSentence(tokenizer,
                                       test_df,
                                       max_seq_len=max_seq_len)
    test_loader = DataLoader(test_data, shuffle=False, batch_size=batch_size)

    # -------------------- Model definition ------------------- --------------#

    print("\t* Building model...")
    device = torch.device("cuda")
    model = bertmodel.to(device)

    # -------------------- Preparation for training  -------------------------#

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)

    ## Implement of warm up
    ## total_steps = len(train_loader) * epochs
    ## scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=60, num_training_steps=total_steps)

    # When the monitored value is not improving, the network performance could be improved by reducing the learning rate.
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode="max",
                                                           factor=0.85,
                                                           patience=0)

    best_score = 0.0
    start_epoch = 1
    # Data for loss curves plot
    epochs_count = []
    train_losses = []
    train_accuracies = []
    valid_losses = []
    valid_accuracies = []
    valid_aucs = []

    # Continuing training from a checkpoint if one was given as argument
    if checkpoint:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint["epoch"] + 1
        best_score = checkpoint["best_score"]
        print("\t* Training will continue on existing model from epoch {}...".
              format(start_epoch))
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epochs_count = checkpoint["epochs_count"]
        train_losses = checkpoint["train_losses"]
        train_accuracy = checkpoint["train_accuracy"]
        valid_losses = checkpoint["valid_losses"]
        valid_accuracy = checkpoint["valid_accuracy"]
        valid_auc = checkpoint["valid_auc"]

    # Compute loss and accuracy before starting (or resuming) training.
    _, valid_loss, valid_accuracy, auc, _, = validate(model, dev_loader)
    print(
        "\n* Validation loss before training: {:.4f}, accuracy: {:.4f}%, auc: {:.4f}"
        .format(valid_loss, (valid_accuracy * 100), auc))

    # -------------------- Training epochs -----------------------------------#

    print("\n", 20 * "=", "Training bert model on device: {}".format(device),
          20 * "=")
    patience_counter = 0
    for epoch in range(start_epoch, epochs + 1):
        epochs_count.append(epoch)

        print("* Training epoch {}:".format(epoch))
        epoch_time, epoch_loss, epoch_accuracy = train(model, train_loader,
                                                       optimizer, epoch,
                                                       max_grad_norm)
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_accuracy)
        print("-> Training time: {:.4f}s, loss = {:.4f}, accuracy: {:.4f}%".
              format(epoch_time, epoch_loss, (epoch_accuracy * 100)))

        print("* Validation for epoch {}:".format(epoch))
        epoch_time, epoch_loss, epoch_accuracy, epoch_auc, _, = validate(
            model, dev_loader)
        valid_losses.append(epoch_loss)
        valid_accuracies.append(epoch_accuracy)
        valid_aucs.append(epoch_auc)
        print(
            "-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%, auc: {:.4f}\n"
            .format(epoch_time, epoch_loss, (epoch_accuracy * 100), epoch_auc))

        # Update the optimizer's learning rate with the scheduler.
        scheduler.step(epoch_accuracy)
        ## scheduler.step()

        # Early stopping on validation accuracy.
        if epoch_accuracy < best_score:
            patience_counter += 1
        else:
            best_score = epoch_accuracy
            patience_counter = 0
            if (if_save_model):
                torch.save(
                    {
                        "epoch": epoch,
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "best_score": best_score,
                        "epochs_count": epochs_count,
                        "train_losses": train_losses,
                        "train_accuracy": train_accuracies,
                        "valid_losses": valid_losses,
                        "valid_accuracy": valid_accuracies,
                        "valid_auc": valid_aucs
                    }, os.path.join(target_dir, "best.pth.tar"))
                print("save model succesfully!\n")

            # run model on test set and save the prediction result to csv
            print("* Test for epoch {}:".format(epoch))
            _, _, test_accuracy, _, all_prob = validate(model, test_loader)
            print("Test accuracy: {:.4f}%\n".format(test_accuracy))
            test_prediction = pd.DataFrame({'prob_1': all_prob})
            test_prediction['prob_0'] = 1 - test_prediction['prob_1']
            test_prediction['prediction'] = test_prediction.apply(
                lambda x: 0 if (x['prob_0'] > x['prob_1']) else 1, axis=1)
            test_prediction = test_prediction[[
                'prob_0', 'prob_1', 'prediction'
            ]]
            test_prediction.to_csv(os.path.join(target_dir,
                                                "test_prediction.csv"),
                                   index=False)

        if patience_counter >= patience:
            print("-> Early stopping: patience limit reached, stopping...")
            break
Пример #10
0
    def train(
        self,
        train_dataset,
        output_dir,
        show_running_loss=True,
        eval_data=None,
        verbose=True,
        **kwargs,
    ):
        """
        Trains the model on train_dataset.

        Utility function to be used by the train_model() method. Not intended to be used directly.
        """

        model = self.model
        args = self.args
        device = self.device

        tb_writer = SummaryWriter(logdir=args.tensorboard_dir)
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(
            train_dataset,
            sampler=train_sampler,
            batch_size=args.train_batch_size,
            num_workers=self.args.dataloader_num_workers,
        )

        if args.max_steps > 0:
            t_total = args.max_steps
            args.num_train_epochs = (
                args.max_steps
                // (len(train_dataloader) // args.gradient_accumulation_steps)
                + 1
            )
        else:
            t_total = (
                len(train_dataloader)
                // args.gradient_accumulation_steps
                * args.num_train_epochs
            )

        no_decay = ["bias", "LayerNorm.weight"]

        optimizer_grouped_parameters = []
        custom_parameter_names = set()
        for group in self.args.custom_parameter_groups:
            params = group.pop("params")
            custom_parameter_names.update(params)
            param_group = {**group}
            param_group["params"] = [
                p for n, p in model.named_parameters() if n in params
            ]
            optimizer_grouped_parameters.append(param_group)

        for group in self.args.custom_layer_parameters:
            layer_number = group.pop("layer")
            layer = f"layer.{layer_number}."
            group_d = {**group}
            group_nd = {**group}
            group_nd["weight_decay"] = 0.0
            params_d = []
            params_nd = []
            for n, p in model.named_parameters():
                if n not in custom_parameter_names and layer in n:
                    if any(nd in n for nd in no_decay):
                        params_nd.append(p)
                    else:
                        params_d.append(p)
                    custom_parameter_names.add(n)
            group_d["params"] = params_d
            group_nd["params"] = params_nd

            optimizer_grouped_parameters.append(group_d)
            optimizer_grouped_parameters.append(group_nd)

        if not self.args.train_custom_parameters_only:
            optimizer_grouped_parameters.extend(
                [
                    {
                        "params": [
                            p
                            for n, p in model.named_parameters()
                            if n not in custom_parameter_names
                            and not any(nd in n for nd in no_decay)
                        ],
                        "weight_decay": args.weight_decay,
                    },
                    {
                        "params": [
                            p
                            for n, p in model.named_parameters()
                            if n not in custom_parameter_names
                            and any(nd in n for nd in no_decay)
                        ],
                        "weight_decay": 0.0,
                    },
                ]
            )

        warmup_steps = math.ceil(t_total * args.warmup_ratio)
        args.warmup_steps = (
            warmup_steps if args.warmup_steps == 0 else args.warmup_steps
        )

        if args.optimizer == "AdamW":
            optimizer = AdamW(
                optimizer_grouped_parameters,
                lr=args.learning_rate,
                eps=args.adam_epsilon,
            )
        elif args.optimizer == "Adafactor":
            optimizer = Adafactor(
                optimizer_grouped_parameters,
                lr=args.learning_rate,
                eps=args.adafactor_eps,
                clip_threshold=args.adafactor_clip_threshold,
                decay_rate=args.adafactor_decay_rate,
                beta1=args.adafactor_beta1,
                weight_decay=args.weight_decay,
                scale_parameter=args.adafactor_scale_parameter,
                relative_step=args.adafactor_relative_step,
                warmup_init=args.adafactor_warmup_init,
            )
            print("Using Adafactor for T5")
        else:
            raise ValueError(
                "{} is not a valid optimizer class. Please use one of ('AdamW', 'Adafactor') instead.".format(
                    args.optimizer
                )
            )

        if args.scheduler == "constant_schedule":
            scheduler = get_constant_schedule(optimizer)

        elif args.scheduler == "constant_schedule_with_warmup":
            scheduler = get_constant_schedule_with_warmup(
                optimizer, num_warmup_steps=args.warmup_steps
            )

        elif args.scheduler == "linear_schedule_with_warmup":
            scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=args.warmup_steps,
                num_training_steps=t_total,
            )

        elif args.scheduler == "cosine_schedule_with_warmup":
            scheduler = get_cosine_schedule_with_warmup(
                optimizer,
                num_warmup_steps=args.warmup_steps,
                num_training_steps=t_total,
                num_cycles=args.cosine_schedule_num_cycles,
            )

        elif args.scheduler == "cosine_with_hard_restarts_schedule_with_warmup":
            scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
                optimizer,
                num_warmup_steps=args.warmup_steps,
                num_training_steps=t_total,
                num_cycles=args.cosine_schedule_num_cycles,
            )

        elif args.scheduler == "polynomial_decay_schedule_with_warmup":
            scheduler = get_polynomial_decay_schedule_with_warmup(
                optimizer,
                num_warmup_steps=args.warmup_steps,
                num_training_steps=t_total,
                lr_end=args.polynomial_decay_schedule_lr_end,
                power=args.polynomial_decay_schedule_power,
            )

        else:
            raise ValueError("{} is not a valid scheduler.".format(args.scheduler))

        if (
            args.model_name
            and os.path.isfile(os.path.join(args.model_name, "optimizer.pt"))
            and os.path.isfile(os.path.join(args.model_name, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(args.model_name, "optimizer.pt"))
            )
            scheduler.load_state_dict(
                torch.load(os.path.join(args.model_name, "scheduler.pt"))
            )

        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        logger.info(" Training started")

        global_step = 0
        training_progress_scores = None
        tr_loss, logging_loss = 0.0, 0.0
        model.zero_grad()
        train_iterator = trange(
            int(args.num_train_epochs), desc="Epoch", disable=args.silent, mininterval=0
        )
        epoch_number = 0
        best_eval_metric = None
        early_stopping_counter = 0
        steps_trained_in_current_epoch = 0
        epochs_trained = 0

        if args.model_name and os.path.exists(args.model_name):
            try:
                # set global_step to gobal_step of last saved checkpoint from model path
                checkpoint_suffix = args.model_name.split("/")[-1].split("-")
                if len(checkpoint_suffix) > 2:
                    checkpoint_suffix = checkpoint_suffix[1]
                else:
                    checkpoint_suffix = checkpoint_suffix[-1]
                global_step = int(checkpoint_suffix)
                epochs_trained = global_step // (
                    len(train_dataloader) // args.gradient_accumulation_steps
                )
                steps_trained_in_current_epoch = global_step % (
                    len(train_dataloader) // args.gradient_accumulation_steps
                )

                logger.info(
                    "   Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("   Continuing training from epoch %d", epochs_trained)
                logger.info("   Continuing training from global step %d", global_step)
                logger.info(
                    "   Will skip the first %d steps in the current epoch",
                    steps_trained_in_current_epoch,
                )
            except ValueError:
                logger.info("   Starting fine-tuning.")

        if args.evaluate_during_training:
            training_progress_scores = self._create_training_progress_scores(**kwargs)

        if args.wandb_project:
            wandb.init(
                project=args.wandb_project,
                config={**asdict(args)},
                **args.wandb_kwargs,
            )
            wandb.run._label(repo="simpletransformers")
            wandb.watch(self.model)

        if args.fp16:
            from torch.cuda import amp

            scaler = amp.GradScaler()

        for current_epoch in train_iterator:
            model.train()
            if epochs_trained > 0:
                epochs_trained -= 1
                continue
            train_iterator.set_description(
                f"Epoch {epoch_number + 1} of {args.num_train_epochs}"
            )
            batch_iterator = tqdm(
                train_dataloader,
                desc=f"Running Epoch {epoch_number} of {args.num_train_epochs}",
                disable=args.silent,
                mininterval=0,
            )
            for step, batch in enumerate(batch_iterator):
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                inputs = self._get_inputs_dict(batch)
                if args.fp16:
                    with amp.autocast():
                        outputs = model(**inputs)
                        # model outputs are always tuple in pytorch-transformers (see doc)
                        loss = outputs[0]
                else:
                    outputs = model(**inputs)
                    # model outputs are always tuple in pytorch-transformers (see doc)
                    loss = outputs[0]

                if args.n_gpu > 1:
                    loss = (
                        loss.mean()
                    )  # mean() to average on multi-gpu parallel training

                current_loss = loss.item()

                if show_running_loss:
                    batch_iterator.set_description(
                        f"Epochs {epoch_number}/{args.num_train_epochs}. Running Loss: {current_loss:9.4f}"
                    )

                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()

                tr_loss += loss.item()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        scaler.unscale_(optimizer)
                    if args.optimizer == "AdamW":
                        torch.nn.utils.clip_grad_norm_(
                            model.parameters(), args.max_grad_norm
                        )

                    if args.fp16:
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        # Log metrics
                        tb_writer.add_scalar(
                            "lr", scheduler.get_last_lr()[0], global_step
                        )
                        tb_writer.add_scalar(
                            "loss",
                            (tr_loss - logging_loss) / args.logging_steps,
                            global_step,
                        )
                        logging_loss = tr_loss
                        if args.wandb_project or self.is_sweeping:
                            wandb.log(
                                {
                                    "Training loss": current_loss,
                                    "lr": scheduler.get_last_lr()[0],
                                    "global_step": global_step,
                                }
                            )

                    if args.save_steps > 0 and global_step % args.save_steps == 0:
                        # Save model checkpoint
                        output_dir_current = os.path.join(
                            output_dir, "checkpoint-{}".format(global_step)
                        )

                        self.save_model(
                            output_dir_current, optimizer, scheduler, model=model
                        )

                    if args.evaluate_during_training and (
                        args.evaluate_during_training_steps > 0
                        and global_step % args.evaluate_during_training_steps == 0
                    ):
                        # Only evaluate when single GPU otherwise metrics may not average well
                        results = self.eval_model(
                            eval_data,
                            verbose=verbose and args.evaluate_during_training_verbose,
                            silent=args.evaluate_during_training_silent,
                            **kwargs,
                        )
                        for key, value in results.items():
                            try:
                                tb_writer.add_scalar(
                                    "eval_{}".format(key), value, global_step
                                )
                            except (NotImplementedError, AssertionError):
                                pass

                        output_dir_current = os.path.join(
                            output_dir, "checkpoint-{}".format(global_step)
                        )

                        if args.save_eval_checkpoints:
                            self.save_model(
                                output_dir_current,
                                optimizer,
                                scheduler,
                                model=model,
                                results=results,
                            )

                        training_progress_scores["global_step"].append(global_step)
                        training_progress_scores["train_loss"].append(current_loss)
                        for key in results:
                            training_progress_scores[key].append(results[key])
                        report = pd.DataFrame(training_progress_scores)
                        report.to_csv(
                            os.path.join(
                                args.output_dir, "training_progress_scores.csv"
                            ),
                            index=False,
                        )

                        if args.wandb_project or self.is_sweeping:
                            wandb.log(self._get_last_metrics(training_progress_scores))

                        if not best_eval_metric:
                            best_eval_metric = results[args.early_stopping_metric]
                            self.save_model(
                                args.best_model_dir,
                                optimizer,
                                scheduler,
                                model=model,
                                results=results,
                            )
                        if best_eval_metric and args.early_stopping_metric_minimize:
                            if (
                                results[args.early_stopping_metric] - best_eval_metric
                                < args.early_stopping_delta
                            ):
                                best_eval_metric = results[args.early_stopping_metric]
                                self.save_model(
                                    args.best_model_dir,
                                    optimizer,
                                    scheduler,
                                    model=model,
                                    results=results,
                                )
                                early_stopping_counter = 0
                            else:
                                if args.use_early_stopping:
                                    if (
                                        early_stopping_counter
                                        < args.early_stopping_patience
                                    ):
                                        early_stopping_counter += 1
                                        if verbose:
                                            logger.info(
                                                f" No improvement in {args.early_stopping_metric}"
                                            )
                                            logger.info(
                                                f" Current step: {early_stopping_counter}"
                                            )
                                            logger.info(
                                                f" Early stopping patience: {args.early_stopping_patience}"
                                            )
                                    else:
                                        if verbose:
                                            logger.info(
                                                f" Patience of {args.early_stopping_patience} steps reached"
                                            )
                                            logger.info(" Training terminated.")
                                            train_iterator.close()
                                        return (
                                            global_step,
                                            tr_loss / global_step
                                            if not self.args.evaluate_during_training
                                            else training_progress_scores,
                                        )
                        else:
                            if (
                                results[args.early_stopping_metric] - best_eval_metric
                                > args.early_stopping_delta
                            ):
                                best_eval_metric = results[args.early_stopping_metric]
                                self.save_model(
                                    args.best_model_dir,
                                    optimizer,
                                    scheduler,
                                    model=model,
                                    results=results,
                                )
                                early_stopping_counter = 0
                            else:
                                if args.use_early_stopping:
                                    if (
                                        early_stopping_counter
                                        < args.early_stopping_patience
                                    ):
                                        early_stopping_counter += 1
                                        if verbose:
                                            logger.info(
                                                f" No improvement in {args.early_stopping_metric}"
                                            )
                                            logger.info(
                                                f" Current step: {early_stopping_counter}"
                                            )
                                            logger.info(
                                                f" Early stopping patience: {args.early_stopping_patience}"
                                            )
                                    else:
                                        if verbose:
                                            logger.info(
                                                f" Patience of {args.early_stopping_patience} steps reached"
                                            )
                                            logger.info(" Training terminated.")
                                            train_iterator.close()
                                        return (
                                            global_step,
                                            tr_loss / global_step
                                            if not self.args.evaluate_during_training
                                            else training_progress_scores,
                                        )
                        model.train()

            epoch_number += 1
            output_dir_current = os.path.join(
                output_dir, "checkpoint-{}-epoch-{}".format(global_step, epoch_number)
            )

            if args.save_model_every_epoch or args.evaluate_during_training:
                os.makedirs(output_dir_current, exist_ok=True)

            if args.save_model_every_epoch:
                self.save_model(output_dir_current, optimizer, scheduler, model=model)

            if args.evaluate_during_training and args.evaluate_each_epoch:
                results = self.eval_model(
                    eval_data,
                    verbose=verbose and args.evaluate_during_training_verbose,
                    silent=args.evaluate_during_training_silent,
                    **kwargs,
                )

                if args.save_eval_checkpoints:
                    self.save_model(
                        output_dir_current, optimizer, scheduler, results=results
                    )

                training_progress_scores["global_step"].append(global_step)
                training_progress_scores["train_loss"].append(current_loss)
                for key in results:
                    training_progress_scores[key].append(results[key])
                report = pd.DataFrame(training_progress_scores)
                report.to_csv(
                    os.path.join(args.output_dir, "training_progress_scores.csv"),
                    index=False,
                )

                if args.wandb_project or self.is_sweeping:
                    wandb.log(self._get_last_metrics(training_progress_scores))

                if not best_eval_metric:
                    best_eval_metric = results[args.early_stopping_metric]
                    self.save_model(
                        args.best_model_dir,
                        optimizer,
                        scheduler,
                        model=model,
                        results=results,
                    )
                if best_eval_metric and args.early_stopping_metric_minimize:
                    if (
                        results[args.early_stopping_metric] - best_eval_metric
                        < args.early_stopping_delta
                    ):
                        best_eval_metric = results[args.early_stopping_metric]
                        self.save_model(
                            args.best_model_dir,
                            optimizer,
                            scheduler,
                            model=model,
                            results=results,
                        )
                        early_stopping_counter = 0
                    else:
                        if (
                            args.use_early_stopping
                            and args.early_stopping_consider_epochs
                        ):
                            if early_stopping_counter < args.early_stopping_patience:
                                early_stopping_counter += 1
                                if verbose:
                                    logger.info(
                                        f" No improvement in {args.early_stopping_metric}"
                                    )
                                    logger.info(
                                        f" Current step: {early_stopping_counter}"
                                    )
                                    logger.info(
                                        f" Early stopping patience: {args.early_stopping_patience}"
                                    )
                            else:
                                if verbose:
                                    logger.info(
                                        f" Patience of {args.early_stopping_patience} steps reached"
                                    )
                                    logger.info(" Training terminated.")
                                    train_iterator.close()
                                return (
                                    global_step,
                                    tr_loss / global_step
                                    if not self.args.evaluate_during_training
                                    else training_progress_scores,
                                )
                else:
                    if (
                        results[args.early_stopping_metric] - best_eval_metric
                        > args.early_stopping_delta
                    ):
                        best_eval_metric = results[args.early_stopping_metric]
                        self.save_model(
                            args.best_model_dir,
                            optimizer,
                            scheduler,
                            model=model,
                            results=results,
                        )
                        early_stopping_counter = 0
                    else:
                        if (
                            args.use_early_stopping
                            and args.early_stopping_consider_epochs
                        ):
                            if early_stopping_counter < args.early_stopping_patience:
                                early_stopping_counter += 1
                                if verbose:
                                    logger.info(
                                        f" No improvement in {args.early_stopping_metric}"
                                    )
                                    logger.info(
                                        f" Current step: {early_stopping_counter}"
                                    )
                                    logger.info(
                                        f" Early stopping patience: {args.early_stopping_patience}"
                                    )
                            else:
                                if verbose:
                                    logger.info(
                                        f" Patience of {args.early_stopping_patience} steps reached"
                                    )
                                    logger.info(" Training terminated.")
                                    train_iterator.close()
                                return (
                                    global_step,
                                    tr_loss / global_step
                                    if not self.args.evaluate_during_training
                                    else training_progress_scores,
                                )

        return (
            global_step,
            tr_loss / global_step
            if not self.args.evaluate_during_training
            else training_progress_scores,
        )
def main(train_file=os.path.join(Config.root_path, 'data/ranking/train.tsv'),
         dev_file=os.path.join(Config.root_path, 'data/ranking/dev.tsv'),
         model_path=Config.bert_model,
         epochs=10,
         batch_size=32,
         lr=2e-05,
         patience=3,
         max_grad_norm=10.0,
         checkpoint=None):
    logging.info(20 * "=" + " Preparing for training " + 20 * "=")
    bert_tokenizer = BertTokenizer.from_pretrained(Config.vocab_path,
                                                   do_lower_case=True)
    device = torch.device("cuda") if Config.is_cuda else torch.device("cpu")
    if not os.path.exists(os.path.dirname(model_path)):
        os.mkdir(os.path.dirname(model_path))
    logging.info("\t* Loading training data...")
    train_dataset = DataPrecessForSentence(bert_tokenizer=bert_tokenizer,
                                           file=train_file,
                                           max_char_len=Config.max_seq_len)
    train_dataloader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    logging.info("\t* Loading validation data...")
    dev_dataset = DataPrecessForSentence(bert_tokenizer=bert_tokenizer,
                                         file=dev_file,
                                         max_char_len=Config.max_seq_len)
    dev_dataloader = DataLoader(
        dataset=dev_dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    logging.info("\t* Building model...")
    model = BertModelTrain().to(device)

    # 待优化的参数
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode="max",
                                                           factor=0.85,
                                                           patience=0)
    best_score = 0.0
    start_epoch = 1
    # Data for loss curves plot
    epochs_count = []
    train_losses = []
    valid_losses = []
    # Continuing training from a checkpoint if one was given as argument
    if checkpoint:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint["epoch"] + 1
        best_score = checkpoint["best_score"]
        logging.info(
            "\t* Training will continue on existing model from epoch {}...".
            format(start_epoch))
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epochs_count = checkpoint["epochs_count"]
        train_losses = checkpoint["train_losses"]
        valid_losses = checkpoint["valid_losses"]
        # Compute loss and accuracy before starting (or resuming) training.
        _, valid_loss, valid_accuracy, auc = validate(model, dev_dataloader)
        logging.info(
            "\t* Validation loss before training: {:.4f}, accuracy: {:.4f}%, \
            auc: {:.4f}".format(valid_loss, (valid_accuracy * 100), auc))

    # -------------------- Training epochs ------------------- #
    logging.info("\n" + 20 * "=" +
                 "Training Bert model on device: {}".format(device) + 20 * "=")
    patience_counter = 0
    for i in range(start_epoch, epochs + 1):
        logging.info("* starting training epoch {}".format(i))
        train_time, train_loss, train_acc = train(
            model=model,
            dataloader=train_dataloader,
            optimizer=optimizer,
            epoch_number=i,
            max_gradient_norm=max_grad_norm)
        train_losses.append(train_loss)
        logging.info("-> Training time: {:.4f}s, loss = {:.4f}, \
            accuracy: {:.4f}%".format(train_time, train_loss,
                                      (train_acc * 100)))

        logging.info("* Validation for epoch {}:".format(i))
        val_time, val_loss, val_acc, score = validate(
            model=model, dataloader=dev_dataloader)
        valid_losses.append(val_loss)
        logging.info(
            "-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%, \
            auc: {:.4f}\n".format(val_time, val_loss, (val_acc * 100), score))
        scheduler.step(val_acc)
        # Early stopping on validation accuracy.
        if val_acc < best_score:
            patience_counter += 1
        else:
            best_score = val_acc
            patience_counter = 0
            torch.save(
                {
                    "epoch": i,
                    "model": model.state_dict(),
                    "best_score": best_score,
                    "epochs_count": epochs_count,
                    "train_losses": train_losses,
                    "valid_losses": valid_losses
                }, model_path)
        if patience_counter >= patience:
            logging.info(
                "-> Early stopping: patience limit reached, stopping...")
            break
Пример #12
0
def train(model, config, args):
    # 시작 Epoch
    num_params = count_parameters(model)
    logger.info("Total Parameter: %d" % num_params)

    train_dataset, _ = load_and_cache_examples(config, tokenizer, 'train')
    test_dataset, test_features = load_and_cache_examples(
        config, tokenizer, 'test')

    num_train_optimization_steps = int(
        len(train_dataset) / train_batch_size) * num_train_epochs

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=learning_rate,
                      eps=adam_epsilon)
    scheduler = WarmupLinearSchedule(
        optimizer,
        warmup_steps=num_train_optimization_steps * 0.1,
        t_total=num_train_optimization_steps)
    global_step = 0
    start_epoch = 0
    start_step = 0
    # Reformer NER 모델
    args.load_ner_checkpoint = False
    if os.path.isfile(f'{config.checkpoint_path}/{config.model_name}.pth'
                      ) and args.load_ner_checkpoint is False:
        checkpoint = torch.load(
            f'{config.checkpoint_path}/{config.model_name}.pth',
            map_location=device)
        model.reformer.load_state_dict(checkpoint['model_state_dict'],
                                       strict=False)
        logger.info(f'Load Reformer Model')

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=train_batch_size,
                                  drop_last=True)

    if args.load_ner_checkpoint:
        checkpoint = torch.load(config.ner_checkpoint_path,
                                map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        start_epoch = checkpoint['epoch']
        global_step = checkpoint['train_step']
        start_step = global_step if start_epoch == 0 else global_step * train_batch_size % len(
            train_dataloader)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        logger.info(f'Load Reformer[NER] Model')

    logger.info("***** Running training *****")
    logger.info("  Batch size = %d", train_batch_size)
    logger.info("  Num steps = %d", num_train_optimization_steps)
    num_train_step = num_train_optimization_steps

    test_texts = get_test_texts(config.data_dir, config.test_file)

    for epoch in range(start_epoch, int(num_train_epochs)):
        global_step, model = train_on_epoch(model, train_dataloader, optimizer,
                                            scheduler, start_step, global_step,
                                            epoch, num_train_step)
        start_step = 0
        eval(model, config, epoch, test_dataset, test_texts, train_batch_size)
Пример #13
0
def mtl_train(args, config, train_set, dev_set, label_map, bert_model,
              clf_head):
    save_dir = "./models/{}".format(utils.get_savedir_name())
    tb_writer = SummaryWriter(os.path.join(save_dir, "logs"))

    train_set = ConcatDataset(train_set)
    train_loader = DataLoader(
        dataset=train_set,
        sampler=utils.BalancedTaskSampler(dataset=train_set,
                                          batch_size=config.batch_size),
        batch_size=config.batch_size,
        collate_fn=utils.collate_fn,
        shuffle=False,
        num_workers=0,
    )
    dev_set = ConcatDataset(dev_set)
    dev_loader = DataLoader(
        dataset=dev_set,
        batch_size=config.batch_size,
        collate_fn=utils.collate_fn,
        shuffle=False,
        num_workers=0,
    )
    num_epochs = config.num_epochs

    if not config.finetune_enc:
        for param in bert_model.parameters():
            param.requires_grad = False
        extra = []
    else:
        extra = list(bert_model.named_parameters())

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in list(clf_head.named_parameters()) + extra
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            config.weight_decay,
        },
        {
            "params": [
                p for n, p in list(clf_head.named_parameters()) + extra
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]
    opt = AdamW(optimizer_grouped_parameters, eps=1e-8, lr=config.outer_lr)

    best_dev_error = np.inf
    if args.load_from:
        state_obj = torch.load(os.path.join(args.load_from, "optim.th"))
        opt.load_state_dict(state_obj["optimizer"])
        num_epochs = num_epochs - state_obj["last_epoch"]
        bert_model = bert_model.eval()
        clf_head = clf_head.eval()
        dev_loss, dev_metrics = utils.compute_loss_metrics(
            dev_loader,
            bert_model,
            clf_head,
            label_map,
            grad_required=False,
            return_metrics=False,
        )
        best_dev_error = dev_loss.mean()

    patience_ctr = 0
    for epoch in range(num_epochs):
        running_loss = 0.0
        epoch_iterator = tqdm(train_loader, desc="Training")
        for (
                train_step,
            (input_ids, attention_mask, token_type_ids, labels, _, _),
        ) in enumerate(epoch_iterator):
            # train
            bert_model.train()
            clf_head.train()
            opt.zero_grad()
            bert_output = bert_model(input_ids, attention_mask, token_type_ids)
            output = clf_head(bert_output,
                              labels=labels,
                              attention_mask=attention_mask)
            loss = output.loss.mean()
            loss.backward()
            if config.finetune_enc:
                torch.nn.utils.clip_grad_norm_(bert_model.parameters(),
                                               config.max_grad_norm)
            torch.nn.utils.clip_grad_norm_(clf_head.parameters(),
                                           config.max_grad_norm)
            opt.step()
            running_loss += loss.item()
            # eval at the beginning of every epoch and after every `config.eval_freq` steps
            if train_step % config.eval_freq == 0:
                bert_model.eval()
                clf_head.eval()
                dev_loss, dev_metrics = utils.compute_loss_metrics(
                    dev_loader,
                    bert_model,
                    clf_head,
                    label_map,
                    grad_required=False,
                    return_metrics=False,
                )
                dev_loss = dev_loss.mean()

                tb_writer.add_scalar("metrics/loss", dev_loss, epoch)
                if dev_metrics is not None:
                    tb_writer.add_scalar("metrics/precision",
                                         dev_metrics["precision"], epoch)
                    tb_writer.add_scalar("metrics/recall",
                                         dev_metrics["recall"], epoch)
                    tb_writer.add_scalar("metrics/f1", dev_metrics["f1"],
                                         epoch)
                    logger.info(
                        "Dev. metrics (p/r/f): {:.3f} {:.3f} {:.3f}".format(
                            dev_metrics["precision"],
                            dev_metrics["recall"],
                            dev_metrics["f1"],
                        ))

                if dev_loss < best_dev_error:
                    logger.info("Found new best model!")
                    best_dev_error = dev_loss
                    save(clf_head, opt, args.config_path, epoch, bert_model)
                    patience_ctr = 0
                else:
                    patience_ctr += 1
                    if patience_ctr == config.patience:
                        logger.info(
                            "Ran out of patience. Stopping training early...")
                        return

        logger.info(
            f"Finished epoch {epoch+1} with avg. training loss: {running_loss/(train_step + 1)}"
        )

    logger.info(f"Best validation loss = {best_dev_error}")
    logger.info("Best model saved at: {}".format(utils.get_savedir_name()))
def model_train_validate_test(train_df, dev_df, test_df, target_dir, 
         max_seq_len=64,
         num_labels=2,
         epochs=10,
         batch_size=32,
         lr=2e-05,
         patience=1,
         max_grad_norm=10.0,
         if_save_model=True,
         checkpoint=None):

    bertmodel = DistilBertModel(requires_grad = True, num_labels = num_labels)
    tokenizer = bertmodel.tokenizer
    
    print(20 * "=", " Preparing for training ", 20 * "=")
    # 保存模型的路径,没有则创建文件夹
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)
    # -------------------- Data loading ------------------- #
    print("\t* Loading training data...")
    train_data = DataPrecessForSentence(tokenizer, train_df, max_seq_len)
    train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)

    print("\t* Loading validation data...")
    dev_data = DataPrecessForSentence(tokenizer,dev_df, max_seq_len)
    dev_loader = DataLoader(dev_data, shuffle=True, batch_size=batch_size)
    
    print("\t* Loading test data...")
    test_data = DataPrecessForSentence(tokenizer,test_df, max_seq_len) 
    test_loader = DataLoader(test_data, shuffle=False, batch_size=batch_size)
    # -------------------- Model definition ------------------- #
    print("\t* Building model...")
    device = torch.device("cuda")
    model = bertmodel.to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print(f'{total_params:,} total parameters.')
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'{total_trainable_params:,} training parameters.')
    # -------------------- Preparation for training  ------------------- #
    # 待优化的参数
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
            {
                    'params':[p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                    'weight_decay':0.01
            },
            {
                    'params':[p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
                    'weight_decay':0.0
            }
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)

    # 当网络的评价指标不在提升的时候,可以通过降低网络的学习率来提高网络性能
    # warmup_steps = math.ceil(len(train_loader) * epochs * 0.1)
    # total_steps = len(train_loader) * epochs
    # scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.85, patience=2, verbose=True)

    best_score = 0.0
    start_epoch = 1
    # Data for loss curves plot
    epochs_count = []
    train_losses = []
    train_accuracies = []
    valid_losses = []
    valid_accuracies = []
    # Continuing training from a checkpoint if one was given as argument
    if checkpoint:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint["epoch"] + 1
        best_score = checkpoint["best_score"]
        print("\t* Training will continue on existing model from epoch {}...".format(start_epoch))
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epochs_count = checkpoint["epochs_count"]
        train_losses = checkpoint["train_losses"]
        train_accuracy = checkpoint["train_accuracy"]
        valid_losses = checkpoint["valid_losses"]
        valid_accuracy = checkpoint["valid_accuracy"]
     # Compute loss and accuracy before starting (or resuming) training.
    _, valid_loss, valid_accuracy, _, = validate(model, dev_loader)
    print("\n* Validation loss before training: {:.4f}, accuracy: {:.4f}%".format(valid_loss, (valid_accuracy*100)))
    # -------------------- Training epochs ------------------- #
    print("\n", 20 * "=", "Training roberta model on device: {}".format(device), 20 * "=")
    patience_counter = 0
    for epoch in range(start_epoch, epochs + 1):
        epochs_count.append(epoch)

        print("* Training epoch {}:".format(epoch))
        epoch_time, epoch_loss, epoch_accuracy = train(model, train_loader, optimizer, epoch, max_grad_norm)
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_accuracy)
        
        print("-> Training time: {:.4f}s, loss = {:.4f}, accuracy: {:.4f}%".format(epoch_time, epoch_loss, (epoch_accuracy*100)))
        
        print("* Validation for epoch {}:".format(epoch))
        epoch_time, epoch_loss, epoch_accuracy, _, = validate(model, dev_loader)
        valid_losses.append(epoch_loss)
        valid_accuracies.append(epoch_accuracy)
        print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n"
              .format(epoch_time, epoch_loss, (epoch_accuracy*100)))
        
        # Update the optimizer's learning rate with the scheduler.
        # scheduler.step()
        scheduler.step(epoch_accuracy)
        # Early stopping on validation accuracy.
        if epoch_accuracy < best_score:
            patience_counter += 1
        else:
            best_score = epoch_accuracy
            patience_counter = 0
            
            if (if_save_model):
                torch.save({"epoch": epoch, 
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "best_score": best_score, # 验证集上的最优准确率
                        "epochs_count": epochs_count,
                        "train_losses": train_losses,
                        "train_accuracy": train_accuracies,
                        "valid_losses": valid_losses,
                        "valid_accuracy": valid_accuracies
                        },
                        os.path.join(target_dir, "best.pth.tar"))
                print("save model succesfully!\n")
            
            print("* Test for epoch {}:".format(epoch))
            _, _, test_accuracy, predictions = validate(model, test_loader)
            print("Test accuracy: {:.4f}%\n".format(test_accuracy))
            test_prediction = pd.DataFrame({'prediction':predictions})
            test_prediction.to_csv(os.path.join(target_dir,"test_prediction.csv"), index=False)
             
        if patience_counter >= patience:
            print("-> Early stopping: patience limit reached, stopping...")
            break
Пример #15
0
            print("==========================================")

        if ckpt_every > 0 and len(total_score_history) > ckpt_lookback:
            current_score = np.mean(total_score_history[-ckpt_lookback:])

            if time.time() - time_ckpt > ckpt_every:
                revert_ckpt = best_ckpt_score is not None and current_score < min(
                    1.2 * best_ckpt_score,
                    0.8 * best_ckpt_score)  # Could be negative or positive
                print("================================== CKPT TIME, " +
                      str(datetime.now()) +
                      " =================================")
                print("Previous best:", best_ckpt_score)
                print("Current Score:", current_score)
                print("[CKPT] Am I reverting?",
                      ("yes" if revert_ckpt else "no! BEST CKPT"))
                if revert_ckpt:
                    summarizer.model.load_state_dict(torch.load(ckpt_file))
                    optimizer.load_state_dict(torch.load(ckpt_optimizer_file))
                time_ckpt = time.time()
                print(
                    "=============================================================================="
                )

            if best_ckpt_score is None or current_score > best_ckpt_score:
                print("[CKPT] Saved new best at: %.3f %s" %
                      (current_score, "[" + str(datetime.now()) + "]"))
                best_ckpt_score = current_score
                torch.save(summarizer.model.state_dict(), ckpt_file)
                torch.save(optimizer.state_dict(), ckpt_optimizer_file)
Пример #16
0
def train(args, model, train_features, dev_features, test_features):
    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_ and args.log_dir != '':
            with open(args.log_dir, 'a+') as f_log:
                f_log.write(s + '\n')
    def finetune(features, optimizer, num_epoch, num_steps, model):
        cur_model = model.module if hasattr(model, 'module') else model
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        if args.train_from_saved_model != '':
            best_score = torch.load(args.train_from_saved_model)["best_f1"]
            epoch_delta = torch.load(args.train_from_saved_model)["epoch"] + 1
        else:
            epoch_delta = 0
            best_score = -1
        train_dataloader = DataLoader(features, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)
        train_iterator = [epoch + epoch_delta for epoch in range(num_epoch)]
        total_steps = int(len(train_dataloader) * num_epoch // args.gradient_accumulation_steps)
        warmup_steps = int(total_steps * args.warmup_ratio)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
        print("Total steps: {}".format(total_steps))
        print("Warmup steps: {}".format(warmup_steps))
        global_step = 0
        log_step = 100
        total_loss = 0
        


        #scaler = GradScaler()
        for epoch in train_iterator:
            start_time = time.time()
            optimizer.zero_grad()

            for step, batch in enumerate(train_dataloader):
                model.train()

                inputs = {'input_ids': batch[0].to(device),
                          'attention_mask': batch[1].to(device),
                          'labels': batch[2],
                          'entity_pos': batch[3],
                          'hts': batch[4],
                          }
                #with autocast():
                outputs = model(**inputs)
                loss = outputs[0] / args.gradient_accumulation_steps
                total_loss += loss.item()
                #    scaler.scale(loss).backward()
               

                loss.backward()

                if step % args.gradient_accumulation_steps == 0:
                    #scaler.unscale_(optimizer)
                    if args.max_grad_norm > 0:
                        # torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                        torch.nn.utils.clip_grad_norm_(cur_model.parameters(), args.max_grad_norm)
                    #scaler.step(optimizer)
                    #scaler.update()
                    #scheduler.step()
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    global_step += 1
                    num_steps += 1
                    if global_step % log_step == 0:
                        cur_loss = total_loss / log_step
                        elapsed = time.time() - start_time
                        logging(
                            '| epoch {:2d} | step {:4d} | min/b {:5.2f} | lr {} | train loss {:5.3f}'.format(
                                epoch, global_step, elapsed / 60, scheduler.get_last_lr(), cur_loss * 1000))
                        total_loss = 0
                        start_time = time.time()

                if (step + 1) == len(train_dataloader) - 1 or (args.evaluation_steps > 0 and num_steps % args.evaluation_steps == 0 and step % args.gradient_accumulation_steps == 0):
                # if step ==0:
                    logging('-' * 89)
                    eval_start_time = time.time()
                    dev_score, dev_output = evaluate(args, model, dev_features, tag="dev")

                    logging(
                        '| epoch {:3d} | time: {:5.2f}s | dev_result:{}'.format(epoch, time.time() - eval_start_time,
                                                                                dev_output))
                    logging('-' * 89)
                    if dev_score > best_score:
                        best_score = dev_score
                        logging(
                            '| epoch {:3d} | best_f1:{}'.format(epoch, best_score))
                        if args.save_path != "":
                            torch.save({
                                'epoch': epoch,
                                'checkpoint': cur_model.state_dict(),
                                'best_f1': best_score,
                                'optimizer': optimizer.state_dict()
                            }, args.save_path
                            , _use_new_zipfile_serialization=False)
                            logging(
                                '| successfully save model at: {}'.format(args.save_path))
                            logging('-' * 89)
        return num_steps

    cur_model = model.module if hasattr(model, 'module') else model
    extract_layer = ["extractor", "bilinear"]
    bert_layer = ['bert_model']
    optimizer_grouped_parameters = [
        {"params": [p for n, p in cur_model.named_parameters() if any(nd in n for nd in bert_layer)], "lr": args.bert_lr},
        {"params": [p for n, p in cur_model.named_parameters() if any(nd in n for nd in extract_layer)], "lr": 1e-4},
        {"params": [p for n, p in cur_model.named_parameters() if not any(nd in n for nd in extract_layer + bert_layer)]},
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    if args.train_from_saved_model != '':
        optimizer.load_state_dict(torch.load(args.train_from_saved_model)["optimizer"])
        print("load saved optimizer from {}.".format(args.train_from_saved_model))
    

    num_steps = 0
    set_seed(args)
    model.zero_grad()
    finetune(train_features, optimizer, args.num_train_epochs, num_steps, model)
Пример #17
0
    def train(self):
        #########################################################################################################################################
        # electra config 객체 생성
        electra_config = ElectraConfig.from_pretrained(
            "/home/mongjin/KuELECTRA_base",
            num_labels=self.config["senti_labels"],
            cache_dir=self.config["cache_dir_path"])

        # electra tokenizer 객체 생성
        electra_tokenizer = ElectraTokenizer.from_pretrained(
            "/home/mongjin/KuELECTRA_base",
            do_lower_case=False,
            cache_dir=self.config["cache_dir_path"])

        # electra model 객체 생성
        electra_model = ElectraForSequenceClassification.from_pretrained(
            "/home/mongjin/KuELECTRA_base",
            config=electra_config,
            lstm_hidden=self.config['lstm_hidden'],
            label_emb_size=self.config['lstm_hidden'] * 2,
            score_emb_size=self.config['lstm_hidden'] * 2,
            score_size=self.config['score_labels'],
            num_layer=self.config['lstm_num_layer'],
            bilstm_flag=self.config['bidirectional_flag'],
            cache_dir=self.config["cache_dir_path"],
            from_tf=True)
        #########################################################################################################################################

        electra_model.cuda()

        # 학습 데이터 읽기
        train_datas = preprocessing.read_data(
            file_path=self.config["train_data_path"], mode=self.config["mode"])

        # 학습 데이터 전처리
        train_dataset = preprocessing.convert_data2dataset(
            datas=train_datas,
            tokenizer=electra_tokenizer,
            max_length=self.config["max_length"],
            labels=self.config["senti_labels"],
            score_labels=self.config["score_labels"],
            mode=self.config["mode"])

        # 학습 데이터를 batch 단위로 추출하기 위한 DataLoader 객체 생성
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=self.config["batch_size"])

        # 평가 데이터 읽기
        test_datas = preprocessing.read_data(
            file_path=self.config["test_data_path"], mode=self.config["mode"])

        # 평가 데이터 전처리
        test_dataset = preprocessing.convert_data2dataset(
            datas=test_datas,
            tokenizer=electra_tokenizer,
            max_length=self.config["max_length"],
            labels=self.config["senti_labels"],
            score_labels=self.config["score_labels"],
            mode=self.config["mode"])

        # 평가 데이터를 batch 단위로 추출하기 위한 DataLoader 객체 생성
        test_sampler = SequentialSampler(test_dataset)
        test_dataloader = DataLoader(test_dataset,
                                     sampler=test_sampler,
                                     batch_size=100)

        # 전체 학습 횟수(batch 단위)
        t_total = len(train_dataloader) // self.config[
            "gradient_accumulation_steps"] * self.config["epoch"]

        # 모델 학습을 위한 optimizer
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer = AdamW([{
            'params': [
                p for n, p in electra_model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'lr':
            5e-5,
            'weight_decay':
            self.config['weight_decay']
        }, {
            'params': [
                p for n, p in electra_model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'lr':
            5e-5,
            'weight_decay':
            0.0
        }])
        # optimizer = AdamW(lan.parameters(), lr=self.config['learning_rate'], eps=self.config['adam_epsilon'])
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.config["warmup_steps"],
            num_training_steps=t_total)

        if os.path.isfile(
                os.path.join(self.config["model_dir_path"],
                             "optimizer.pt")) and os.path.isfile(
                                 os.path.join(self.config["model_dir_path"],
                                              "scheduler.pt")):
            # 기존에 학습했던 optimizer와 scheduler의 정보 불러옴
            optimizer.load_state_dict(
                torch.load(
                    os.path.join(self.config["model_dir_path"],
                                 "optimizer.pt")))
            scheduler.load_state_dict(
                torch.load(
                    os.path.join(self.config["model_dir_path"],
                                 "scheduler.pt")))
            print(
                "#######################     Success Load Model     ###########################"
            )

        global_step = 0
        electra_model.zero_grad()
        max_test_accuracy = 0
        for epoch in range(self.config["epoch"]):
            electra_model.train()

            # 학습 데이터에 대한 정확도와 평균 loss
            train_accuracy, average_loss, global_step, score_acc = self.do_train(
                electra_model=electra_model,
                optimizer=optimizer,
                scheduler=scheduler,
                train_dataloader=train_dataloader,
                epoch=epoch + 1,
                global_step=global_step)

            print("train_accuracy : {}\taverage_loss : {}\n".format(
                round(train_accuracy, 4), round(average_loss, 4)))
            print("train_score_accuracy :", "{:.6f}".format(score_acc))

            electra_model.eval()

            # 평가 데이터에 대한 정확도
            test_accuracy, score_acc = self.do_evaluate(
                electra_model=electra_model,
                test_dataloader=test_dataloader,
                mode=self.config["mode"])

            print("test_accuracy : {}\n".format(round(test_accuracy, 4)))
            print("test_score_accuracy :", "{:.6f}".format(score_acc))

            # 현재의 정확도가 기존 정확도보다 높은 경우 모델 파일 저장
            if (max_test_accuracy < test_accuracy):
                max_test_accuracy = test_accuracy

                output_dir = os.path.join(self.config["model_dir_path"],
                                          "checkpoint-{}".format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)

                electra_config.save_pretrained(output_dir)
                electra_tokenizer.save_pretrained(output_dir)
                electra_model.save_pretrained(output_dir)
                # torch.save(lan.state_dict(), os.path.join(output_dir, "lan.pt"))
                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))

            print("max_test_accuracy :",
                  "{:.6f}".format(round(max_test_accuracy, 4)))