Пример #1
0
    def train(self):
        train_data = StreamingSparseDataset(self.train_features, self.train_labels)
        train_dataloader = DataLoader(train_data, shuffle=True, batch_size=self.args.batch_size)

        print("Number of examples: ", len(self.train_labels))
        print("Batch size:", self.args.batch_size)

        for epoch in trange(int(self.args.epochs), desc="Epoch"):
            self.train_epoch(train_dataloader)
            if self.args.evaluate_dev:
                dev_evaluator = BagOfWordsEvaluator(self.model, self.vectorizer, self.processor, self.args, split='dev')
                dev_precision, dev_recall, dev_f1, dev_acc, dev_loss = dev_evaluator.get_scores()[0][:5]

                # Print validation results
                tqdm.write(self.log_header)
                tqdm.write(self.log_template.format(epoch + 1, self.nb_train_steps, epoch + 1, self.args.epochs,
                                                    dev_acc, dev_precision, dev_recall, dev_f1, dev_loss))

                # Update validation results
                if dev_f1 > self.best_dev_f1:
                    self.unimproved_iters = 0
                    self.best_dev_f1 = dev_f1
                    torch.save(self.model, self.snapshot_path)
                else:
                    self.unimproved_iters += 1
                    if self.unimproved_iters >= self.args.patience:
                        self.early_stop = True
                        tqdm.write("Early Stopping. Epoch: {}, Best Dev F1: {}".format(epoch, self.best_dev_f1))
                        break
        # save model at end of training
        # when evaluating on test
        if self.args.evaluate_test:
            torch.save(self.model, self.snapshot_path)
Пример #2
0
def evaluate_split(model, vectorizer, processor, args, split='dev'):
    evaluator = BagOfWordsEvaluator(model, vectorizer, processor, args, split)
    accuracy, precision, recall, f1, avg_loss = evaluator.get_scores(
        silent=True)[0]
    print('\n' + LOG_HEADER)
    print(
        LOG_TEMPLATE.format(split.upper(), accuracy, precision, recall, f1,
                            avg_loss))
Пример #3
0
def evaluate_split(model, vectorizer, processor, args, save_file, split='dev'):
    evaluator = BagOfWordsEvaluator(model, vectorizer, processor, args, split)
    scores, score_names = evaluator.get_scores(silent=True)
    p_micro, r_micro, f1_micro, accuracy, avg_loss = scores[:5]
    print('\n' + LOG_HEADER)
    print(
        LOG_TEMPLATE.format(split.upper(), accuracy, p_micro, r_micro,
                            f1_micro, avg_loss))

    scores_dict = dict(zip(score_names, scores))
    with open(save_file, 'w') as f:
        f.write(json.dumps(scores_dict))