def forward(self, input_ids, token_type_ids=None, attention_mask=None, lus=None, senses=None, args=None, using_gold_fame=False, position_ids=None, head_mask=None): sequence_output, pooled_output = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=False) sequence_output = self.dropout(sequence_output) pooled_output = self.dropout(pooled_output) sense_logits = self.sense_classifier(pooled_output) arg_logits = self.arg_classifier(sequence_output) lufr_masks = frameBERT_utils.get_masks(lus, self.lufrmap, num_label=self.num_senses, masking=self.masking).to(device) sense_loss = 0 # loss for sense id arg_loss = 0 # loss for arg id if senses is not None: for i in range(len(sense_logits)): sense_logit = sense_logits[i] arg_logit = arg_logits[i] lufr_mask = lufr_masks[i] gold_sense = senses[i] gold_arg = args[i] #train sense classifier loss_fct_sense = CrossEntropyLoss(weight = lufr_mask) loss_per_seq_for_sense = loss_fct_sense(sense_logit.view(-1, self.num_senses), gold_sense.view(-1)) sense_loss += loss_per_seq_for_sense #train arg classifier masked_sense_logit = frameBERT_utils.masking_logit(sense_logit, lufr_mask) pred_sense, sense_score = frameBERT_utils.logit2label(masked_sense_logit) frarg_mask = frameBERT_utils.get_masks([pred_sense], self.frargmap, num_label=self.num_args, masking=True).to(device)[0] loss_fct_arg = CrossEntropyLoss(weight = frarg_mask) # only keep active parts of loss if attention_mask is not None: active_loss = attention_mask[i].view(-1) == 1 active_logits = arg_logit.view(-1, self.num_args)[active_loss] active_labels = gold_arg.view(-1)[active_loss] loss_per_seq_for_arg = loss_fct_arg(active_logits, active_labels) else: loss_per_seq_for_arg = loss_fct_arg(arg_logit.view(-1, self.num_args), gold_arg.view(-1)) arg_loss += loss_per_seq_for_arg total_loss = 0.5*sense_loss + 0.5*arg_loss loss = total_loss / len(sense_logits) if self.return_pooled_output: return pooled_output, loss else: return loss else: if self.return_pooled_output: return pooled_output, sense_logits, arg_logits else: return sense_logits, arg_logits
def forward(self, input_ids, token_type_ids=None, attention_mask=None, lus=None, senses=None, args=None, using_gold_fame=False, position_ids=None, head_mask=None): sense_loss = 0 # loss for sense id arg_loss = 0 # loss for arg id sequence_output, pooled_output = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) sequence_output = self.dropout(sequence_output) if self.joint: pooled_output = self.dropout(pooled_output) sense_logits = self.sense_classifier(pooled_output) lufr_masks = frameBERT_utils.get_masks(lus, self.lufrmap, num_label=self.num_senses, masking=self.masking).to(device) masked_sense_logits = sense_logits * lufr_masks arg_logits = self.arg_classifier(sequence_output) # train frame identifier if self.joint: if senses is not None: loss_fct_sense = CrossEntropyLoss() loss_sense = loss_fct_sense(masked_sense_logits.view(-1, self.num_senses), senses.view(-1)) # train arg classifier if senses is not None: loss_fct_arg = CrossEntropyLoss() if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = arg_logits.view(-1, self.num_args)[active_loss] active_labels = args.view(-1)[active_loss] loss_arg = loss_fct_arg(active_logits, active_labels) else: loss_arg = loss_fct_arg(arg_logits.view(-1, self.num_args), args.view(-1)) if senses is not None: # joint vs only argument identification if self.joint: loss = 0.5*loss_sense + 0.5*loss_arg else: loss = loss_arg if self.return_pooled_output: return pooled_output, loss else: return loss else: if self.return_pooled_output: return pooled_output, masked_sense_logits, arg_logits else: return masked_sense_logits, arg_logits
def forward(self, input_ids, token_type_ids=None, attention_mask=None, lus=None, senses=None, args=None, using_gold_fame=False, position_ids=None, head_mask=None): sequence_output, pooled_output = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=False) pooled_output = self.dropout(pooled_output) sense_logits = self.sense_classifier(pooled_output) lufr_masks = frameBERT_utils.get_masks(lus, self.lufrmap, num_label=self.num_senses, masking=self.masking).to(device) sense_loss = 0 # loss for sense id arg_loss = 0 # loss for arg id if senses is not None: for i in range(len(sense_logits)): sense_logit = sense_logits[i] lufr_mask = lufr_masks[i] gold_sense = senses[i] gold_arg = args[i] #train sense classifier loss_fct_sense = CrossEntropyLoss(weight = lufr_mask) loss_per_seq_for_sense = loss_fct_sense(sense_logit.view(-1, self.num_senses), gold_sense.view(-1)) sense_loss += loss_per_seq_for_sense total_loss = sense_loss loss = total_loss / len(sense_logits) if self.return_pooled_output: return pooled_output, loss else: return loss else: if self.return_pooled_output: return pooled_output, sense_logits, arg_logits else: return sense_logits, arg_logits
def parser(self, input_d, sent_id=False, result_format=False, frame_candis=5): input_conll = dataio.preprocessor(input_d) #target identification if self.gold_pred: if len(input_conll[0]) == 2: pass else: input_conll = [input_conll] tgt_data = input_conll else: if self.srl == 'framenet': tgt_conll = self.targetid.target_id(input_conll) else: tgt_conll = self.targetid.pred_id(input_conll) # add <tgt> and </tgt> to target word tgt_data = dataio.data2tgt_data(tgt_conll, mode='parse') if tgt_data: # convert conll to bert inputs bert_inputs = self.bert_io.convert_to_bert_input_JointShallowSemanticParsing( tgt_data) dataloader = DataLoader(bert_inputs, sampler=None, batch_size=1) pred_senses, pred_args = [], [] sense_candis_list = [] for batch in dataloader: # torch.cuda.set_device(device) batch = tuple(t.to(device) for t in batch) b_input_ids, b_orig_tok_to_maps, b_lus, b_token_type_ids, b_masks = batch with torch.no_grad(): tmp_eval_loss = self.model(b_input_ids, lus=b_lus, token_type_ids=b_token_type_ids, attention_mask=b_masks) sense_logits, arg_logits = self.model( b_input_ids, lus=b_lus, token_type_ids=b_token_type_ids, attention_mask=b_masks) lufr_masks = frameBERT_utils.get_masks( b_lus, self.bert_io.lufrmap, num_label=len(self.bert_io.sense2idx), masking=self.masking).to(device) b_input_ids_np = b_input_ids.detach().cpu().numpy() arg_logits_np = arg_logits.detach().cpu().numpy() b_input_ids, arg_logits = [], [] for b_idx in range(len(b_orig_tok_to_maps)): orig_tok_to_map = b_orig_tok_to_maps[b_idx] bert_token = self.bert_io.tokenizer.convert_ids_to_tokens( b_input_ids_np[b_idx]) tgt_idx = frameBERT_utils.get_tgt_idx(bert_token, tgt=self.tgt) input_id, sense_logit, arg_logit = [], [], [] for idx in orig_tok_to_map: if idx != -1: if idx not in tgt_idx: try: input_id.append(b_input_ids_np[b_idx][idx]) arg_logits_np[b_idx][idx][1] = np.NINF arg_logit.append(arg_logits_np[b_idx][idx]) except KeyboardInterrupt: raise except: pass b_input_ids.append(input_id) arg_logits.append(arg_logit) b_input_ids = torch.Tensor(b_input_ids).to(device) arg_logits = torch.Tensor(arg_logits).to(device) for b_idx in range(len(sense_logits)): input_id = b_input_ids[b_idx] sense_logit = sense_logits[b_idx] arg_logit = arg_logits[b_idx] lufr_mask = lufr_masks[b_idx] masked_sense_logit = frameBERT_utils.masking_logit( sense_logit, lufr_mask) pred_sense, sense_score = frameBERT_utils.logit2label( masked_sense_logit) sense_candis = frameBERT_utils.logit2candis( masked_sense_logit, candis=frame_candis, idx2label=self.bert_io.idx2sense) sense_candis_list.append(sense_candis) if self.srl == 'framenet': arg_logit_np = arg_logit.detach().cpu().numpy() arg_logit = [] frarg_mask = frameBERT_utils.get_masks( [pred_sense], self.bert_io.bio_frargmap, num_label=len(self.bert_io.bio_arg2idx), masking=True).to(device)[0] for logit in arg_logit_np: masked_logit = frameBERT_utils.masking_logit( logit, frarg_mask) arg_logit.append(np.array(masked_logit)) arg_logit = torch.Tensor(arg_logit).to(device) else: pass pred_arg = [] for logit in arg_logit: label, score = frameBERT_utils.logit2label(logit) pred_arg.append(int(label)) pred_senses.append([int(pred_sense)]) pred_args.append(pred_arg) pred_sense_tags = [ self.bert_io.idx2sense[p_i] for p in pred_senses for p_i in p ] if self.srl == 'framenet': pred_arg_tags = [[self.bert_io.idx2bio_arg[p_i] for p_i in p] for p in pred_args] elif self.srl == 'framenet-argid': pred_arg_tags = [[ self.bert_io.idx2bio_argument[p_i] for p_i in p ] for p in pred_args] else: pred_arg_tags = [[self.bert_io.idx2bio_arg[p_i] for p_i in p] for p in pred_args] conll_result = [] for i in range(len(pred_arg_tags)): raw = tgt_data[i] conll, toks, lus = [], [], [] for idx in range(len(raw[0])): tok, lu = raw[0][idx], raw[1][idx] if tok == '<tgt>' or tok == '</tgt>': pass else: toks.append(tok) lus.append(lu) conll.append(toks) conll.append(lus) sense_seq = ['_' for i in range(len(conll[1]))] for idx in range(len(conll[1])): if conll[1][idx] != '_': sense_seq[idx] = pred_sense_tags[i] conll.append(sense_seq) conll.append(pred_arg_tags[i]) conll_result.append(conll) else: conll_result = [] result = [] if result_format == 'all': result = {} result['conll'] = conll_result if conll_result: textae = conll2textae.get_textae(conll_result) frdf = dataio.frame2rdf(conll_result, sent_id=sent_id) topk = dataio.topk(conll_result, sense_candis_list) else: textae = [] frdf = [] topk = {} result['textae'] = textae result['graph'] = frdf result['topk'] = topk elif result_format == 'textae': if conll_result: textae = conll2textae.get_textae(conll_result) else: textae = [] result = textae elif result_format == 'graph': if conll_result: frdf = dataio.frame2rdf(conll_result, sent_id=sent_id, language=self.language) else: frdf = [] result = frdf elif result_format == 'topk': if conll_result: topk = dataio.topk(conll_result, sense_candis_list) else: topk = {} result = topk else: result = conll_result return result