コード例 #1
0
ファイル: network.py プロジェクト: nju-websoft/SkeletonKBQA
 def _train_pairwise_(self, data, optimizer, loss_fn, device):
     '''
         Given data, passes it through model, inited in constructor, returns loss and updates the weight
         :params data: {batch of question, pos paths, neg paths and dummy y labels}
         :params optimizer: torch.optim object
         :params loss fn: torch.nn loss object
         :params device: torch.device object
         returns loss
     '''
     self.encoder.train()
     assert 'ques_dep_batch' in data
     dep_batch, dep_mask_batch,\
     pos_1_batch, pos_2_batch, pos_3_batch, pos_4_batch, \
     neg_1_batch, neg_2_batch, neg_3_batch, neg_4_batch, y_label = data['ques_dep_batch'], data['ques_dep_mask_batch'],\
         data['pos_rel1_batch'], data['pos_rel2_batch'],data['pos_rel3_batch'], data['pos_rel4_batch'], \
         data['neg_rel1_batch'], data['neg_rel2_batch'],data['neg_rel3_batch'], data['neg_rel4_batch'], \
         data['y_label']
     optimizer.zero_grad()
     pos_2_batch = tu.no_one_left_behind(pos_2_batch)
     pos_3_batch = tu.no_one_left_behind(pos_3_batch)
     pos_4_batch = tu.no_one_left_behind(pos_4_batch)
     neg_2_batch = tu.no_one_left_behind(neg_2_batch)
     neg_3_batch = tu.no_one_left_behind(neg_3_batch)
     neg_4_batch = tu.no_one_left_behind(neg_4_batch)
     dep_batch = tu.trim(dep_batch)
     pos_scores = self.encoder(dep_batch, dep_mask_batch, tu.trim(pos_1_batch), tu.trim(pos_2_batch), tu.trim(pos_3_batch), tu.trim(pos_4_batch))
     neg_scores = self.encoder(tu.trim(dep_batch), dep_mask_batch, tu.trim(neg_1_batch), tu.trim(neg_2_batch), tu.trim(neg_3_batch), tu.trim(neg_4_batch))
     loss = loss_fn(pos_scores, neg_scores, y_label)
     loss.backward()
     torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), .5)
     # torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), .5, norm_type=2)
     #梯度裁剪原理:既然在BP过程中会产生梯度消失.
     # 那么最简单粗暴的方法,设定阈值,当梯度小于阈值时,更新的梯度为阈值
     optimizer.step()
     return loss
コード例 #2
0
ファイル: network.py プロジェクト: nju-websoft/SkeletonKBQA
 def _train_pairwise_(self, data, optimizer, loss_fn, device):
     '''
         Given data, passes it through model, inited in constructor, returns loss and updates the weight
         :params data: {batch of question, pos paths, neg paths and dummy y labels}
         :params optimizer: torch.optim object
         :params loss fn: torch.nn loss object
         :params device: torch.device object
         returns loss
     '''
     self.encoder.train()
     # Unpacking the data and model from args
     ques_batch, pos_batch, pos_batch_words,neg_batch,neg_batch_words, y_label \
         = data['ques_batch'], data['pos_batch'],data['pos_batch_words'], \
                                                 data['neg_batch'],data['neg_batch_words'],\
                                                 data['y_label']
     optimizer.zero_grad()
     ques_batch = tu.trim(ques_batch)
     pos_batch_words = tu.trim(pos_batch_words)
     neg_batch_words = tu.trim(neg_batch_words)
     pos_scores = self.encoder(ques_batch, pos_batch_words)
     neg_scores = self.encoder(ques_batch, neg_batch_words)
     loss = loss_fn(pos_scores, neg_scores, y_label)
     loss.backward()
     torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), .5)
     optimizer.step()
     return loss
コード例 #3
0
ファイル: network.py プロジェクト: nju-websoft/SkeletonKBQA
    def predict(self, question, question_dep = None, paths = None, paths_words = None,
                paths_rel1=None, paths_rel2=None, paths_rel3=None, paths_rel4=None,
                attention_value=False, device=None):
        # question = question, question_dep = question_dep,
        # paths = paths, paths_words = paths_words,
        # paths_rel1 = paths_rel1, paths_rel2 = paths_rel2, paths_rel3 = paths_rel3, paths_rel4 = paths_rel4, device = device
        """
            Same code works for both pairwise or pointwise
        """
        with torch.no_grad():
            self.encoder_q.eval()
            self.encoder_p.eval()

            # Have to manually check if the 2nd paths holds anything in this batch.
            # If not, we have to pad everything up with zeros, or even call a limited part of the comparison module.
            paths_rel2 = tu.no_one_left_behind(paths_rel2)
            paths_rel3 = tu.no_one_left_behind(paths_rel3)
            paths_rel4 = tu.no_one_left_behind(paths_rel4)

            # Encoding all the data
            ques_encoded,attention_score = self.encoder_q(tu.trim(question))
            # print(paths_rel1)
            path_encoded = self.encoder_p(tu.trim(paths_rel1), tu.trim(paths_rel2),tu.trim(paths_rel3), tu.trim(paths_rel4))

            # Pass them to the comparison module
            score = torch.sum(ques_encoded * path_encoded, dim=-1)

            self.encoder_q.train()
            self.encoder_p.train()
            if attention_value:
                return score,attention_score
            else:
                return score
コード例 #4
0
ファイル: network.py プロジェクト: nju-websoft/SkeletonKBQA
 def predict(self, question, question_dep=None, paths=None, paths_words=None, paths_rel1=None, paths_rel2=None, paths_rel3=None, paths_rel4=None, attention_value=False, device=None):
     """
         Same code works for both pairwise or pointwise
     """
     with torch.no_grad():
         self.encoder.eval()
         paths_rel2 = tu.no_one_left_behind(paths_rel2)
         paths_rel3 = tu.no_one_left_behind(paths_rel3)
         paths_rel4 = tu.no_one_left_behind(paths_rel4)
         score = self.encoder(tu.trim(question), tu.trim(paths_rel1), tu.trim(paths_rel2),tu.trim(paths_rel3),tu.trim(paths_rel4))
         self.encoder.train()
         return score
コード例 #5
0
ファイル: network.py プロジェクト: nju-websoft/SkeletonKBQA
 def predict(self, question=None, question_dep=None, question_dep_mask=None, paths=None, paths_words=None, paths_rel1=None, paths_rel2=None, paths_rel3=None, paths_rel4=None, attention_value=False, device=None):
     """
         Same code works for both pairwise or pointwise
     """
     with torch.no_grad():
         self.encoder.eval()
         paths_rel2 = tu.no_one_left_behind(paths_rel2)
         paths_rel3 = tu.no_one_left_behind(paths_rel3)
         paths_rel4 = tu.no_one_left_behind(paths_rel4)
         score = self.encoder(tu.trim(question), question_dep, question_dep_mask, tu.trim(paths_rel1), tu.trim(paths_rel2),tu.trim(paths_rel3),tu.trim(paths_rel4))
         # score = score.squeeze() # v0.2 in order to solve: mrr = mrr_output.index(positive_path_index) + 1.0
         self.encoder.train()
         return score
コード例 #6
0
ファイル: network.py プロジェクト: nju-websoft/SkeletonKBQA
 def _train_pairwise_(self, data, optimizer, loss_fn, device):
     ques_batch, pos_1_batch, pos_2_batch, pos_3_batch, pos_4_batch, neg_1_batch, neg_2_batch, neg_3_batch, neg_4_batch, y_label = \
         data['ques_batch'], data['pos_rel1_batch'], data['pos_rel2_batch'], data['pos_rel3_batch'], data[
             'pos_rel4_batch'], \
         data['neg_rel1_batch'], data['neg_rel2_batch'], data['neg_rel3_batch'], data['neg_rel4_batch'], data[
             'y_label']
     # ques_batch, pos_1_batch, pos_2_batch, neg_1_batch, neg_2_batch, y_label = \
     #     data['ques_batch'], data['pos_rel1_batch'], data['pos_rel2_batch'], \
     #     data['neg_rel1_batch'], data['neg_rel2_batch'], data['y_label']
     optimizer.zero_grad()
     # Have to manually check if the 2nd paths holds anything in this batch.
     # If not, we have to pad everything up with zeros, or even call a limited part of the comparison module.
     pos_2_batch = tu.no_one_left_behind(pos_2_batch)
     pos_3_batch = tu.no_one_left_behind(pos_3_batch)
     pos_4_batch = tu.no_one_left_behind(pos_4_batch)
     neg_2_batch = tu.no_one_left_behind(neg_2_batch)
     neg_3_batch = tu.no_one_left_behind(neg_3_batch)
     neg_4_batch = tu.no_one_left_behind(neg_4_batch)
     # assert torch.mean((torch.sum(pos_2_batch, dim=-1) != 0).float()) == 1
     # assert torch.mean((torch.sum(pos_1_batch, dim=-1) != 0).float()) == 1
     # Encoding all the data
     ques_encoded,_ = self.encoder_q(tu.trim(ques_batch))
     pos_encoded = self.encoder_p(tu.trim(pos_1_batch), tu.trim(pos_2_batch),tu.trim(pos_3_batch), tu.trim(pos_4_batch))
     neg_encoded = self.encoder_p(tu.trim(neg_1_batch), tu.trim(neg_2_batch),tu.trim(neg_3_batch), tu.trim(neg_4_batch))
     # Pass them to the comparison module
     pos_scores = torch.sum(ques_encoded * pos_encoded, dim=-1)
     neg_scores = torch.sum(ques_encoded * neg_encoded, dim=-1)
     '''
         If `y == 1` then it assumed the first input should be ranked higher
         (have a larger value) than the second input, and vice-versa for `y == -1`
     '''
     loss = loss_fn(pos_scores, neg_scores, y_label)
     loss.backward()
     # torch.nn.utils.clip_grad_norm_(self.encoder_q.parameters(), .5)
     # torch.nn.utils.clip_grad_norm_(self.encoder_p.parameters(), .5)
     optimizer.step()
     return loss