def evaluation(self):
        if self.hyper.is_bert == "nyt_bert_tokenizer" or \
                self.hyper.is_bert == "nyt11_bert_tokenizer" or \
                self.hyper.is_bert == "nyt10_bert_tokenizer":
            dev_set = Selection_bert_Nyt_Dataset(self.hyper, self.hyper.dev)
            loader = Selection_bert_loader(dev_set,
                                           batch_size=100,
                                           pin_memory=True)

        elif self.hyper.is_bert == "bert_bilstem_crf":
            dev_set = Selection_Nyt_Dataset(self.hyper, self.hyper.dev)
            loader = Selection_loader(dev_set, batch_size=100, pin_memory=True)

        else:
            dev_set = Selection_Dataset(self.hyper, self.hyper.dev)
            loader = Selection_loader(dev_set, batch_size=100, pin_memory=True)

        self.metrics.reset()
        self.model.eval()

        pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader))

        with torch.no_grad():
            for batch_ndx, sample in pbar:
                output = self.model(sample, is_train=False)
                #self.metrics(output['selection_triplets'], output['spo_gold'])
                self.metrics(output, output)

            result = self.metrics.get_metric()
            print(', '.join([
                "%s: %.4f" % (name, value)
                for name, value in result.items() if not name.startswith("_")
            ]) + " ||")
Example #2
0
    def train(self):
        train_set = Selection_Dataset(self.hyper, self.hyper.train)
        loader = Selection_loader(train_set,
                                  batch_size=self.hyper.train_batch,
                                  pin_memory=True)

        for epoch in range(self.hyper.epoch_num):
            self.model.train()
            pbar = tqdm(enumerate(BackgroundGenerator(loader)),
                        total=len(loader))

            for batch_idx, sample in pbar:

                self.optimizer.zero_grad()
                output = self.model(sample, is_train=True)
                loss = output['loss']
                loss.backward()
                self.optimizer.step()

                pbar.set_description(output['description'](
                    epoch, self.hyper.epoch_num))

                # break

            self.save_model(epoch)

            if epoch % self.hyper.print_epoch == 0 and epoch > 3:
                self.evaluation()
Example #3
0
    def evaluation(self):
        dev_set = Selection_Dataset(self.hyper, self.hyper.dev)
        loader = Selection_loader(dev_set,
                                  batch_size=self.hyper.eval_batch,
                                  pin_memory=True)
        self.triplet_metrics.reset()
        self.model.eval()

        pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader))

        with torch.no_grad():
            for batch_ndx, sample in pbar:
                output = self.model(sample, is_train=False)
                self.triplet_metrics(output['selection_triplets'],
                                     output['spo_gold'])
                self.ner_metrics(output['gold_tags'], output['decoded_tag'])

            triplet_result = self.triplet_metrics.get_metric()
            ner_result = self.ner_metrics.get_metric()
            print('Triplets-> ' + ', '.join([
                "%s: %.4f" % (name[0], value)
                for name, value in triplet_result.items()
                if not name.startswith("_")
            ]) + ' ||' + 'NER->' + ', '.join([
                "%s: %.4f" % (name[0], value)
                for name, value in ner_result.items()
                if not name.startswith("_")
            ]))
Example #4
0
    def postcheck(self):
        dev_set = Selection_Dataset(self.hyper,
                                    self.hyper.xgb_train_root,
                                    is_xgb=True)
        loader = Selection_loader(dev_set,
                                  batch_size=self.hyper.eval_batch,
                                  pin_memory=True)
        self.model.eval()

        pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader))

        self.save_rc.reset('pos_neg_score.json')

        flag = False

        with torch.no_grad():
            for batch_ndx, sample in pbar:
                output = self.model(sample, is_train=False)
                if output.get('pos_list'):
                    pos_list = output['pos_list']
                    flag = True
                neg_list = output['neg_list']

                if flag:
                    self.save_rc.save(neg_list, pos_list)
                else:
                    self.save_rc.save(neg_list)
Example #5
0
    def evaluation(self):
        dev_set = Selection_Dataset(self.hyper, self.hyper.dev)
        loader = Selection_loader(dev_set,
                                  batch_size=self.hyper.eval_batch,
                                  pin_memory=True)
        self.selection_metrics.reset()
        self.model.eval()

        pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader))

        with torch.no_grad():
            for batch_ndx, sample in pbar:
                pred = self.model(sample, is_train=False)
                pred = torch.sigmoid(pred) > 0.5
                labels = sample.selection_id
                self.selection_metrics(
                    np.array(pred.cpu().numpy(), dtype=int).tolist(),
                    np.array(labels.cpu().numpy(), dtype=int).tolist())

            triplet_result = self.selection_metrics.get_metric()

            print('Triplets-> ' + ', '.join([
                "%s: %.4f" % (name[0], value)
                for name, value in triplet_result.items()
                if not name.startswith("_")
            ]))
    def train(self):

        if self.hyper.is_bert == "bert_bilstem_crf":
            train_set = Selection_Nyt_Dataset(self.hyper, self.hyper.train)
            loader = Selection_loader(train_set,
                                      batch_size=100,
                                      pin_memory=True)

        elif self.hyper.is_bert == "nyt_bert_tokenizer" or \
                self.hyper.is_bert == "nyt11_bert_tokenizer" or \
                self.hyper.is_bert == "nyt10_bert_tokenizer":
            train_set = Selection_bert_Nyt_Dataset(self.hyper,
                                                   self.hyper.train)
            loader = Selection_bert_loader(train_set,
                                           batch_size=100,
                                           pin_memory=True)

        else:
            train_set = Selection_Dataset(self.hyper, self.hyper.train)
            loader = Selection_loader(train_set,
                                      batch_size=100,
                                      pin_memory=True)

        for epoch in range(self.hyper.epoch_num):
            self.model.train()
            pbar = tqdm(enumerate(BackgroundGenerator(loader)),
                        total=len(loader))

            for batch_idx, sample in pbar:
                self.optimizer.zero_grad()
                output = self.model(sample, is_train=True)
                loss = output['loss']
                loss.backward()
                self.optimizer.step()

                pbar.set_description(output['description'](
                    epoch, self.hyper.epoch_num))

            self.save_model(epoch)
            if epoch > 3:
                self.evaluation()

            #if epoch >= 6:
            #    self.evaluation()
            '''
Example #7
0
    def evaluation(self):
        dev_set = Selection_Dataset(self.hyper, self.hyper.dev)
        loader = Selection_loader(dev_set,
                                  batch_size=self.hyper.eval_batch,
                                  pin_memory=True)
        self.triplet_metrics.reset()
        self.model.eval()
        self.model._init_spo_search(SPO_searcher(self.hyper))

        pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader))

        self.save_rc.reset('dev.json')

        with torch.no_grad():
            for batch_ndx, sample in pbar:
                output = self.model(sample, is_train=False)
                """
                attn = self.model.attn.cpu()
                if batch_ndx == 41:
                    print(sample.text[0])
                    attn_multi_plot(attn,sample.text[0])
                """

                # distant search
                self.triplet_metrics(output['selection_triplets'],
                                     output['spo_gold'])
                self.ner_metrics(output['gold_tags'], output['decoded_tag'])
                #self.save_err.save(batch_ndx,output['selection_triplets'], output['spo_gold'])
                self.save_rc.save(output['selection_triplets'])

            triplet_result = self.triplet_metrics.get_metric()
            ner_result = self.ner_metrics.get_metric()
            print('Triplets-> ' + ', '.join([
                "%s: %.4f" % (name[0], value)
                for name, value in triplet_result.items()
                if not name.startswith("_")
            ]) + ' ||' + 'NER->' + ', '.join([
                "%s: %.4f" % (name[0], value)
                for name, value in ner_result.items()
                if not name.startswith("_")
            ]))
Example #8
0
    def train_pso(self, epoch=None):
        train_set = Selection_Dataset(self.hyper, self.hyper.train)
        loader = Selection_loader(train_set,
                                  batch_size=self.hyper.train_batch,
                                  pin_memory=True)

        if not epoch:
            epoch = 0

        num_train_steps = int(
            len(loader) / self.hyper.train_batch *
            (self.hyper.epoch_num - epoch + 1))
        num_warmup_steps = int(num_train_steps * self.hyper.warmup_prop)

        self.scheduler = self._scheduler(self.optimizer, num_warmup_steps,
                                         num_train_steps)

        while epoch <= self.hyper.epoch_num:
            pbar = tqdm(enumerate(BackgroundGenerator(loader)),
                        total=len(loader))

            for batch_idx, sample in pbar:
                self.model.train()
                output = self.model(sample, is_train=True)
                loss = output['loss']
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

                self.optimizer.step()
                self.scheduler.step()
                self.model.zero_grad()

                pbar.set_description(output['description'](
                    epoch, self.hyper.epoch_num))

            self.save_model(epoch)
            epoch += 1
Example #9
0
    def train(self, epoch=None):
        train_set = Selection_Dataset(self.hyper, self.hyper.train)
        #sampler = torch.utils.data.sampler.WeightedRandomSampler(train_set.weight,len(train_set.weight))
        loader = Selection_loader(train_set,
                                  batch_size=self.hyper.train_batch,
                                  pin_memory=True)
        #loader = Selection_loader(train_set, batch_size=self.hyper.train_batch,sampler = sampler, pin_memory=True)

        if not epoch:
            epoch = 0

        while epoch <= self.hyper.epoch_num:
            #for epoch in range(self.hyper.epoch_num):
            self.model.train()
            pbar = tqdm(enumerate(BackgroundGenerator(loader)),
                        total=len(loader))

            for batch_idx, sample in pbar:

                self.optimizer.zero_grad()
                output = self.model(sample, is_train=True)
                loss = output['loss']
                loss.backward()
                self.optimizer.step()

                pbar.set_description(output['description'](
                    epoch, self.hyper.epoch_num))

            self.save_model(epoch)
            """
            if epoch % self.hyper.print_epoch == 0 and epoch > 3:
                if self.model_name == 'selection':
                    self.evaluation()
                elif self.model_name == 'pso_1':
                    self.evaluation_pso()
            """
            epoch += 1
Example #10
0
    def evaluation(self):
        dev_set = Selection_Dataset(self.hyper, self.hyper.dev)
        loader = Selection_loader(dev_set,
                                  batch_size=self.hyper.eval_batch,
                                  pin_memory=True)
        self.triplet_metrics.reset()
        self.model.eval()

        pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader))

        with torch.no_grad():
            with open('./output/sub00.csv', 'w') as file:
                id = 0
                for batch_ndx, sample in pbar:
                    tokens = sample.tokens_id.to(self.device)
                    selection_gold = sample.selection_id.to(self.device)
                    bio_gold = sample.bio_id.to(self.device)
                    text_list = sample.text
                    spo_gold = sample.spo_gold
                    bio_text = sample.bio

                    output = self.model(sample, is_train=False)

                    self.triplet_metrics(output['selection_triplets'],
                                         output['spo_gold'])
                    self.ner_metrics(output['gold_tags'],
                                     output['decoded_tag'])

                    for i in range(len(output['decoded_tag'])):
                        file.write(str(8001 + id) + ',')
                        if len(output['selection_triplets'][i]) != 0:
                            file.write(output['selection_triplets'][i][0]
                                       ['predicate'] + ',')
                            file.write(
                                output['selection_triplets'][i][0]['subject'] +
                                ',')
                            file.write(
                                output['selection_triplets'][i][0]['object'] +
                                '\n')
                        else:
                            if output['decoded_tag'][i].count('B') < 2:
                                file.write('Other' + ',' + 'Other' + ',' +
                                           'Other')
                            else:
                                BIO = output['decoded_tag'][i]
                                tt = ''.join(reversed(BIO))
                                index1 = BIO.index('B')
                                index2 = len(tt) - tt.index('B') - 1
                                file.write('Other' + ',' +
                                           text_list[i][index2] + ',' +
                                           text_list[i][index1])
                            file.write('\n')
                        id += 1
                        # file.write('sentence {} BIO:\n'.format(i))
                        # for j in range(len(text_list[i])):
                        #     file.write(text_list[i][j]+' ')
                        # file.write('\n')
                        # file.writelines(bio_text[i])
                        # file.write('\n')
                        #
                        # file.writelines(output['decoded_tag'][i])
                        # file.write('\n')
                        # file.writelines(output['gold_tags'][i])
                        # file.write('\n')
                        # file.write('sentence {} relation:\n'.format(i))
                        # file.write('\n')
                        # if len(output['selection_triplets']) == 0:
                        #     file.write('empty')
                        # else:
                        #     file.writelines(str(output['selection_triplets'][i]))
                        # file.write('\n')
                        # file.writelines(str(output['spo_gold'][i]))
                        # file.write('\n')

            triplet_result = self.triplet_metrics.get_metric()
            ner_result = self.ner_metrics.get_metric()
            # print('triplet_result=', triplet_result)
            # print('ner_result=', ner_result)

            print('Triplets-> ' + ', '.join([
                "%s: %.4f" % (name[0], value)
                for name, value in triplet_result.items()
                if not name.startswith("_")
            ]) + ' ||' + 'NER->' + ', '.join([
                "%s: %.4f" % (name[0], value)
                for name, value in ner_result.items()
                if not name.startswith("_")
            ]))
Example #11
0
    def evaluation_pso(self):
        dev_set = Selection_Dataset(self.hyper, self.hyper.dev)
        loader = Selection_loader(dev_set,
                                  batch_size=self.hyper.eval_batch,
                                  pin_memory=True)
        self.p_metrics.reset()
        self.triplet_metrics.reset()
        self.model.eval()
        if self.model_name == 'pso_2':
            self.model_p.eval()
        all_labels, all_logits = [], []
        self.model._init_spo_search(SPO_searcher(self.hyper))

        pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader))

        if self.model_name == 'pso_1':
            with torch.no_grad():
                for batch_ndx, sample in pbar:
                    output = self.model(sample, is_train=False)
                    """
                    if batch_ndx == 44:
                        attn = self.model.attn.cpu()
                        print(attn)
                        attn_plot(attn,sample.text[0])
                    """

                    self.p_metrics(output['p_decode'], output['p_golden'])
                    all_labels.extend(output['p_golden'])
                    all_logits.extend(output['p_logits'])
                """
                with open('labels.txt','w',encoding = 'utf-8') as f:
                    for l in all_labels:
                        f.write(' '.join(l))
                        f.write('\n')
                
                with open('scores.txt','w',encoding ='utf-8') as t:
                    for l in all_logits:
                        t.write(' '.join([str(i) for i in l]))
                        t.write('\n')
                """

                p_result = self.p_metrics.get_metric()

                print('P-> ' + ', '.join([
                    "%s: %.4f" % (name[0], value)
                    for name, value in p_result.items()
                    if not name.startswith("_")
                ]))

                #roc_auc_class(all_labels,all_logits)

        elif self.model_name == 'pso_2':

            self.save_rc.reset('dev.json')

            with torch.no_grad():
                for batch_ndx, sample in pbar:
                    output_p = self.model_p(sample, is_train=False)
                    self.p_metrics(output_p['p_decode'], output_p['p_golden'])
                    all_labels.extend(output_p['p_decode'])
                    all_logits.extend(output_p['p_golden'])

                    dev_set.schema_transformer(output_p, sample)
                    output = self.model(sample, is_train=False)
                    """
                    if batch_ndx == 63:
                        print(sample.text[0])
                        
                        attn = self.model.attn[2].cpu()
                        attn_pso_plot_sub(attn,sample.text[0])

                        for id in range(9,12):
                            attn = self.model.attn[id-9].cpu()
                            attn_pso_plot(attn,sample.text[0],id)
                            #attn_pso_plot_stack(attn,sample.text[0],id)
                        
                        attn = self.model.attn[-1].cpu()
                        attn_pso_plot_sub(attn,sample.text[0])
                        """
                    self.triplet_metrics(output['selection_triplets'],
                                         output['spo_gold'])
                    self.save_err.save(batch_ndx, output['selection_triplets'],
                                       output['spo_gold'])
                    self.save_rc.save(output['selection_triplets'])

                p_result = self.p_metrics.get_metric()
                triplet_result = self.triplet_metrics.get_metric()

                print('Triplets-> ' + ', '.join([
                    "%s: %.4f" % (name[0], value)
                    for name, value in triplet_result.items()
                    if not name.startswith("_")
                ]))

                print('P-> ' + ', '.join([
                    "%s: %.4f" % (name[0], value)
                    for name, value in p_result.items()
                    if not name.startswith("_")
                ]))