예제 #1
0
    def train(self, train_set, dev_set):
        self.iterations, self.nb_tr_steps, self.tr_loss = 0, 0, 0
        self.best_valid_metric, self.unimproved_iters = 0, 0
        self.early_stop = False
        if self.args.gradient_accumulation_steps < 1:
            raise ValueError(
                "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
                .format(self.args.gradient_accumulation_steps))

        self.args.batch_size = (self.args.batch_size //
                                self.args.gradient_accumulation_steps)
        self.init_optimizer()

        train_dataset = convert_df_to_ids(train_set, self.word2id,
                                          self.args.max_seq_length)

        dev_dataset = convert_df_to_ids(dev_set, self.word2id,
                                        self.args.max_seq_length)

        train_dataloader = DataLoader(
            train_dataset,
            sampler=RandomSampler(train_dataset),
            batch_size=self.args.batch_size,
        )
        dev_dataloader = DataLoader(
            dev_dataset,
            sampler=SequentialSampler(dev_dataset),
            batch_size=self.args.batch_size,
        )

        for epoch in tqdm(range(int(self.args.epochs))):
            self.tr_loss = self.train_an_epoch(train_dataloader)[0]
            tqdm.write(f"[Epoch {epoch}] loss: {self.tr_loss}".format(
                epoch, self.best_valid_metric))
            self.tr_loss = 0
            eval_result = self.eval(dev_dataloader)
            # Update validation results
            if eval_result[self.args.valid_metric] > self.best_valid_metric:
                self.unimproved_iters = 0
                self.best_valid_metric = eval_result[self.args.valid_metric]
                print_dict_as_table(
                    eval_result,
                    tag=f"[Epoch {epoch}]performance on validation set",
                    columns=["metrics", "values"],
                )
                ensureDir(self.args.model_save_dir)
                self.save_pretrained(self.args.model_save_dir)
            else:
                self.unimproved_iters += 1
                if self.unimproved_iters >= self.args.patience:
                    self.early_stop = True
                    tqdm.write(
                        "Early Stopping. Epoch: {}, best_valid_metric ({}): {}"
                        .format(epoch, self.args.valid_metric,
                                self.best_valid_metric))
                    break
예제 #2
0
    def predict(self, test_set):
        """

        Args:
            test_set: list of :obj:InputExample

        Returns:
            ndarray: An array of predicted label scores.
        """
        test_dataset = convert_df_to_ids(test_set, self.word2id,
                                         self.args.max_seq_length)

        test_dataloader = DataLoader(
            test_dataset,
            sampler=SequentialSampler(test_dataset),
            batch_size=self.args.batch_size,
        )
        return self.scores(test_dataloader)[0]
예제 #3
0
    def test(self, test_set):
        """Get a evaluation result for a test set.

        Args:
            test_set:

        Returns:

        """
        test_dataset = convert_df_to_ids(test_set, self.word2id,
                                         self.args.max_seq_length)

        test_dataloader = DataLoader(
            test_dataset,
            sampler=SequentialSampler(test_dataset),
            batch_size=self.args.batch_size,
        )
        return self.eval(test_dataloader)