def _train_pairwise_(self, data, optimizer, loss_fn, device): ''' Given data, passes it through model, inited in constructor, returns loss and updates the weight :params data: {batch of question, pos paths, neg paths and dummy y labels} :params optimizer: torch.optim object :params loss fn: torch.nn loss object :params device: torch.device object returns loss ''' self.encoder.train() assert 'ques_dep_batch' in data dep_batch, dep_mask_batch,\ pos_1_batch, pos_2_batch, pos_3_batch, pos_4_batch, \ neg_1_batch, neg_2_batch, neg_3_batch, neg_4_batch, y_label = data['ques_dep_batch'], data['ques_dep_mask_batch'],\ data['pos_rel1_batch'], data['pos_rel2_batch'],data['pos_rel3_batch'], data['pos_rel4_batch'], \ data['neg_rel1_batch'], data['neg_rel2_batch'],data['neg_rel3_batch'], data['neg_rel4_batch'], \ data['y_label'] optimizer.zero_grad() pos_2_batch = tu.no_one_left_behind(pos_2_batch) pos_3_batch = tu.no_one_left_behind(pos_3_batch) pos_4_batch = tu.no_one_left_behind(pos_4_batch) neg_2_batch = tu.no_one_left_behind(neg_2_batch) neg_3_batch = tu.no_one_left_behind(neg_3_batch) neg_4_batch = tu.no_one_left_behind(neg_4_batch) dep_batch = tu.trim(dep_batch) pos_scores = self.encoder(dep_batch, dep_mask_batch, tu.trim(pos_1_batch), tu.trim(pos_2_batch), tu.trim(pos_3_batch), tu.trim(pos_4_batch)) neg_scores = self.encoder(tu.trim(dep_batch), dep_mask_batch, tu.trim(neg_1_batch), tu.trim(neg_2_batch), tu.trim(neg_3_batch), tu.trim(neg_4_batch)) loss = loss_fn(pos_scores, neg_scores, y_label) loss.backward() torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), .5) # torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), .5, norm_type=2) #梯度裁剪原理:既然在BP过程中会产生梯度消失. # 那么最简单粗暴的方法,设定阈值,当梯度小于阈值时,更新的梯度为阈值 optimizer.step() return loss
def _train_pairwise_(self, data, optimizer, loss_fn, device): ''' Given data, passes it through model, inited in constructor, returns loss and updates the weight :params data: {batch of question, pos paths, neg paths and dummy y labels} :params optimizer: torch.optim object :params loss fn: torch.nn loss object :params device: torch.device object returns loss ''' self.encoder.train() # Unpacking the data and model from args ques_batch, pos_batch, pos_batch_words,neg_batch,neg_batch_words, y_label \ = data['ques_batch'], data['pos_batch'],data['pos_batch_words'], \ data['neg_batch'],data['neg_batch_words'],\ data['y_label'] optimizer.zero_grad() ques_batch = tu.trim(ques_batch) pos_batch_words = tu.trim(pos_batch_words) neg_batch_words = tu.trim(neg_batch_words) pos_scores = self.encoder(ques_batch, pos_batch_words) neg_scores = self.encoder(ques_batch, neg_batch_words) loss = loss_fn(pos_scores, neg_scores, y_label) loss.backward() torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), .5) optimizer.step() return loss
def predict(self, question, question_dep = None, paths = None, paths_words = None, paths_rel1=None, paths_rel2=None, paths_rel3=None, paths_rel4=None, attention_value=False, device=None): # question = question, question_dep = question_dep, # paths = paths, paths_words = paths_words, # paths_rel1 = paths_rel1, paths_rel2 = paths_rel2, paths_rel3 = paths_rel3, paths_rel4 = paths_rel4, device = device """ Same code works for both pairwise or pointwise """ with torch.no_grad(): self.encoder_q.eval() self.encoder_p.eval() # Have to manually check if the 2nd paths holds anything in this batch. # If not, we have to pad everything up with zeros, or even call a limited part of the comparison module. paths_rel2 = tu.no_one_left_behind(paths_rel2) paths_rel3 = tu.no_one_left_behind(paths_rel3) paths_rel4 = tu.no_one_left_behind(paths_rel4) # Encoding all the data ques_encoded,attention_score = self.encoder_q(tu.trim(question)) # print(paths_rel1) path_encoded = self.encoder_p(tu.trim(paths_rel1), tu.trim(paths_rel2),tu.trim(paths_rel3), tu.trim(paths_rel4)) # Pass them to the comparison module score = torch.sum(ques_encoded * path_encoded, dim=-1) self.encoder_q.train() self.encoder_p.train() if attention_value: return score,attention_score else: return score
def predict(self, question, question_dep=None, paths=None, paths_words=None, paths_rel1=None, paths_rel2=None, paths_rel3=None, paths_rel4=None, attention_value=False, device=None): """ Same code works for both pairwise or pointwise """ with torch.no_grad(): self.encoder.eval() paths_rel2 = tu.no_one_left_behind(paths_rel2) paths_rel3 = tu.no_one_left_behind(paths_rel3) paths_rel4 = tu.no_one_left_behind(paths_rel4) score = self.encoder(tu.trim(question), tu.trim(paths_rel1), tu.trim(paths_rel2),tu.trim(paths_rel3),tu.trim(paths_rel4)) self.encoder.train() return score
def predict(self, question=None, question_dep=None, question_dep_mask=None, paths=None, paths_words=None, paths_rel1=None, paths_rel2=None, paths_rel3=None, paths_rel4=None, attention_value=False, device=None): """ Same code works for both pairwise or pointwise """ with torch.no_grad(): self.encoder.eval() paths_rel2 = tu.no_one_left_behind(paths_rel2) paths_rel3 = tu.no_one_left_behind(paths_rel3) paths_rel4 = tu.no_one_left_behind(paths_rel4) score = self.encoder(tu.trim(question), question_dep, question_dep_mask, tu.trim(paths_rel1), tu.trim(paths_rel2),tu.trim(paths_rel3),tu.trim(paths_rel4)) # score = score.squeeze() # v0.2 in order to solve: mrr = mrr_output.index(positive_path_index) + 1.0 self.encoder.train() return score
def _train_pairwise_(self, data, optimizer, loss_fn, device): ques_batch, pos_1_batch, pos_2_batch, pos_3_batch, pos_4_batch, neg_1_batch, neg_2_batch, neg_3_batch, neg_4_batch, y_label = \ data['ques_batch'], data['pos_rel1_batch'], data['pos_rel2_batch'], data['pos_rel3_batch'], data[ 'pos_rel4_batch'], \ data['neg_rel1_batch'], data['neg_rel2_batch'], data['neg_rel3_batch'], data['neg_rel4_batch'], data[ 'y_label'] # ques_batch, pos_1_batch, pos_2_batch, neg_1_batch, neg_2_batch, y_label = \ # data['ques_batch'], data['pos_rel1_batch'], data['pos_rel2_batch'], \ # data['neg_rel1_batch'], data['neg_rel2_batch'], data['y_label'] optimizer.zero_grad() # Have to manually check if the 2nd paths holds anything in this batch. # If not, we have to pad everything up with zeros, or even call a limited part of the comparison module. pos_2_batch = tu.no_one_left_behind(pos_2_batch) pos_3_batch = tu.no_one_left_behind(pos_3_batch) pos_4_batch = tu.no_one_left_behind(pos_4_batch) neg_2_batch = tu.no_one_left_behind(neg_2_batch) neg_3_batch = tu.no_one_left_behind(neg_3_batch) neg_4_batch = tu.no_one_left_behind(neg_4_batch) # assert torch.mean((torch.sum(pos_2_batch, dim=-1) != 0).float()) == 1 # assert torch.mean((torch.sum(pos_1_batch, dim=-1) != 0).float()) == 1 # Encoding all the data ques_encoded,_ = self.encoder_q(tu.trim(ques_batch)) pos_encoded = self.encoder_p(tu.trim(pos_1_batch), tu.trim(pos_2_batch),tu.trim(pos_3_batch), tu.trim(pos_4_batch)) neg_encoded = self.encoder_p(tu.trim(neg_1_batch), tu.trim(neg_2_batch),tu.trim(neg_3_batch), tu.trim(neg_4_batch)) # Pass them to the comparison module pos_scores = torch.sum(ques_encoded * pos_encoded, dim=-1) neg_scores = torch.sum(ques_encoded * neg_encoded, dim=-1) ''' If `y == 1` then it assumed the first input should be ranked higher (have a larger value) than the second input, and vice-versa for `y == -1` ''' loss = loss_fn(pos_scores, neg_scores, y_label) loss.backward() # torch.nn.utils.clip_grad_norm_(self.encoder_q.parameters(), .5) # torch.nn.utils.clip_grad_norm_(self.encoder_p.parameters(), .5) optimizer.step() return loss