def forward(self, q_ids=None, char_ids=None, word_ids=None, token_type_ids=None, subject_ids=None, subject_labels=None, object_labels=None, eval_file=None, is_eval=False): mask = char_ids != 0 seq_mask = char_ids.eq(0) char_emb = self.char_emb(char_ids) word_emb = self.word_convert_char(self.word_emb(word_ids)) # word_emb = self.word_emb(word_ids) emb = char_emb + word_emb # emb = char_emb # subject_encoder = sent_encoder + self.token_entity_emb(token_type_id) sent_encoder = self.first_sentence_encoder(emb, seq_mask) ent_emission = self.ent_emission(sent_encoder) if not is_eval: # subject_encoder = self.token_entity_emb(token_type_ids) # context_encoder = bert_encoder + subject_encoder sub_start_encoder = batch_gather(sent_encoder, subject_ids[:, 0]) sub_end_encoder = batch_gather(sent_encoder, subject_ids[:, 1]) subject = torch.cat([sub_start_encoder, sub_end_encoder], 1) context_encoder = self.LayerNorm(sent_encoder, subject) context_encoder = self.transformer_encoder(context_encoder.transpose(1, 0), src_key_padding_mask=seq_mask).transpose(0, 1) ent_loss = -self.ent_crf(ent_emission, subject_labels, mask=mask, reduction='mean') emission = self.emission(context_encoder) po_loss = -self.crf(emission, object_labels, mask=mask, reduction='mean') loss = ent_loss + po_loss return loss else: subject_preds = self.ent_crf.decode(emissions=ent_emission, mask=mask) answer_list = list() for qid, sub_pred in zip(q_ids.cpu().numpy(), subject_preds): seq_len = min(len(eval_file[qid].context), self.max_len) tag_list = list() j = 0 while j < seq_len: end = j flag = True if sub_pred[j] == 1: start = j for k in range(start + 1, seq_len): if sub_pred[k] != sub_pred[start] + 1: end = k - 1 flag = False break if flag: end = seq_len - 1 tag_list.append((start, end)) j = end + 1 answer_list.append(tag_list) qid_ids, sent_encoders, pass_ids, subject_ids, token_type_ids = [], [], [], [], [] for i, subjects in enumerate(answer_list): if subjects: qid = q_ids[i].unsqueeze(0).expand(len(subjects)) pass_tensor = char_ids[i, :].unsqueeze(0).expand(len(subjects), char_ids.size(1)) new_sent_encoder = sent_encoder[i, :, :].unsqueeze(0).expand(len(subjects), sent_encoder.size(1), sent_encoder.size(2)) token_type_id = torch.zeros((len(subjects), char_ids.size(1)), dtype=torch.long) for index, (start, end) in enumerate(subjects): token_type_id[index, start:end + 1] = 1 qid_ids.append(qid) pass_ids.append(pass_tensor) subject_ids.append(torch.tensor(subjects, dtype=torch.long)) sent_encoders.append(new_sent_encoder) token_type_ids.append(token_type_id) if len(qid_ids) == 0: # print('len(qid_list)==0:') subject_ids = torch.zeros(1, 2).long().to(sent_encoder.device) qid_tensor = torch.tensor([-1], dtype=torch.long).to(sent_encoder.device) po_tensor = torch.zeros(1, sent_encoder.size(1)).long().to(sent_encoder.device) return qid_tensor, subject_ids, po_tensor qids = torch.cat(qid_ids).to(sent_encoder.device) pass_ids = torch.cat(pass_ids).to(sent_encoder.device) sent_encoders = torch.cat(sent_encoders).to(sent_encoder.device) # token_type_ids = torch.cat(token_type_ids).to(bert_encoder.device) subject_ids = torch.cat(subject_ids).to(sent_encoder.device) flag = False split_heads = 1024 sent_encoders_ = torch.split(sent_encoders, split_heads, dim=0) pass_ids_ = torch.split(pass_ids, split_heads, dim=0) # token_type_ids_ = torch.split(token_type_ids, split_heads, dim=0) subject_encoder_ = torch.split(subject_ids, split_heads, dim=0) po_preds = list() for i in range(len(subject_encoder_)): sent_encoders = sent_encoders_[i] # token_type_ids = token_type_ids_[i] pass_ids = pass_ids_[i] subject_encoder = subject_encoder_[i] if sent_encoders.size(0) == 1: flag = True sent_encoders = sent_encoders.expand(2, sent_encoders.size(1), sent_encoders.size(2)) subject_encoder = subject_encoder.expand(2, subject_encoder.size(1)) pass_ids = pass_ids.expand(2, pass_ids.size(1)) sub_start_encoder = batch_gather(sent_encoders, subject_encoder[:, 0]) sub_end_encoder = batch_gather(sent_encoders, subject_encoder[:, 1]) subject = torch.cat([sub_start_encoder, sub_end_encoder], 1) context_encoder = self.LayerNorm(sent_encoders, subject) context_encoder = self.transformer_encoder(context_encoder.transpose(1, 0), src_key_padding_mask=pass_ids.eq(0)).transpose(0, 1) emission = self.emission(context_encoder) po_pred = self.crf.decode(emissions=emission, mask=(pass_ids != 0)) max_len = pass_ids.size(1) temp_tag = copy.deepcopy(po_pred) for line in temp_tag: line.extend([0] * (max_len - len(line))) # TODO:check po_pred = torch.tensor(temp_tag).to(emission.device) if flag: po_pred = po_pred[1, :].unsqueeze(0) po_preds.append(po_pred) po_tensor = torch.cat(po_preds).to(qids.device) # print(subject_ids.device) # print(po_tensor.device) # print(qids.shape) # print(subject_ids.shape) # print(po_tensor.shape) return qids, subject_ids, po_tensor
def forward(self, q_ids=None, passage_ids=None, segment_ids=None, token_type_ids=None, subject_ids=None, subject_labels=None, object_labels=None, eval_file=None, is_eval=False): mask = (passage_ids != 0).float() bert_encoder_ = self.bert(passage_ids, token_type_ids=segment_ids, attention_mask=mask, output_all_encoded_layers=True)[0][-4:] bert_encoder_ = torch.cat([m.unsqueeze(0) for m in bert_encoder_]).mean(0) bert_encoder = self.lstm_encoder(bert_encoder_, mask=passage_ids.eq(0)) if not is_eval: sub_start_encoder = batch_gather(bert_encoder, subject_ids[:, 0]) sub_end_encoder = batch_gather(bert_encoder, subject_ids[:, 1]) subject = torch.cat([sub_start_encoder, sub_end_encoder], 1) context_encoder = self.LayerNorm(bert_encoder, subject) sub_preds = self.subject_dense(bert_encoder) po_preds = self.po_dense(context_encoder).reshape( passage_ids.size(0), -1, self.classes_num, 2) subject_loss = self.loss_fct(sub_preds, subject_labels) subject_loss = subject_loss.mean(2) subject_loss = torch.sum(subject_loss * mask.float()) / torch.sum( mask.float()) po_loss = self.loss_fct(po_preds, object_labels) po_loss = torch.sum(po_loss.mean(3), 2) po_loss = torch.sum(po_loss * mask.float()) / torch.sum( mask.float()) loss = subject_loss + po_loss return loss else: subject_preds = nn.Sigmoid()(self.subject_dense(bert_encoder)) answer_list = list() for qid, sub_pred in zip(q_ids.cpu().numpy(), subject_preds.cpu().numpy()): context = eval_file[qid].bert_tokens start = np.where(sub_pred[:, 0] > 0.5)[0] end = np.where(sub_pred[:, 1] > 0.5)[0] subjects = [] for i in start: j = end[end >= i] if i == 0 or i > len(context) - 2: continue if len(j) > 0: j = j[0] if j > len(context) - 2: continue subjects.append((i, j)) answer_list.append(subjects) qid_ids, bert_encoders, pass_ids, subject_ids, token_type_ids = [], [], [], [], [] for i, subjects in enumerate(answer_list): if subjects: qid = q_ids[i].unsqueeze(0).expand(len(subjects)) pass_tensor = passage_ids[i, :].unsqueeze(0).expand( len(subjects), passage_ids.size(1)) new_bert_encoder = bert_encoder[i, :, :].unsqueeze( 0).expand(len(subjects), bert_encoder.size(1), bert_encoder.size(2)) token_type_id = torch.zeros( (len(subjects), passage_ids.size(1)), dtype=torch.long) for index, (start, end) in enumerate(subjects): token_type_id[index, start:end + 1] = 1 qid_ids.append(qid) pass_ids.append(pass_tensor) subject_ids.append(torch.tensor(subjects, dtype=torch.long)) bert_encoders.append(new_bert_encoder) token_type_ids.append(token_type_id) if len(qid_ids) == 0: subject_ids = torch.zeros(1, 2).long().to(bert_encoder.device) qid_tensor = torch.tensor([-1], dtype=torch.long).to( bert_encoder.device) po_tensor = torch.zeros(1, bert_encoder.size(1)).long().to( bert_encoder.device) return qid_tensor, subject_ids, po_tensor qids = torch.cat(qid_ids).to(bert_encoder.device) pass_ids = torch.cat(pass_ids).to(bert_encoder.device) bert_encoders = torch.cat(bert_encoders).to(bert_encoder.device) subject_ids = torch.cat(subject_ids).to(bert_encoder.device) flag = False split_heads = 1024 bert_encoders_ = torch.split(bert_encoders, split_heads, dim=0) pass_ids_ = torch.split(pass_ids, split_heads, dim=0) subject_encoder_ = torch.split(subject_ids, split_heads, dim=0) po_preds = list() for i in range(len(bert_encoders_)): bert_encoders = bert_encoders_[i] pass_ids = pass_ids_[i] subject_encoder = subject_encoder_[i] if bert_encoders.size(0) == 1: flag = True bert_encoders = bert_encoders.expand( 2, bert_encoders.size(1), bert_encoders.size(2)) subject_encoder = subject_encoder.expand( 2, subject_encoder.size(1)) sub_start_encoder = batch_gather(bert_encoders, subject_encoder[:, 0]) sub_end_encoder = batch_gather(bert_encoders, subject_encoder[:, 1]) subject = torch.cat([sub_start_encoder, sub_end_encoder], 1) context_encoder = self.LayerNorm(bert_encoders, subject) po_pred = self.po_dense(context_encoder).reshape( subject_encoder.size(0), -1, self.classes_num, 2) if flag: po_pred = po_pred[1, :, :, :].unsqueeze(0) po_preds.append(po_pred) po_tensor = torch.cat(po_preds).to(qids.device) po_tensor = nn.Sigmoid()(po_tensor) return qids, subject_ids, po_tensor
def forward(self, q_ids=None, char_ids=None, word_ids=None, token_type_ids=None, subject_ids=None, subject_labels=None, object_labels=None, eval_file=None, is_eval=False): mask = char_ids != 0 seq_mask = char_ids.eq(0) char_emb = self.char_emb(char_ids) word_emb = self.word_convert_char(self.word_emb(word_ids)) # word_emb = self.word_emb(word_ids) emb = char_emb + word_emb # emb = char_emb # subject_encoder = sent_encoder + self.token_entity_emb(token_type_id) sent_encoder = self.first_sentence_encoder(emb, seq_mask) if not is_eval: # subject_encoder = self.token_entity_emb(token_type_ids) # context_encoder = bert_encoder + subject_encoder sub_start_encoder = batch_gather(sent_encoder, subject_ids[:, 0]) sub_end_encoder = batch_gather(sent_encoder, subject_ids[:, 1]) subject = torch.cat([sub_start_encoder, sub_end_encoder], 1) context_encoder = self.LayerNorm(sent_encoder, subject) context_encoder = self.transformer_encoder(context_encoder.transpose(1, 0), src_key_padding_mask=seq_mask).transpose(0, 1) sub_preds = self.subject_dense(sent_encoder) po_preds = self.po_dense(context_encoder).reshape(char_ids.size(0), -1, self.classes_num, 2) subject_loss = self.loss_fct(sub_preds, subject_labels) subject_loss = subject_loss.mean(2) subject_loss = torch.sum(subject_loss * mask.float()) / torch.sum(mask.float()) po_loss = self.loss_fct(po_preds, object_labels) po_loss = torch.sum(po_loss.mean(3), 2) po_loss = torch.sum(po_loss * mask.float()) / torch.sum(mask.float()) loss = subject_loss + po_loss return loss else: subject_preds = nn.Sigmoid()(self.subject_dense(sent_encoder)) answer_list = list() for qid, sub_pred in zip(q_ids.cpu().numpy(), subject_preds.cpu().numpy()): context = eval_file[qid].context start = np.where(sub_pred[:, 0] > 0.5)[0] end = np.where(sub_pred[:, 1] > 0.4)[0] subjects = [] for i in start: j = end[end >= i] if i >= len(context): continue if len(j) > 0: j = j[0] if j >= len(context): continue subjects.append((i, j)) answer_list.append(subjects) qid_ids, sent_encoders, pass_ids, subject_ids, token_type_ids = [], [], [], [], [] for i, subjects in enumerate(answer_list): if subjects: qid = q_ids[i].unsqueeze(0).expand(len(subjects)) pass_tensor = char_ids[i, :].unsqueeze(0).expand(len(subjects), char_ids.size(1)) new_sent_encoder = sent_encoder[i, :, :].unsqueeze(0).expand(len(subjects), sent_encoder.size(1), sent_encoder.size(2)) token_type_id = torch.zeros((len(subjects), char_ids.size(1)), dtype=torch.long) for index, (start, end) in enumerate(subjects): token_type_id[index, start:end + 1] = 1 qid_ids.append(qid) pass_ids.append(pass_tensor) subject_ids.append(torch.tensor(subjects, dtype=torch.long)) sent_encoders.append(new_sent_encoder) token_type_ids.append(token_type_id) if len(qid_ids) == 0: # print('len(qid_list)==0:') qid_tensor = torch.tensor([-1, -1], dtype=torch.long).to(sent_encoder.device) return qid_tensor, qid_tensor, qid_tensor # print('len(qid_list)!=========================0:') qids = torch.cat(qid_ids).to(sent_encoder.device) pass_ids = torch.cat(pass_ids).to(sent_encoder.device) sent_encoders = torch.cat(sent_encoders).to(sent_encoder.device) # token_type_ids = torch.cat(token_type_ids).to(bert_encoder.device) subject_ids = torch.cat(subject_ids).to(sent_encoder.device) flag = False split_heads = 1024 sent_encoders_ = torch.split(sent_encoders, split_heads, dim=0) pass_ids_ = torch.split(pass_ids, split_heads, dim=0) # token_type_ids_ = torch.split(token_type_ids, split_heads, dim=0) subject_encoder_ = torch.split(subject_ids, split_heads, dim=0) # print('len(qid_list)!=========================1:') po_preds = list() for i in range(len(subject_encoder_)): sent_encoders = sent_encoders_[i] # token_type_ids = token_type_ids_[i] pass_ids = pass_ids_[i] subject_encoder = subject_encoder_[i] if sent_encoders.size(0) == 1: flag = True # print('flag = True**********') sent_encoders = sent_encoders.expand(2, sent_encoders.size(1), sent_encoders.size(2)) subject_encoder = subject_encoder.expand(2, subject_encoder.size(1)) pass_ids = pass_ids.expand(2, pass_ids.size(1)) # print('len(qid_list)!=========================2:') sub_start_encoder = batch_gather(sent_encoders, subject_encoder[:, 0]) sub_end_encoder = batch_gather(sent_encoders, subject_encoder[:, 1]) subject = torch.cat([sub_start_encoder, sub_end_encoder], 1) context_encoder = self.LayerNorm(sent_encoders, subject) context_encoder = self.transformer_encoder(context_encoder.transpose(1, 0), src_key_padding_mask=pass_ids.eq(0)).transpose(0, 1) # print('len(qid_list)!=========================3') # context_encoder = self.LayerNorm(context_encoder) po_pred = self.po_dense(context_encoder).reshape(subject_encoder.size(0), -1, self.classes_num, 2) if flag: po_pred = po_pred[1, :, :, :].unsqueeze(0) po_preds.append(po_pred) po_tensor = torch.cat(po_preds).to(qids.device) po_tensor = nn.Sigmoid()(po_tensor) return qids, subject_ids, po_tensor