Beispiel #1
0
    def evaluate(self, loader, punct=False):
        self.parser.eval()
        metric = Metric()
        pbar = tqdm(total=len(loader))

        for words, tags, masks, heads, rels, mask_heads in loader:
            states = [
                State(mask, tags.device, self.vocab.bert_index,
                      self.config.input_graph) for mask in masks
            ]

            states = self.parser(words, tags, masks, states)

            pred_heads = []
            pred_rels = []
            for state in states:
                pred_heads.append([h[0] for h in state.head][1:])
                pred_rels.append([h[1] for h in state.head][1:])
            pred_heads = [item for sublist in pred_heads for item in sublist]
            pred_rels = [item for sublist in pred_rels for item in sublist]

            pred_heads = torch.tensor(pred_heads).to(heads.device)
            pred_rels = torch.tensor(pred_rels).to(heads.device)

            heads = heads[mask_heads]
            rels = rels[mask_heads]
            pbar.update(1)
            metric(pred_heads, pred_rels, heads, rels)
            del states

        return metric
Beispiel #2
0
    def predict(self, loader):
        self.parser.eval()
        metric = Metric()
        pbar = tqdm(total=len(loader))
        all_arcs, all_rels = [], []
        for words, tags, masks, heads, rels, mask_heads in loader:
            states = [
                State(mask, tags.device, self.vocab.bert_index,
                      self.config.input_graph) for mask in masks
            ]
            states = self.parser(words, tags, masks, states)

            pred_heads = []
            pred_rels = []
            for state in states:
                pred_heads.append([h[0] for h in state.head][1:])
                pred_rels.append([h[1] for h in state.head][1:])

            pred_heads = [item for sublist in pred_heads for item in sublist]
            pred_rels = [item for sublist in pred_rels for item in sublist]

            pred_heads = torch.tensor(pred_heads).to(heads.device)
            pred_rels = torch.tensor(pred_rels).to(heads.device)

            heads = heads[mask_heads]
            rels = rels[mask_heads]

            metric(pred_heads, pred_rels, heads, rels)

            lens = masks.sum(1).tolist()

            all_arcs.extend(torch.split(pred_heads, lens))
            all_rels.extend(torch.split(pred_rels, lens))
            pbar.update(1)
        all_arcs = [seq.tolist() for seq in all_arcs]
        all_rels = [self.vocab.id2rel(seq) for seq in all_rels]

        return all_arcs, all_rels, metric
Beispiel #3
0
    def evaluate(self, loader, punct=False):
        self.parser.eval()

        loss, metric = 0, Metric()

        for words, tags, arcs, rels in loader:
            mask = words.ne(self.vocab.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            # ignore all punctuation if not specified
            if not punct:
                puncts = words.new_tensor(self.vocab.puncts)
                mask &= words.unsqueeze(-1).ne(puncts).all(-1)
            s_arc, s_rel = self.parser(words, tags)
            s_arc, s_rel = s_arc[mask], s_rel[mask]
            gold_arcs, gold_rels = arcs[mask], rels[mask]
            pred_arcs, pred_rels = self.decode(s_arc, s_rel)

            loss += self.get_loss(s_arc, s_rel, gold_arcs, gold_rels)
            metric(pred_arcs, pred_rels, gold_arcs, gold_rels)
        loss /= len(loader)

        return loss, metric
Beispiel #4
0
    def __call__(self, config):
        print("Preprocess the data")
        train = Corpus.load(config.ftrain)
        dev = Corpus.load(config.fdev)
        test = Corpus.load(config.ftest)
        if os.path.exists(config.vocab):
            vocab = torch.load(config.vocab)
        else:
            vocab = Vocab.from_corpus(corpus=train, min_freq=2)
            vocab.read_embeddings(Embedding.load(config.fembed, config.unk))
            torch.save(vocab, config.vocab)
        config.update({
            'n_words': vocab.n_train_words,
            'n_tags': vocab.n_tags,
            'n_rels': vocab.n_rels,
            'pad_index': vocab.pad_index,
            'unk_index': vocab.unk_index
        })
        print(vocab)

        print("Load the dataset")
        trainset = TextDataset(vocab.numericalize(train))
        devset = TextDataset(vocab.numericalize(dev))
        testset = TextDataset(vocab.numericalize(test))
        # set the data loaders
        train_loader = batchify(dataset=trainset,
                                batch_size=config.batch_size,
                                n_buckets=config.buckets,
                                shuffle=True)
        dev_loader = batchify(dataset=devset,
                              batch_size=config.batch_size,
                              n_buckets=config.buckets)
        test_loader = batchify(dataset=testset,
                               batch_size=config.batch_size,
                               n_buckets=config.buckets)
        print(f"{'train:':6} {len(trainset):5} sentences in total, "
              f"{len(train_loader):3} batches provided")
        print(f"{'dev:':6} {len(devset):5} sentences in total, "
              f"{len(dev_loader):3} batches provided")
        print(f"{'test:':6} {len(testset):5} sentences in total, "
              f"{len(test_loader):3} batches provided")

        print("Create the model")
        parser = BiaffineParser(config, vocab.embeddings)
        if torch.cuda.is_available():
            parser = parser.cuda()
        print(f"{parser}\n")

        model = Model(vocab, parser)

        total_time = timedelta()
        best_e, best_metric = 1, Metric()
        model.optimizer = Adam(model.parser.parameters(),
                               config.lr,
                               (config.beta_1, config.beta_2),
                               config.epsilon)
        model.scheduler = ExponentialLR(model.optimizer,
                                        config.decay ** (1 / config.steps))

        for epoch in range(1, config.epochs + 1):
            start = datetime.now()
            # train one epoch and update the parameters
            model.train(train_loader)

            print(f"Epoch {epoch} / {config.epochs}:")
            loss, train_metric = model.evaluate(train_loader, config.punct)
            print(f"{'train:':6} Loss: {loss:.4f} {train_metric}")
            loss, dev_metric = model.evaluate(dev_loader, config.punct)
            print(f"{'dev:':6} Loss: {loss:.4f} {dev_metric}")
            loss, test_metric = model.evaluate(test_loader, config.punct)
            print(f"{'test:':6} Loss: {loss:.4f} {test_metric}")

            t = datetime.now() - start
            # save the model if it is the best so far
            if dev_metric > best_metric and epoch > config.patience:
                best_e, best_metric = epoch, dev_metric
                model.parser.save(config.model + f".{best_e}")
                print(f"{t}s elapsed (saved)\n")
            else:
                print(f"{t}s elapsed\n")
            total_time += t
            if epoch - best_e >= config.patience:
                break
        model.parser = BiaffineParser.load(config.model + f".{best_e}")
        loss, metric = model.evaluate(test_loader, config.punct)

        print(f"max score of dev is {best_metric.score:.2%} at epoch {best_e}")
        print(f"the score of test at epoch {best_e} is {metric.score:.2%}")
        print(f"average time of each epoch is {total_time / epoch}s")
        print(f"{total_time}s elapsed")
Beispiel #5
0
    def __call__(self, config):
        print("Preprocess the data")
        train = Corpus.load(config.ftrain)
        dev = Corpus.load(config.fdev)
        test = Corpus.load(config.ftest)

        if path.exists(config.model) != True:
            os.mkdir(config.model)

        if path.exists("model/") != True:
            os.mkdir("model/")

        if path.exists(config.model + config.modelname) != True:
            os.mkdir(config.model + config.modelname)

        if config.checkpoint:
            vocab = torch.load(config.main_path + config.vocab +
                               config.modelname + "/vocab.tag")
        else:
            vocab = Vocab.from_corpus(config=config,
                                      corpus=train,
                                      corpus_dev=dev,
                                      corpus_test=test,
                                      min_freq=0)
        train_seq = read_seq(config.ftrain_seq, vocab)
        total_act = 0
        for x in train_seq:
            total_act += len(x)
        print("number of transitions:{}".format(total_act))

        torch.save(vocab, config.vocab + config.modelname + "/vocab.tag")

        config.update({
            'n_words': vocab.n_train_words,
            'n_tags': vocab.n_tags,
            'n_rels': vocab.n_rels,
            'n_trans': vocab.n_trans,
            'pad_index': vocab.pad_index,
            'unk_index': vocab.unk_index
        })

        print("Load the dataset")
        trainset = TextDataset(vocab.numericalize(train, train_seq))
        devset = TextDataset(vocab.numericalize(dev))
        testset = TextDataset(vocab.numericalize(test))

        # set the data loaders
        train_loader, _ = batchify(dataset=trainset,
                                   batch_size=config.batch_size,
                                   n_buckets=config.buckets,
                                   shuffle=True)
        dev_loader, _ = batchify(dataset=devset,
                                 batch_size=config.batch_size,
                                 n_buckets=config.buckets)
        test_loader, _ = batchify(dataset=testset,
                                  batch_size=config.batch_size,
                                  n_buckets=config.buckets)

        print(f"{'train:':6} {len(trainset):5} sentences in total, "
              f"{len(train_loader):3} batches provided")
        print(f"{'dev:':6} {len(devset):5} sentences in total, "
              f"{len(dev_loader):3} batches provided")
        print(f"{'test:':6} {len(testset):5} sentences in total, "
              f"{len(test_loader):3} batches provided")
        print("Create the model")

        if config.checkpoint:
            parser = Parser.load(config.main_path + config.model +
                                 config.modelname + "/parser-checkpoint")
        else:
            parser = Parser(config, vocab.bertmodel)

        print("number of parameters:{}".format(
            sum(p.numel() for p in parser.parameters() if p.requires_grad)))
        if torch.cuda.is_available():
            print('Train/Evaluate on GPU')
            device = torch.device('cuda')
            parser = parser.to(device)

        model = Model(vocab, parser, config, vocab.n_rels)
        total_time = timedelta()
        best_e, best_metric = 1, Metric()

        ## prepare optimisers
        num_train_optimization_steps = int(config.epochs * len(train_loader))
        warmup_steps = int(config.warmupproportion *
                           num_train_optimization_steps)
        ## one for parsing parameters, one for BERT parameters
        if config.use_two_opts:
            model_nonbert = []
            model_bert = []
            layernorm_params = [
                'layernorm_key_layer', 'layernorm_value_layer',
                'dp_relation_k', 'dp_relation_v'
            ]
            for name, param in parser.named_parameters():
                if 'bert' in name and not any(nd in name
                                              for nd in layernorm_params):
                    model_bert.append((name, param))
                else:
                    model_nonbert.append((name, param))

            # Prepare optimizer and schedule (linear warmup and decay) for Non-bert parameters
            no_decay = ['bias', 'LayerNorm.weight']
            optimizer_grouped_parameters_nonbert = [{
                'params': [
                    p for n, p in model_nonbert
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                config.weight_decay
            }, {
                'params': [
                    p for n, p in model_nonbert
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            }]
            model.optimizer_nonbert = AdamW(
                optimizer_grouped_parameters_nonbert, lr=config.lr2)

            model.scheduler_nonbert = get_linear_schedule_with_warmup(
                model.optimizer_nonbert,
                num_warmup_steps=warmup_steps,
                num_training_steps=num_train_optimization_steps)

            # Prepare optimizer and schedule (linear warmup and decay) for Bert parameters
            optimizer_grouped_parameters_bert = [{
                'params': [
                    p for n, p in model_bert
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                config.weight_decay
            }, {
                'params':
                [p for n, p in model_bert if any(nd in n for nd in no_decay)],
                'weight_decay':
                0.0
            }]

            model.optimizer_bert = AdamW(optimizer_grouped_parameters_bert,
                                         lr=config.lr)
            model.scheduler_bert = get_linear_schedule_with_warmup(
                model.optimizer_bert,
                num_warmup_steps=warmup_steps,
                num_training_steps=num_train_optimization_steps)

        else:
            # Prepare optimizer and schedule (linear warmup and decay)
            no_decay = ['bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in parser.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                config.weight_decay
            }, {
                'params': [
                    p for n, p in parser.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            }]
            model.optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr)
            model.scheduler = get_linear_schedule_with_warmup(
                model.optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=num_train_optimization_steps)

        start_epoch = 1

        ## load model, optimiser, and other parameters from a checkpoint
        if config.checkpoint:
            check_load = torch.load(config.main_path + config.model +
                                    config.modelname + "/checkpoint")
            if config.use_two_opts:
                model.optimizer_bert.load_state_dict(
                    check_load['optimizer_bert'])
                model.optimizer_nonbert.load_state_dict(
                    check_load['optimizer_nonbert'])
                model.scheduler_bert.load_state_dict(
                    check_load['lr_schedule_bert'])
                model.scheduler_nonbert.load_state_dict(
                    check_load['lr_schedule_nonbert'])
                start_epoch = check_load['epoch'] + 1
                best_e = check_load['best_e']
                best_metric = check_load['best_metric']
            else:
                model.optimizer.load_state_dict(check_load['optimizer'])
                model.scheduler.load_state_dict(check_load['lr_schedule'])
                start_epoch = check_load['epoch'] + 1
                best_e = check_load['best_e']
                best_metric = check_load['best_metric']

        f1 = open(config.model + config.modelname + "/baseline.txt", "a")

        f1.write("New Model:\n")
        f1.close()
        for epoch in range(start_epoch, config.epochs + 1):
            start = datetime.now()
            # train one epoch and update the parameters
            model.train(train_loader)
            print(f"Epoch {epoch} / {config.epochs}:")
            f1 = open(config.model + config.modelname + "/baseline.txt", "a")
            dev_metric = model.evaluate(dev_loader, config.punct)
            f1.write(str(epoch) + "\n")
            print(f"{'dev:':6} {dev_metric}")
            f1.write(f"{'dev:':6} {dev_metric}")
            f1.write("\n")
            f1.close()

            t = datetime.now() - start
            # save the model if it is the best so far
            if dev_metric > best_metric:
                best_e, best_metric = epoch, dev_metric
                print(config.model + config.modelname + "/model_weights")
                model.parser.save(config.model + config.modelname +
                                  "/model_weights")
                print(f"{t}s elapsed (saved)\n")
            else:
                print(f"{t}s elapsed\n")
            total_time += t
            if epoch - best_e >= config.patience:
                break

            ## save checkpoint
            if config.use_two_opts:
                checkpoint = {
                    "epoch": epoch,
                    "optimizer_bert": model.optimizer_bert.state_dict(),
                    "lr_schedule_bert": model.scheduler_bert.state_dict(),
                    "lr_schedule_nonbert":
                    model.scheduler_nonbert.state_dict(),
                    "optimizer_nonbert": model.optimizer_nonbert.state_dict(),
                    'best_metric': best_metric,
                    'best_e': best_e
                }
                torch.save(
                    checkpoint, config.main_path + config.model +
                    config.modelname + "/checkpoint")
                parser.save(config.main_path + config.model +
                            config.modelname + "/parser-checkpoint")
            else:
                checkpoint = {
                    "epoch": epoch,
                    "optimizer": model.optimizer.state_dict(),
                    "lr_schedule": model.scheduler.state_dict(),
                    'best_metric': best_metric,
                    'best_e': best_e
                }
                torch.save(
                    checkpoint, config.main_path + config.model +
                    config.modelname + "/checkpoint")
                parser.save(config.main_path + config.model +
                            config.modelname + "/parser-checkpoint")
        model.parser = Parser.load(config.model + config.modelname +
                                   "/model_weights")
        metric = model.evaluate(test_loader, config.punct)
        print(metric)
        print(f"max score of dev is {best_metric.score:.2%} at epoch {best_e}")
        print(f"the score of test at epoch {best_e} is {metric.score:.2%}")
        print(f"average time of each epoch is {total_time / epoch}s")
        print(f"{total_time}s elapsed")
Beispiel #6
0
    def __call__(self, config):
        print("Preprocess the data")

        if config.input_type == "conllu":
            train = UniversalDependenciesDatasetReader()
            train.load(config.ftrain)
            dev = UniversalDependenciesDatasetReader()
            dev.load(config.fdev)
            test = UniversalDependenciesDatasetReader()
            test.load(config.ftest)
        else:
            train = Corpus.load(config.ftrain)
            dev = Corpus.load(config.fdev)
            test = Corpus.load(config.ftest)

        if config.use_predicted:
            if config.input_type == "conllu":
                train_predicted = UniversalDependenciesDatasetReader()
                train_predicted.load(config.fpredicted_train)
                dev_predicted = UniversalDependenciesDatasetReader()
                dev_predicted.load(config.fpredicted_dev)
                test_predicted = UniversalDependenciesDatasetReader()
                test_predicted.load(config.fpredicted_test)
            else:
                train_predicted = Corpus.load(config.fpredicted_train)
                dev_predicted = Corpus.load(config.fpredicted_dev)
                test_predicted = Corpus.load(config.fpredicted_test)

        if path.exists(config.main_path + "/exp") != True:
            os.mkdir(config.main_path + "/exp")

        if path.exists(config.main_path + "/model") != True:
            os.mkdir(config.main_path + "/model")

        if path.exists(config.main_path + config.model + config.modelname) != True:
            os.mkdir(config.main_path + config.model + config.modelname)

        vocab = Vocab.from_corpus(config=config, corpus=train, min_freq=2)

        torch.save(vocab, config.main_path + config.vocab + config.modelname + "/vocab.tag")

        config.update({
            'n_words': vocab.n_train_words,
            'n_tags': vocab.n_tags,
            'n_rels': vocab.n_rels,
            'pad_index': vocab.pad_index,
            'unk_index': vocab.unk_index
        })

        print("Load the dataset")

        if config.use_predicted:
            trainset = TextDataset(vocab.numericalize(train, train_predicted))
            devset = TextDataset(vocab.numericalize(dev, dev_predicted))
            testset = TextDataset(vocab.numericalize(test, test_predicted))
        else:
            trainset = TextDataset(vocab.numericalize(train))
            devset = TextDataset(vocab.numericalize(dev))
            testset = TextDataset(vocab.numericalize(test))

        # set the data loaders
        train_loader, _ = batchify(dataset=trainset,
                                   batch_size=config.batch_size,
                                   n_buckets=config.buckets,
                                   shuffle=True)
        dev_loader, _ = batchify(dataset=devset,
                                 batch_size=config.batch_size,
                                 n_buckets=config.buckets)
        test_loader, _ = batchify(dataset=testset,
                                  batch_size=config.batch_size,
                                  n_buckets=config.buckets)
        print(f"{'train:':6} {len(trainset):5} sentences in total, "
              f"{len(train_loader):3} batches provided")
        print(f"{'dev:':6} {len(devset):5} sentences in total, "
              f"{len(dev_loader):3} batches provided")
        print(f"{'test:':6} {len(testset):5} sentences in total, "
              f"{len(test_loader):3} batches provided")

        print("Create the model")
        parser = BiaffineParser(config, vocab.n_rels, vocab.bertmodel)

        print("number of pars:{}".format(sum(p.numel() for p in parser.parameters()
                                             if p.requires_grad)))
        if torch.cuda.is_available():
            print('device:cuda')
            device = torch.device('cuda')
            parser = parser.to(device)
        # print(f"{parser}\n")

        model = Model(vocab, parser, config, vocab.n_rels)
        total_time = timedelta()
        best_e, best_metric = 1, Metric()

        num_train_optimization_steps = int(config.num_iter_encoder * config.epochs * len(train_loader))
        warmup_steps = int(config.warmupproportion * num_train_optimization_steps)

        if config.use_two_opts:
            model_nonbert = []
            model_bert = []
            layernorm_params = ['layernorm_key_layer', 'layernorm_value_layer', 'dp_relation_k', 'dp_relation_v']
            for name, param in parser.named_parameters():
                if 'bert' in name and not any(nd in name for nd in layernorm_params):
                    model_bert.append((name, param))
                else:
                    model_nonbert.append((name, param))

            # Prepare optimizer and schedule (linear warmup and decay) for Non-bert parameters
            no_decay = ['bias', 'LayerNorm.weight']
            optimizer_grouped_parameters_nonbert = [
                {'params': [p for n, p in model_nonbert if not any(nd in n for nd in no_decay)],
                 'weight_decay': config.weight_decay},
                {'params': [p for n, p in model_nonbert if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]
            model.optimizer_nonbert = AdamW(optimizer_grouped_parameters_nonbert, lr=config.lr2)

            model.scheduler_nonbert = get_linear_schedule_with_warmup(model.optimizer_nonbert,
                                                                      num_warmup_steps=warmup_steps,
                                                                      num_training_steps=num_train_optimization_steps)

            # Prepare optimizer and schedule (linear warmup and decay) for Bert parameters
            optimizer_grouped_parameters_bert = [
                {'params': [p for n, p in model_bert if not any(nd in n for nd in no_decay)],
                 'weight_decay': config.weight_decay},
                {'params': [p for n, p in model_bert if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]

            model.optimizer_bert = AdamW(optimizer_grouped_parameters_bert, lr=config.lr1)
            model.scheduler_bert = get_linear_schedule_with_warmup(
                model.optimizer_bert, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
            )

        else:
            # Prepare optimizer and schedule (linear warmup and decay)
            no_decay = ['bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [
                {'params': [p for n, p in parser.named_parameters() if not any(nd in n for nd in no_decay)],
                 'weight_decay': config.weight_decay},
                {'params': [p for n, p in parser.named_parameters() if any(nd in n for nd in no_decay)],
                 'weight_decay': 0.0}
            ]
            model.optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr1)
            model.scheduler = get_linear_schedule_with_warmup(
                model.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
            )

        for epoch in range(1, config.epochs + 1):
            start = datetime.now()
            # train one epoch and update the parameters
            if config.use_predicted:
                model.train_predicted(train_loader)
            else:
                model.train(train_loader)
            print(f"Epoch {epoch} / {config.epochs}:")

            if config.use_predicted:
                loss, dev_metric = model.evaluate_predicted(dev_loader, config.punct)
            else:
                loss, dev_metric = model.evaluate(dev_loader, config.punct)

            print(f"{'dev:':6} Loss: {loss:.4f} {dev_metric}")
            if config.use_predicted:
                loss, test_metric = model.evaluate_predicted(test_loader, config.punct)
            else:
                loss, test_metric = model.evaluate(test_loader, config.punct)
            print(f"{'test:':6} Loss: {loss:.4f} {test_metric}")

            t = datetime.now() - start
            # save the model if it is the best so far
            if dev_metric > best_metric:
                best_e, best_metric = epoch, dev_metric
                print(config.model + config.modelname + "/model_weights")
                model.parser.save(config.main_path + config.model + config.modelname + "/model_weights")
                print(f"{t}s elapsed (saved)\n")
            else:
                print(f"{t}s elapsed\n")
            total_time += t
            if epoch - best_e >= config.patience:
                break
        model.parser = BiaffineParser.load(config.main_path + config.model + config.modelname + "/model_weights")
        if config.use_predicted:
            loss, metric = model.evaluate_predicted(test_loader, config.punct)
        else:
            loss, metric = model.evaluate(test_loader, config.punct)
        print(metric)
        print(f"max score of dev is {best_metric.score:.2%} at epoch {best_e}")
        print(f"the score of test at epoch {best_e} is {metric.score:.2%}")
        print(f"average time of each epoch is {total_time / epoch}s")
        print(f"{total_time}s elapsed")
Beispiel #7
0
    def predict(self, loader):
        self.parser.eval()
        metric = Metric()
        all_arcs, all_rels = [], []
        for words, tags, arcs, rels, mask, sbert_arc, sbert_rel, offsets in loader:

            stop_sign = torch.ones(len(words)).long().to(words.device)
            mask_gold = arcs > 0
            mask_unused = mask.clone()
            ## iterate over encoder
            for counter in range(0, self.config.num_iter_encoder):

                self.counter_ref = counter

                if counter == 0:
                    s_arc, s_rel = self.parser(words, tags)
                    s_arc_final = s_arc
                    s_rel_final = s_rel
                else:
                    if self.config.use_mst_train or counter == 1:
                        graph_arc, graph_rel = \
                            self.prepare_mst(s_arc, s_rel, mask, sbert_arc.clone(), sbert_rel.clone(), stop_sign, False)
                    else:
                        graph_arc, graph_rel = \
                            self.prepare_argmax(s_arc, s_rel, mask, sbert_arc.clone(), sbert_rel.clone(), stop_sign,
                                                False)
                    s_arc, s_rel = self.parser(words, tags, stop_sign,
                                               graph_arc, graph_rel)

                if self.config.use_mst_train or counter == 1:
                    new_arcs, new_rels = self.prepare_mst(
                        s_arc, s_rel, mask, sbert_arc.clone(),
                        sbert_rel.clone(), stop_sign, True)
                else:
                    new_arcs, new_rels = self.prepare_argmax(
                        s_arc, s_rel, mask, sbert_arc.clone(),
                        sbert_rel.clone(), stop_sign, True)

                if counter > 0:
                    stop_sign = self.check_stop(stop_sign, new_arcs, prev_arcs,
                                                new_rels, prev_rels, mask)
                    if stop_sign.sum() == 0:
                        # print('All Dependency Graphs are converged in this batch')
                        break
                    mask = (stop_sign.unsqueeze(1) * mask.long()).bool()

                prev_arcs = new_arcs
                prev_rels = new_rels

                index = stop_sign.nonzero()
                s_arc_final[index] = s_arc[index]
                s_rel_final[index] = s_rel[index]

            gold_arcs, gold_rels = arcs[mask_gold], rels[mask_gold]

            if self.config.use_mst_eval:
                pred_rels, pred_arcs_org, pred_arcs = self.decode_mst(
                    s_arc_final,
                    s_rel_final,
                    mask_unused,
                    prepare=False,
                    do_predict=True)

            metric(pred_arcs, pred_rels, gold_arcs, gold_rels)

            lens = mask_unused.sum(1).tolist()

            all_arcs.extend(torch.split(pred_arcs_org, lens))
            all_rels.extend(torch.split(pred_rels, lens))

        all_arcs = [seq.tolist() for seq in all_arcs]
        all_rels = [self.vocab.id2rel(seq) for seq in all_rels]

        return all_arcs, all_rels, metric
Beispiel #8
0
    def evaluate(self, loader, punct=False):
        self.parser.eval()

        loss, metric = 0, Metric()
        pbar = tqdm(total=len(loader))
        for words, tags, arcs, rels, mask, sbert_arc, sbert_rel in loader:
            stop_sign = torch.ones(len(words)).long().to(words.device)
            mask_gold = arcs > 0
            mask_unused = mask.clone()
            ## iterate over encoder
            for counter in range(self.config.num_iter_encoder):

                self.counter_ref = counter

                if counter == 0:
                    s_arc, s_rel = self.parser(words, tags)
                    s_arc_final = s_arc
                    s_rel_final = s_rel
                else:
                    if self.config.use_mst_train or counter == 1:
                        graph_arc,graph_rel= \
                        self.prepare_mst(s_arc,s_rel,mask,sbert_arc.clone(),sbert_rel.clone(),stop_sign,False)
                    else:
                        graph_arc,graph_rel= \
                        self.prepare_argmax(s_arc,s_rel,mask,sbert_arc.clone(),sbert_rel.clone(),stop_sign,False)

                    s_arc, s_rel = self.parser(words, tags, stop_sign,
                                               graph_arc, graph_rel)

                if self.config.use_mst_train or counter == 1:
                    new_arcs, new_rels = self.prepare_mst(
                        s_arc, s_rel, mask, sbert_arc.clone(),
                        sbert_rel.clone(), stop_sign, True)
                else:
                    new_arcs, new_rels = self.prepare_argmax(
                        s_arc, s_rel, mask, sbert_arc.clone(),
                        sbert_rel.clone(), stop_sign, True)

                if counter > 0:
                    stop_sign = self.check_stop(stop_sign, new_arcs, prev_arcs,
                                                new_rels, prev_rels, mask)
                    if stop_sign.sum() == 0:
                        # print('All Dependency Graphs are converged in this batch')
                        break
                    mask = (stop_sign.unsqueeze(1) * mask.long()).bool()

                prev_arcs = new_arcs
                prev_rels = new_rels

                index = stop_sign.nonzero()
                s_arc_final[index] = s_arc[index]
                s_rel_final[index] = s_rel[index]

                if self.config.show_refinement:
                    if counter == 0:
                        self.initial_refinement_total(new_arcs.clone(),
                                                      new_rels.clone(),
                                                      mask_unused.clone(),
                                                      arcs.clone(),
                                                      rels.clone(),
                                                      mask_gold.clone())
                    elif counter > 0:
                        self.refinement_total(new_arcs.clone(),
                                              new_rels.clone(),
                                              mask_unused.clone(),
                                              stop_sign.clone())

            gold_arcs, gold_rels = arcs[mask_gold], rels[mask_gold]

            if self.config.use_mst_eval:
                pred_arcs, pred_rels = self.decode_mst(s_arc_final,
                                                       s_rel_final,
                                                       mask_unused,
                                                       prepare=False)
            else:
                pred_arcs, pred_rels = self.decode(s_arc_final, s_rel_final,
                                                   mask_unused)

            s_arc_mask = s_arc_final[mask_unused]
            s_rel_mask = s_rel_final[mask_unused]

            loss += self.get_loss(s_arc_mask, s_rel_mask, gold_arcs, gold_rels)

            metric(pred_arcs, pred_rels, gold_arcs, gold_rels)
            pbar.update(1)

        loss /= len(loader)
        return loss, metric