Example #1
0
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, lus=None, senses=None, args=None, using_gold_fame=False, position_ids=None, head_mask=None):
        sequence_output, pooled_output = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=False)
        sequence_output = self.dropout(sequence_output)
        pooled_output = self.dropout(pooled_output)
        
        sense_logits = self.sense_classifier(pooled_output)
        arg_logits = self.arg_classifier(sequence_output)        

        lufr_masks = frameBERT_utils.get_masks(lus, self.lufrmap, num_label=self.num_senses, masking=self.masking).to(device)
        
        sense_loss = 0 # loss for sense id
        arg_loss = 0 # loss for arg id
        
        if senses is not None:
            for i in range(len(sense_logits)):
                sense_logit = sense_logits[i]
                arg_logit = arg_logits[i]                

                lufr_mask = lufr_masks[i]
                    
                gold_sense = senses[i]
                gold_arg = args[i]
                
                #train sense classifier
                loss_fct_sense = CrossEntropyLoss(weight = lufr_mask)
                loss_per_seq_for_sense = loss_fct_sense(sense_logit.view(-1, self.num_senses), gold_sense.view(-1))
                sense_loss += loss_per_seq_for_sense
                
                #train arg classifier
                masked_sense_logit = frameBERT_utils.masking_logit(sense_logit, lufr_mask)
                pred_sense, sense_score = frameBERT_utils.logit2label(masked_sense_logit)

                frarg_mask = frameBERT_utils.get_masks([pred_sense], self.frargmap, num_label=self.num_args, masking=True).to(device)[0]                
                loss_fct_arg = CrossEntropyLoss(weight = frarg_mask)

                
                # only keep active parts of loss
                if attention_mask is not None:
                    active_loss = attention_mask[i].view(-1) == 1
                    active_logits = arg_logit.view(-1, self.num_args)[active_loss]
                    active_labels = gold_arg.view(-1)[active_loss]
                    loss_per_seq_for_arg = loss_fct_arg(active_logits, active_labels)
                else:
                    loss_per_seq_for_arg = loss_fct_arg(arg_logit.view(-1, self.num_args), gold_arg.view(-1))
                arg_loss += loss_per_seq_for_arg

            total_loss = 0.5*sense_loss + 0.5*arg_loss
            loss = total_loss / len(sense_logits)
            
            if self.return_pooled_output:
                return pooled_output, loss
            else:
                return loss
        else:
            if self.return_pooled_output:
                return pooled_output, sense_logits, arg_logits
            else:
                return sense_logits, arg_logits
Example #2
0
 def forward(self, input_ids, token_type_ids=None, attention_mask=None, lus=None, senses=None, args=None, using_gold_fame=False, position_ids=None, head_mask=None):
     
     sense_loss = 0 # loss for sense id
     arg_loss = 0 # loss for arg id
     
     sequence_output, pooled_output = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
     sequence_output = self.dropout(sequence_output)
     
     if self.joint:
         pooled_output = self.dropout(pooled_output)
         sense_logits = self.sense_classifier(pooled_output)
         lufr_masks = frameBERT_utils.get_masks(lus, self.lufrmap, num_label=self.num_senses, masking=self.masking).to(device)
         masked_sense_logits = sense_logits * lufr_masks
         
     arg_logits = self.arg_classifier(sequence_output)
     
     # train frame identifier
     if self.joint:
         if senses is not None:
             loss_fct_sense = CrossEntropyLoss()
             loss_sense = loss_fct_sense(masked_sense_logits.view(-1, self.num_senses), senses.view(-1))
     
     # train arg classifier
     if senses is not None:
         loss_fct_arg = CrossEntropyLoss()        
         if attention_mask is not None:
             active_loss = attention_mask.view(-1) == 1
             active_logits = arg_logits.view(-1, self.num_args)[active_loss]
             active_labels = args.view(-1)[active_loss]
             loss_arg = loss_fct_arg(active_logits, active_labels)
         else:
             loss_arg = loss_fct_arg(arg_logits.view(-1, self.num_args), args.view(-1))           
     
     
     if senses is not None:
         
         # joint vs only argument identification
         if self.joint:
             loss = 0.5*loss_sense + 0.5*loss_arg
         else:
             loss = loss_arg
             
         if self.return_pooled_output:
             return pooled_output, loss
         else:
             return loss
     else:
         if self.return_pooled_output:
             return pooled_output, masked_sense_logits, arg_logits
         else:
             return masked_sense_logits, arg_logits
Example #3
0
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, lus=None, senses=None, args=None, using_gold_fame=False, position_ids=None, head_mask=None):
        sequence_output, pooled_output = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=False)
        pooled_output = self.dropout(pooled_output)
        
        sense_logits = self.sense_classifier(pooled_output)      

        lufr_masks = frameBERT_utils.get_masks(lus, self.lufrmap, num_label=self.num_senses, masking=self.masking).to(device)
        
        sense_loss = 0 # loss for sense id
        arg_loss = 0 # loss for arg id
        
        if senses is not None:
            for i in range(len(sense_logits)):
                sense_logit = sense_logits[i]            

                lufr_mask = lufr_masks[i]
                    
                gold_sense = senses[i]
                gold_arg = args[i]
                
                #train sense classifier
                loss_fct_sense = CrossEntropyLoss(weight = lufr_mask)
                loss_per_seq_for_sense = loss_fct_sense(sense_logit.view(-1, self.num_senses), gold_sense.view(-1))
                sense_loss += loss_per_seq_for_sense

            total_loss = sense_loss
            loss = total_loss / len(sense_logits)
            
            if self.return_pooled_output:
                return pooled_output, loss
            else:
                return loss
        else:
            if self.return_pooled_output:
                return pooled_output, sense_logits, arg_logits
            else:
                return sense_logits, arg_logits
Example #4
0
    def parser(self,
               input_d,
               sent_id=False,
               result_format=False,
               frame_candis=5):
        input_conll = dataio.preprocessor(input_d)

        #target identification
        if self.gold_pred:
            if len(input_conll[0]) == 2:
                pass
            else:
                input_conll = [input_conll]
            tgt_data = input_conll
        else:
            if self.srl == 'framenet':
                tgt_conll = self.targetid.target_id(input_conll)
            else:
                tgt_conll = self.targetid.pred_id(input_conll)

            # add <tgt> and </tgt> to target word
            tgt_data = dataio.data2tgt_data(tgt_conll, mode='parse')

        if tgt_data:

            # convert conll to bert inputs
            bert_inputs = self.bert_io.convert_to_bert_input_JointShallowSemanticParsing(
                tgt_data)
            dataloader = DataLoader(bert_inputs, sampler=None, batch_size=1)

            pred_senses, pred_args = [], []
            sense_candis_list = []
            for batch in dataloader:
                #                 torch.cuda.set_device(device)
                batch = tuple(t.to(device) for t in batch)
                b_input_ids, b_orig_tok_to_maps, b_lus, b_token_type_ids, b_masks = batch

                with torch.no_grad():
                    tmp_eval_loss = self.model(b_input_ids,
                                               lus=b_lus,
                                               token_type_ids=b_token_type_ids,
                                               attention_mask=b_masks)
                    sense_logits, arg_logits = self.model(
                        b_input_ids,
                        lus=b_lus,
                        token_type_ids=b_token_type_ids,
                        attention_mask=b_masks)

                lufr_masks = frameBERT_utils.get_masks(
                    b_lus,
                    self.bert_io.lufrmap,
                    num_label=len(self.bert_io.sense2idx),
                    masking=self.masking).to(device)

                b_input_ids_np = b_input_ids.detach().cpu().numpy()
                arg_logits_np = arg_logits.detach().cpu().numpy()

                b_input_ids, arg_logits = [], []

                for b_idx in range(len(b_orig_tok_to_maps)):
                    orig_tok_to_map = b_orig_tok_to_maps[b_idx]
                    bert_token = self.bert_io.tokenizer.convert_ids_to_tokens(
                        b_input_ids_np[b_idx])
                    tgt_idx = frameBERT_utils.get_tgt_idx(bert_token,
                                                          tgt=self.tgt)

                    input_id, sense_logit, arg_logit = [], [], []

                    for idx in orig_tok_to_map:
                        if idx != -1:
                            if idx not in tgt_idx:
                                try:
                                    input_id.append(b_input_ids_np[b_idx][idx])
                                    arg_logits_np[b_idx][idx][1] = np.NINF
                                    arg_logit.append(arg_logits_np[b_idx][idx])
                                except KeyboardInterrupt:
                                    raise
                                except:
                                    pass

                    b_input_ids.append(input_id)
                    arg_logits.append(arg_logit)

                b_input_ids = torch.Tensor(b_input_ids).to(device)
                arg_logits = torch.Tensor(arg_logits).to(device)

                for b_idx in range(len(sense_logits)):
                    input_id = b_input_ids[b_idx]
                    sense_logit = sense_logits[b_idx]
                    arg_logit = arg_logits[b_idx]

                    lufr_mask = lufr_masks[b_idx]
                    masked_sense_logit = frameBERT_utils.masking_logit(
                        sense_logit, lufr_mask)
                    pred_sense, sense_score = frameBERT_utils.logit2label(
                        masked_sense_logit)

                    sense_candis = frameBERT_utils.logit2candis(
                        masked_sense_logit,
                        candis=frame_candis,
                        idx2label=self.bert_io.idx2sense)
                    sense_candis_list.append(sense_candis)

                    if self.srl == 'framenet':
                        arg_logit_np = arg_logit.detach().cpu().numpy()
                        arg_logit = []
                        frarg_mask = frameBERT_utils.get_masks(
                            [pred_sense],
                            self.bert_io.bio_frargmap,
                            num_label=len(self.bert_io.bio_arg2idx),
                            masking=True).to(device)[0]
                        for logit in arg_logit_np:
                            masked_logit = frameBERT_utils.masking_logit(
                                logit, frarg_mask)
                            arg_logit.append(np.array(masked_logit))
                        arg_logit = torch.Tensor(arg_logit).to(device)
                    else:
                        pass

                    pred_arg = []
                    for logit in arg_logit:
                        label, score = frameBERT_utils.logit2label(logit)
                        pred_arg.append(int(label))

                    pred_senses.append([int(pred_sense)])
                    pred_args.append(pred_arg)

            pred_sense_tags = [
                self.bert_io.idx2sense[p_i] for p in pred_senses for p_i in p
            ]
            if self.srl == 'framenet':
                pred_arg_tags = [[self.bert_io.idx2bio_arg[p_i] for p_i in p]
                                 for p in pred_args]
            elif self.srl == 'framenet-argid':
                pred_arg_tags = [[
                    self.bert_io.idx2bio_argument[p_i] for p_i in p
                ] for p in pred_args]
            else:
                pred_arg_tags = [[self.bert_io.idx2bio_arg[p_i] for p_i in p]
                                 for p in pred_args]

            conll_result = []

            for i in range(len(pred_arg_tags)):

                raw = tgt_data[i]

                conll, toks, lus = [], [], []
                for idx in range(len(raw[0])):
                    tok, lu = raw[0][idx], raw[1][idx]
                    if tok == '<tgt>' or tok == '</tgt>':
                        pass
                    else:
                        toks.append(tok)
                        lus.append(lu)
                conll.append(toks)
                conll.append(lus)

                sense_seq = ['_' for i in range(len(conll[1]))]
                for idx in range(len(conll[1])):
                    if conll[1][idx] != '_':
                        sense_seq[idx] = pred_sense_tags[i]

                conll.append(sense_seq)
                conll.append(pred_arg_tags[i])

                conll_result.append(conll)
        else:
            conll_result = []

        result = []
        if result_format == 'all':
            result = {}
            result['conll'] = conll_result

            if conll_result:
                textae = conll2textae.get_textae(conll_result)
                frdf = dataio.frame2rdf(conll_result, sent_id=sent_id)
                topk = dataio.topk(conll_result, sense_candis_list)
            else:
                textae = []
                frdf = []
                topk = {}
            result['textae'] = textae
            result['graph'] = frdf
            result['topk'] = topk
        elif result_format == 'textae':
            if conll_result:
                textae = conll2textae.get_textae(conll_result)
            else:
                textae = []
            result = textae
        elif result_format == 'graph':
            if conll_result:
                frdf = dataio.frame2rdf(conll_result,
                                        sent_id=sent_id,
                                        language=self.language)
            else:
                frdf = []
            result = frdf
        elif result_format == 'topk':
            if conll_result:
                topk = dataio.topk(conll_result, sense_candis_list)
            else:
                topk = {}
            result = topk
        else:
            result = conll_result

        return result