예제 #1
0
파일: emb.py 프로젝트: xiandshi/DacKGR
 def get_corrupt_triple(self, e1, e2, r):
     batch_size = e1.size()[0]
     e1_list = []
     e2_list = []
     r_list = []
     decide_num = random.randint(0, 99)
     for i in range(batch_size):
         if decide_num < 25:
             rand_e1 = random.randint(0, len(self.kg.entity2id) - 1)
             try:
                 while int(e2[i]) in self.kg.train_objects[rand_e1][int(
                         r[i])]:
                     rand_e1 = random.randint(0, len(self.kg.entity2id) - 1)
             except:
                 pass
             e1_list.append(rand_e1)
             e2_list.append(int(e2[i]))
             r_list.append(int(r[i]))
         elif decide_num < 50:
             rand_e2 = random.randint(0, len(self.kg.entity2id) - 1)
             try:
                 while rand_e2 in self.kg.train_objects[int(e1[i])][int(
                         r[i])]:
                     rand_e2 = random.randint(0, len(self.kg.entity2id) - 1)
             except:
                 pass
             e1_list.append(int(e1[i]))
             e2_list.append(rand_e2)
             r_list.append(int(r[i]))
         else:
             rand_r = random.randint(0, len(self.kg.relation2id) - 1)
             try:
                 while int(e2[i]) in self.kg.train_objects[int(
                         e1[i])][rand_r]:
                     rand_r = random.randint(0,
                                             len(self.kg.relation2id) - 1)
             except:
                 pass
             e1_list.append(int(e1[i]))
             e2_list.append(int(e2[i]))
             r_list.append(rand_r)
     return var_cuda(
         torch.LongTensor(e1_list),
         requires_grad=False), var_cuda(torch.LongTensor(e2_list),
                                        requires_grad=False), var_cuda(
                                            torch.LongTensor(r_list),
                                            requires_grad=False)
    def top_k_answer_unique(log_action_dist, action_space):
        """
        Get top k unique entities
            - k = beam_size if the beam size is smaller than or equal to the beam action space size
            - k = beam_action_space_size otherwise
        :param log_action_dist: [batch_size*beam_size, action_space_size]
        :param action_space (r_space, e_space):
            r_space: [batch_size*beam_size, action_space_size]
            e_space: [batch_size*beam_size, action_space_size]
        :return:
            (next_r, next_e), action_prob, action_offset: [batch_size*k]
        """
        full_size = len(log_action_dist)
        assert (full_size % batch_size == 0)
        last_k = int(full_size / batch_size)
        (r_space, e_space), _ = action_space
        action_space_size = r_space.size()[1]

        r_space = r_space.view(batch_size, -1)
        e_space = e_space.view(batch_size, -1)
        log_action_dist = log_action_dist.view(batch_size, -1)
        beam_action_space_size = log_action_dist.size()[1]
        assert (beam_action_space_size % action_space_size == 0)
        k = min(beam_size, beam_action_space_size)
        next_r_list, next_e_list = [], []
        log_action_prob_list = []
        action_offset_list = []
        for i in range(batch_size):
            log_action_dist_b = log_action_dist[i]
            r_space_b = r_space[i]
            e_space_b = e_space[i]
            unique_e_space_b = var_cuda(torch.unique(e_space_b.data.cpu()))
            unique_log_action_dist, unique_idx = unique_max(
                unique_e_space_b, e_space_b, log_action_dist_b)
            k_prime = min(len(unique_e_space_b), k)
            top_unique_log_action_dist, top_unique_idx2 = torch.topk(
                unique_log_action_dist, k_prime)
            top_unique_idx = unique_idx[top_unique_idx2]
            top_unique_beam_offset = top_unique_idx // action_space_size
            top_r = r_space_b[top_unique_idx]
            top_e = e_space_b[top_unique_idx]
            next_r_list.append(top_r.unsqueeze(0))
            next_e_list.append(top_e.unsqueeze(0))
            log_action_prob_list.append(
                top_unique_log_action_dist.unsqueeze(0))
            top_unique_batch_offset = i * last_k
            top_unique_action_offset = top_unique_batch_offset + top_unique_beam_offset
            action_offset_list.append(top_unique_action_offset.unsqueeze(0))
        next_r = ops.pad_and_cat(next_r_list,
                                 padding_value=kg.dummy_r).view(-1)
        next_e = ops.pad_and_cat(next_e_list,
                                 padding_value=kg.dummy_e).view(-1)
        log_action_prob = ops.pad_and_cat(log_action_prob_list,
                                          padding_value=-ops.HUGE_INT)
        action_offset = ops.pad_and_cat(action_offset_list, padding_value=-1)
        return (next_r,
                next_e), log_action_prob.view(-1), action_offset.view(-1)
예제 #3
0
    def convert_tuples_to_tensors(self,
                                  batch_data,
                                  num_labels=-1,
                                  num_tiles=1):
        """
        Convert batched tuples to the tensors accepted by the NN.
        num_tiles == num_rollouts == beam-size
        """
        def convert_to_binary_multi_subject(e1):
            e1_label = zeros_var_cuda([len(e1), num_labels])
            for i in range(len(e1)):
                e1_label[i][e1[i]] = 1
            return e1_label

        def convert_to_binary_multi_object(e2):
            e2_label = zeros_var_cuda([len(e2), num_labels])
            for i in range(len(e2)):
                e2_label[i][e2[i]] = 1
            return e2_label

        batch_e1, batch_e2, batch_r = [], [], []
        for i in range(len(batch_data)):
            e1, e2, r = batch_data[i]
            batch_e1.append(e1)
            batch_e2.append(e2)
            batch_r.append(r)
        batch_e1 = var_cuda(torch.LongTensor(batch_e1), requires_grad=False)
        batch_r = var_cuda(torch.LongTensor(batch_r), requires_grad=False)
        if type(batch_e2[0]) is list:
            batch_e2 = convert_to_binary_multi_object(batch_e2)
        elif type(batch_e1[0]) is list:
            batch_e1 = convert_to_binary_multi_subject(batch_e1)
        else:
            batch_e2 = var_cuda(torch.LongTensor(batch_e2),
                                requires_grad=False)
        # Rollout multiple times for each example
        if num_tiles > 1:
            batch_e1 = ops.tile_along_beam(
                batch_e1,
                num_tiles)  # is repeating the vector "num-tiles" times
            batch_r = ops.tile_along_beam(batch_r, num_tiles)
            batch_e2 = ops.tile_along_beam(batch_e2, num_tiles)
        return batch_e1, batch_e2, batch_r
 def vectorize_action_space(action_space_list, action_space_size):
     bucket_size = len(action_space_list)
     r_space = torch.zeros(bucket_size, action_space_size) + self.dummy_r
     e_space = torch.zeros(bucket_size, action_space_size) + self.dummy_e
     action_mask = torch.zeros(bucket_size, action_space_size)
     for i, action_space in enumerate(action_space_list):
         for j, (r, e) in enumerate(action_space):
             r_space[i, j] = r
             e_space[i, j] = e
             action_mask[i, j] = 1
     return (int_var_cuda(r_space), int_var_cuda(e_space)), var_cuda(action_mask)
예제 #5
0
 def build(forks: List[Fork], action_space_size, dummy_r, dummy_e):
     bucket_size = len(forks)
     r_space = torch.zeros(bucket_size, action_space_size) + dummy_r
     e_space = torch.zeros(bucket_size, action_space_size) + dummy_e
     action_mask = torch.zeros(bucket_size, action_space_size)
     for i, fork in enumerate(forks):
         for j, direction in enumerate(fork.directions):
             r_space[i, j] = direction.rel
             e_space[i, j] = direction.ent
             action_mask[i, j] = 1
     return ActionSpace(forks, int_var_cuda(r_space), int_var_cuda(e_space),
                        var_cuda(action_mask))
예제 #6
0
파일: pg.py 프로젝트: h-peng17/KGE
 def apply_action_dropout_mask(action_dist, action_mask):
     if self.action_dropout_rate > 0:
         rand = torch.rand(action_dist.size())
         action_keep_mask = var_cuda(rand > self.action_dropout_rate).float()
         # There is a small chance that that action_keep_mask is accidentally set to zero.
         # When this happen, we take a random sample from the available actions.
         # sample_action_dist = action_dist * (action_keep_mask + ops.EPSILON)
         sample_action_dist = \
             action_dist * action_keep_mask + ops.EPSILON * (1 - action_keep_mask) * action_mask
         return sample_action_dist
     else:
         return action_dist
예제 #7
0
 def get_corrupt_relation(self, e1, e2, r):
     batch_size = e1.size()[0]
     r_list = []
     for i in range(batch_size):
         rand_r = random.randint(0, len(self.kg.relation2id) - 1)
         try:
             while int(e2[i]) in self.kg.train_objects[int(e1[i])][rand_r]:
                 rand_r = random.randint(0, len(self.kg.relation2id) - 1)
         except:
             pass
         r_list.append(rand_r)
     return var_cuda(torch.LongTensor(r_list), requires_grad=False)
예제 #8
0
        def apply_action_dropout_mask(r_space, action_dist, action_mask):

            bucket_size = len(r_space)
            batch_idx = torch.tensor(offset[self.ct:self.ct +
                                            bucket_size]).cuda()
            r_prob_b = r_prob[batch_idx]

            if not self.pretrain:
                uni_mask = torch.zeros(r_prob_b.shape).cuda()
                uni_mask = torch.scatter(uni_mask, 1, r_space,
                                         torch.ones(
                                             r_space.shape).cuda()).cuda()
                r_prob_b = r_prob_b * uni_mask
            r_prob_sum = torch.sum(r_prob_b, 1)
            is_zero = (r_prob_sum == 0).float().unsqueeze(1)
            r_prob_b = r_prob_b + is_zero * torch.ones(
                r_prob_b[0].size()).cuda()

            r_chosen_b = torch.multinomial(r_prob_b, 1, replacement=True)

            r_prob_chosen_b = ops.batch_lookup(r_prob_b,
                                               r_chosen_b).unsqueeze(1)

            action_mute_mask = (r_space == r_chosen_b).float()
            if self.pretrain:
                action_dist_muted = action_mute_mask * r_prob_chosen_b
            else:
                action_dist_muted = action_dist * action_mute_mask * r_prob_chosen_b

            dist_sum = torch.sum(action_dist_muted, 1)
            is_zero = (dist_sum == 0).float().unsqueeze(1)
            if self.pretrain:
                uniform_dist = torch.ones(action_dist[0].size()).float().cuda()
                action_dist_muted = action_dist_muted + is_zero * uniform_dist
            else:
                action_dist_muted = action_dist_muted + is_zero * action_dist
            self.ct += bucket_size
            new_action_dist = action_dist_muted

            if self.action_dropout_rate > 0:
                rand = torch.rand(action_dist.size())
                action_keep_mask = var_cuda(
                    rand > self.action_dropout_rate).float()
                sample_action_dist = new_action_dist * action_keep_mask + ops.EPSILON * (
                    1 - action_keep_mask) * action_mask
                # if dropout to 0, keep original value
                dist_sum = torch.sum(sample_action_dist, 1)
                is_zero = (dist_sum == 0).float().unsqueeze(1)
                sample_action_dist = sample_action_dist + is_zero * new_action_dist
                return sample_action_dist
            else:
                return new_action_dist
    def format_batch(self, batch_data, num_labels=-1, num_tiles=1):
        """
        Convert batched tuples to the tensors accepted by the NN.
        """
        def convert_to_binary_multi_subject(e1):
            e1_label = zeros_var_cuda([len(e1), num_labels])
            for i in range(len(e1)):
                e1_label[i][e1[i]] = 1
            return e1_label

        def convert_to_binary_multi_object(e2):
            e2_label = zeros_var_cuda([len(e2), num_labels])
            for i in range(len(e2)):
                e2_label[i][e2[i]] = 1
            return e2_label

        batch_e1, batch_e2, batch_r = [], [], []
        for i in range(len(batch_data)):
            #e1, e2, r ,lenmask= batch_data[i]
            e1, e2, r = batch_data[i]
            batch_e1.append(e1)
            batch_e2.append(e2)
            batch_r.append(r)
        batch_e1 = var_cuda(torch.LongTensor(batch_e1), requires_grad=False)
        batch_r = var_cuda(torch.LongTensor(batch_r), requires_grad=False)
        if type(batch_e2[0]) is list:
            batch_e2 = convert_to_binary_multi_object(batch_e2)
        elif type(batch_e1[0]) is list:
            batch_e1 = convert_to_binary_multi_subject(batch_e1)
        else:
            batch_e2 = var_cuda(torch.LongTensor(batch_e2),
                                requires_grad=False)
        # Rollout multiple times for each example
        #e.g., the shape of batch_e1: (128) after function (128*num_tiles)
        if num_tiles > 1:
            batch_e1 = ops.tile_along_beam(batch_e1, num_tiles)
            batch_r = ops.tile_along_beam(batch_r, num_tiles)
            batch_e2 = ops.tile_along_beam(batch_e2, num_tiles)
        return batch_e1, batch_e2, batch_r
예제 #10
0
 def find_mute_idx(space, valid_space):
     space = space.cpu().numpy()
     # action_mute_mask = np.zeros(space.shape).astype(np.float32)
     action_mute_mask = []
     # for v in valid_space:
     #     action_mute_mask += space == v
     for idx in range(len(space)):
         # if space[idx] == 0:
         #     break
         if space[idx] in valid_space:
             #     action_mute_mask[idx] = 1
             action_mute_mask.append(1)
         else:
             action_mute_mask.append(0)
     return var_cuda(torch.FloatTensor(action_mute_mask))
 def virtual_step(self, e_set, r):
     """
     Given a set of entities (e_set), find the set of entities (e_set_out) which has at least one incoming edge
     labeled r and the source entity is in e_set.
     """
     batch_size = len(e_set)
     e_set_1D = e_set.view(-1)
     r_space = self.action_space[0][0][e_set_1D]
     e_space = self.action_space[0][1][e_set_1D]
     e_space = (r_space.view(batch_size, -1) == r.unsqueeze(1)).long() * e_space.view(batch_size, -1)
     e_set_out = []
     for i in range(len(e_space)):
         e_set_out_b = var_cuda(unique(e_space[i].data))
         e_set_out.append(e_set_out_b.unsqueeze(0))
     e_set_out = ops.pad_and_cat(e_set_out, padding_value=self.dummy_e)
     return e_set_out
예제 #12
0
 def get_answer_mask(self, e_space, e_s, q, kg):
     if kg.args.mask_test_false_negatives:
         answer_vectors = kg.all_object_vectors
     else:
         answer_vectors = kg.train_object_vectors
     answer_masks = []
     for i in range(len(e_space)):
         _e_s, _q = int(e_s[i]), int(q[i])
         if not _e_s in answer_vectors or not _q in answer_vectors[_e_s]:
             answer_vector = var_cuda(torch.LongTensor([[kg.num_entities]]))
         else:
             answer_vector = answer_vectors[_e_s][_q]
         answer_mask = torch.sum(e_space[i].unsqueeze(0) == answer_vector, dim=0).long()
         answer_masks.append(answer_mask)
     answer_mask = torch.cat(answer_masks).view(len(e_space), -1)
     return answer_mask
예제 #13
0
파일: emb.py 프로젝트: h-peng17/KGE
 def get_object_mask(self, e2_space, e1, q):
     kg = self.kg
     if kg.args.mask_test_false_negatives:
         answer_vectors = kg.all_object_vectors
     else:
         answer_vectors = kg.train_object_vectors
     object_masks = []
     for i in range(len(e2_space)):
         _e1, _q = int(e1[i]), int(q[i])
         if not e1 in answer_vectors or not q in answer_vectors[_e1]:
             answer_vector = var_cuda(torch.LongTensor([[kg.num_entities]]))
         else:
             answer_vector = answer_vectors[_e1][_q]
         object_mask = torch.sum(e2_space[i].unsqueeze(0) == answer_vector, dim=0)
         object_masks.append(object_mask)
     object_mask = torch.cat(object_masks).view(len(e2_space), -1)
     return object_mask
예제 #14
0
파일: emb.py 프로젝트: patzaa/MultiHopKG
 def get_subject_mask(self, e1_space, e2, q):
     assert False #TODO(tilo) !?
     kg = self.kg
     if kg.args.mask_test_false_negatives:
         answer_vectors = kg.all_subject_vectors
     else:
         answer_vectors = kg.train_subject_vectors
     subject_masks = []
     for i in range(len(e1_space)):
         _e2, _q = int(e2[i]), int(q[i])
         if not _e2 in answer_vectors or not _q in answer_vectors[_e2]:
             answer_vector = var_cuda(torch.LongTensor([[kg.num_entities]]))
         else:
             answer_vector = answer_vectors[_e2][_q]
         subject_mask = torch.sum(e1_space[i].unsqueeze(0) == answer_vector, dim=0)
         subject_masks.append(subject_mask)
     subject_mask = torch.cat(subject_masks).view(len(e1_space), -1)
     return subject_mask
    def format_batch_with_abs(self, batch_data, num_labels=-1, num_tiles=1):
        """
        Convert batched tuples to the tensors accepted by the NN.
        """

        def convert_to_binary_multi_subject(e1):
            e1_label = zeros_var_cuda([len(e1), num_labels])
            e1_label_abs = zeros_var_cuda([len(e1), num_labels])
            for i in range(len(e1)):
                e1_label[i][e1[i]] = 1
                e1_label_abs[i][self.kg.get_typeid(e1[i])] = 1
            return e1_label, e1_label_abs

        def convert_to_binary_multi_object(e2):
            e2_label = zeros_var_cuda([len(e2), num_labels])
            e2_label_abs = zeros_var_cuda([len(e2), num_labels])
            for i in range(len(e2)):
                e2_label[i][e2[i]] = 1
                e2_label_abs[i][self.kg.get_typeid(e2[i])] = 1
            return e2_label, e2_label_abs

        batch_e1, batch_e2, batch_r, batch_e1_abs, batch_e2_abs, batch_r_abs = [], [], [], [], [], []
        for i in range(len(batch_data)):
            e1, e2, r = batch_data[i]
            batch_e1.append(e1)
            batch_e2.append(e2)
            batch_r.append(r)

            batch_e1_abs.append(self.kg.get_typeid(e1))
            batch_e2_abs.append(self.kg.get_typeid(e2))
            batch_r_abs.append(r)

        batch_e1 = var_cuda(torch.LongTensor(batch_e1), requires_grad=False)
        batch_r = var_cuda(torch.LongTensor(batch_r), requires_grad=False)
        batch_e1_abs = var_cuda(torch.LongTensor(batch_e1_abs), requires_grad=False)
        batch_r_abs = var_cuda(torch.LongTensor(batch_r_abs), requires_grad=False)
        if type(batch_e2[0]) is list:
            batch_e2, batch_e2_abs = convert_to_binary_multi_object(batch_e2)
        elif type(batch_e1[0]) is list:
            batch_e1, batch_e1_abs = convert_to_binary_multi_subject(batch_e1)
        else:
            batch_e2 = var_cuda(torch.LongTensor(batch_e2), requires_grad=False)
            batch_e2_abs = var_cuda(torch.LongTensor(batch_e2_abs), requires_grad=False)
        # Rollout multiple times for each example
        if num_tiles > 1:
            batch_e1 = ops.tile_along_beam(batch_e1, num_tiles)
            batch_r = ops.tile_along_beam(batch_r, num_tiles)
            batch_e2 = ops.tile_along_beam(batch_e2, num_tiles)
            batch_e1_abs = ops.tile_along_beam(batch_e1_abs, num_tiles)
            batch_r_abs = ops.tile_along_beam(batch_r_abs, num_tiles)
            batch_e2_abs = ops.tile_along_beam(batch_e2_abs, num_tiles)
        return batch_e1, batch_e2, batch_r, batch_e1_abs, batch_e2_abs, batch_r_abs
예제 #16
0
파일: emb.py 프로젝트: h-peng17/KGE
    def export_fuzzy_facts(self):
        """
        Export high confidence facts according to the model.
        """
        kg, mdl = self.kg, self.mdl

        # Gather all possible (subject, relation) and (relation, object) pairs
        sub_rel, rel_obj = {}, {}
        for file_name in ['raw.kb', 'train.triples', 'dev.triples', 'test.triples']:
            with open(os.path.join(self.data_dir, file_name)) as f:
                for line in f:
                    e1, e2, r = line.strip().split()
                    e1_id, e2_id, r_id = kg.triple2ids((e1, e2, r))
                    if not e1_id in sub_rel:
                        sub_rel[e1_id] = {}
                    if not r_id in sub_rel[e1_id]:
                        sub_rel[e1_id][r_id] = set()
                    sub_rel[e1_id][r_id].add(e2_id)
                    if not e2_id in rel_obj:
                        rel_obj[e2_id] = {}
                    if not r_id in rel_obj[e2_id]:
                        rel_obj[e2_id][r_id] = set()
                    rel_obj[e2_id][r_id].add(e1_id)

        o_f = open(os.path.join(self.data_dir, 'train.fuzzy.triples'), 'w')
        print('Saving fuzzy facts to {}'.format(os.path.join(self.data_dir, 'train.fuzzy.triples')))
        count = 0
        # Save recovered objects
        e1_ids, r_ids = [], []
        for e1_id in sub_rel:
            for r_id in sub_rel[e1_id]:
                e1_ids.append(e1_id)
                r_ids.append(r_id)
        for i in range(0, len(e1_ids), self.batch_size):
            e1_ids_b = e1_ids[i:i+self.batch_size]
            r_ids_b = r_ids[i:i+self.batch_size]
            e1 = var_cuda(torch.LongTensor(e1_ids_b))
            r = var_cuda(torch.LongTensor(r_ids_b))
            pred_scores = mdl.forward(e1, r, kg)
            for j in range(pred_scores.size(0)):
                for _e2 in range(pred_scores.size(1)):
                    if _e2 in [NO_OP_ENTITY_ID, DUMMY_ENTITY_ID]:
                        continue
                    if pred_scores[j, _e2] >= self.theta:
                        _e1 = int(e1[j])
                        _r = int(r[j])
                        o_f.write('{}\t{}\t{}\t{}\n'.format(
                            kg.id2entity[_e1], kg.id2entity[_e2], kg.id2relation[_r], float(pred_scores[j, _e2])))
                        count += 1
                        if count % 1000 == 0:
                            print('{} fuzzy facts exported'.format(count))
        # Save recovered subjects
        e2_ids, r_ids = [], []
        for e2_id in rel_obj:
            for r_id in rel_obj[e2_id]:
                e2_ids.append(e2_id)
                r_ids.append(r_id)
        e1 = int_var_cuda(torch.arange(kg.num_entities))
        for i in range(len(e2_ids)):
            r = int_fill_var_cuda(e1.size(), r_ids[i])
            e2 = int_fill_var_cuda(e1.size(), e2_ids[i])
            pred_scores = mdl.forward_fact(e1, r, e2, kg)
            for j in range(pred_scores.size(1)):
                if pred_scores[j] > self.theta:
                    _e1 = int(e1[j])
                    if _e1 in [NO_OP_ENTITY_ID, DUMMY_ENTITY_ID]:
                        continue
                    _r = int(r[j])
                    _e2 = int(e2[j])
                    if _e1 in sub_rel and _r in sub_rel[_e1]:
                        continue
                    o_f.write('{}\t{}\t{}\t{}\n'.format(
                        kg.id2entity[_e1], kg.id2entity[_e2], kg.id2relation[_r], float(pred_scores[j])))
                    count += 1
                    if count % 1000 == 0:
                        print('{} fuzzy facts exported'.format(count))
예제 #17
0
    def sample_action(self,
                      obs,
                      t,
                      path_trace,
                      db_outcomes,
                      r_prob,
                      inv_offset=None):
        """
        Sample an action based on current policy.
        :param db_outcomes (((r_space, e_space), action_mask), action_dist):
                r_space: (Variable:batch) relation space
                e_space: (Variable:batch) target entity space
                action_mask: (Variable:batch) binary mask indicating padding actions.
                action_dist: (Variable:batch) action distribution of the current step based on set_policy
                    network parameters
        :param inv_offset: Indexes for restoring original order in a batch.
        :return next_action (next_r, next_e): Sampled next action.
        :return action_prob: Probability of the sampled action.
        """
        def find_mute_idx(space, valid_space):
            space = space.cpu().numpy()
            # action_mute_mask = np.zeros(space.shape).astype(np.float32)
            action_mute_mask = []
            # for v in valid_space:
            #     action_mute_mask += space == v
            for idx in range(len(space)):
                # if space[idx] == 0:
                #     break
                if space[idx] in valid_space:
                    #     action_mute_mask[idx] = 1
                    action_mute_mask.append(1)
                else:
                    action_mute_mask.append(0)
            return var_cuda(torch.FloatTensor(action_mute_mask))

        def to_one_hot(x):
            y = torch.eye(self.kg.num_relations).cuda()
            return y[x]

        def apply_action_dropout_mask(r_space, action_dist, action_mask):

            bucket_size = len(r_space)
            batch_idx = torch.tensor(offset[self.ct:self.ct +
                                            bucket_size]).cuda()
            r_prob_b = r_prob[batch_idx]

            if not self.pretrain:
                uni_mask = torch.zeros(r_prob_b.shape).cuda()
                uni_mask = torch.scatter(uni_mask, 1, r_space,
                                         torch.ones(
                                             r_space.shape).cuda()).cuda()
                r_prob_b = r_prob_b * uni_mask
            r_prob_sum = torch.sum(r_prob_b, 1)
            is_zero = (r_prob_sum == 0).float().unsqueeze(1)
            r_prob_b = r_prob_b + is_zero * torch.ones(
                r_prob_b[0].size()).cuda()

            r_chosen_b = torch.multinomial(r_prob_b, 1, replacement=True)

            r_prob_chosen_b = ops.batch_lookup(r_prob_b,
                                               r_chosen_b).unsqueeze(1)

            action_mute_mask = (r_space == r_chosen_b).float()
            if self.pretrain:
                action_dist_muted = action_mute_mask * r_prob_chosen_b
            else:
                action_dist_muted = action_dist * action_mute_mask * r_prob_chosen_b

            dist_sum = torch.sum(action_dist_muted, 1)
            is_zero = (dist_sum == 0).float().unsqueeze(1)
            if self.pretrain:
                uniform_dist = torch.ones(action_dist[0].size()).float().cuda()
                action_dist_muted = action_dist_muted + is_zero * uniform_dist
            else:
                action_dist_muted = action_dist_muted + is_zero * action_dist
            self.ct += bucket_size
            new_action_dist = action_dist_muted

            if self.action_dropout_rate > 0:
                rand = torch.rand(action_dist.size())
                action_keep_mask = var_cuda(
                    rand > self.action_dropout_rate).float()
                sample_action_dist = new_action_dist * action_keep_mask + ops.EPSILON * (
                    1 - action_keep_mask) * action_mask
                # if dropout to 0, keep original value
                dist_sum = torch.sum(sample_action_dist, 1)
                is_zero = (dist_sum == 0).float().unsqueeze(1)
                sample_action_dist = sample_action_dist + is_zero * new_action_dist
                return sample_action_dist
            else:
                return new_action_dist

        def sample(action_space, action_dist):
            sample_outcome = {}
            ((r_space, e_space), action_mask) = action_space
            sample_action_dist = apply_action_dropout_mask(
                r_space, action_dist, action_mask)
            idx = torch.multinomial(sample_action_dist, 1, replacement=True)
            next_r = ops.batch_lookup(r_space, idx)
            next_e = ops.batch_lookup(e_space, idx)
            action_prob = ops.batch_lookup(sample_action_dist, idx)

            sample_outcome['action_sample'] = (next_r, next_e)
            sample_outcome['action_prob'] = action_prob
            return sample_outcome

        e_s, q, e_t, last_step, last_r, seen_nodes = obs
        offset = [0] * len(inv_offset)
        for i in range(len(inv_offset)):
            offset[inv_offset[i]] = i
        if inv_offset is not None:
            next_r_list = []
            next_e_list = []
            action_dist_list = []
            action_prob_list = []
            self.ct = 0
            self.zero_ct = 0

            # relation dropout
            rand = torch.rand(r_prob.size())
            r_prob_keep_mask = var_cuda(
                rand > self.action_dropout_rate).float()
            r_prob = r_prob * r_prob_keep_mask

            r_prob_sum = torch.sum(r_prob, 1)
            is_zero = (r_prob_sum == 0).float().unsqueeze(1)
            r_prob = r_prob + is_zero * torch.ones(r_prob[0].size()).cuda()
            # r_prob_keep_mask = (rand > self.action_dropout_rate).float().cuda()
            #r_chosen = torch.multinomial(r_prob, 1, replacement = True)
            #r_prob_chosen = ops.batch_lookup(r_prob, r_chosen)
            for action_space, action_dist in db_outcomes:
                sample_outcome = sample(action_space, action_dist)
                next_r_list.append(sample_outcome['action_sample'][0])
                next_e_list.append(sample_outcome['action_sample'][1])
                action_prob_list.append(sample_outcome['action_prob'])
                action_dist_list.append(action_dist)
            next_r = torch.cat(next_r_list, dim=0)[inv_offset]
            next_e = torch.cat(next_e_list, dim=0)[inv_offset]
            action_sample = (next_r, next_e)
            action_prob = torch.cat(action_prob_list, dim=0)[inv_offset]
            sample_outcome = {}
            sample_outcome['action_sample'] = action_sample
            sample_outcome['action_prob'] = action_prob
        else:
            sample_outcome = sample(db_outcomes[0][0], db_outcomes[0][1])

        return sample_outcome
예제 #18
0
    def teacher_forcing_pretrain(self, e_s, q, e_t, num_steps):
        """
        Perform multi-step rollout from the source entity conditioned on the query relation.
        :param pn: Policy network.
        :param e_s: (Variable:batch) source entity indices.
        :param q: (Variable:batch) query relation indices.
        :param e_t: (Variable:batch) target entity indices.
        :param kg: Knowledge graph environment.
        :param num_steps: Number of rollout steps.
        :return log_action_probs: Log probability of the sampled path.
        :return action_entropy: Entropy regularization term.
        """
        assert (num_steps > 0)
        kg, pn = self.kg, self.mdl

        # Initialization
        log_action_probs = []
        action_entropy = []

        r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r)
        path_label = []
        cnt = 0
        for q_b in q:
            if int(q_b) not in self.rel2rules.keys():
                path_label.append(
                    torch.randint(kg.num_relations,
                                  (num_steps, )).numpy().tolist())
                continue
            rules = self.rel2rules[int(q_b)]
            cnt += 1
            # uniform
            # sample_rule_id = torch.randint(len(rules.keys()), (1,)).item()

            # weighted by confidence score
            rule_dist_orig = torch.tensor(list(rules.values())).cuda()
            rand = torch.rand(rule_dist_orig.size())
            keep_mask = var_cuda(rand > self.action_dropout_rate).float()
            rule_dist = rule_dist_orig * keep_mask
            rule_sum = torch.sum(rule_dist, 0)
            is_zero = (rule_sum == 0).float()  #.unsqueeze(1)
            rule_dist = rule_dist + is_zero * rule_dist_orig
            sample_rule_id = torch.multinomial(rule_dist, 1).item()
            path_label.append(list(rules.keys())[sample_rule_id])
        path_label = torch.tensor(path_label).cuda()
        #print('rule_path_percentage:', cnt/len(path_label))

        path_trace = [r_s]
        pn.initialize_path((r_s, e_s), kg)
        for t in range(num_steps):
            last_r = path_trace[-1]
            obs = [e_s, q, e_t, t == (num_steps - 1), last_r, None]

            # relation selection
            r_prob, policy_entropy = pn.transit_r(None, obs, kg)
            action_r = path_label[:, t]
            action_prob = ops.batch_lookup(r_prob, action_r.view(-1, 1))
            pn.update_path_r(action_r, kg)

            action_entropy.append(policy_entropy)

            log_action_probs.append(ops.safe_log(action_prob))
            path_trace.append(action_r)

        return {
            'log_action_probs': log_action_probs,
            'action_entropy': action_entropy,
            'path_trace': path_trace
        }
def beam_search_same(pn,
                     e_s,
                     q,
                     e_t,
                     kg,
                     num_steps,
                     beam_size,
                     return_path_components=False,
                     return_merge_scores=None,
                     same_start=False):
    """
    Beam search from source.

    :param pn: Policy network.
    :param e_s: (Variable:batch) source entity indices.
    :param q: (Variable:batch) query relation indices.
    :param e_t: (Variable:batch) target entity indices.
    :param kg: Knowledge graph environment.
    :param num_steps: Number of search steps.
    :param beam_size: Beam size used in search.
    :param return_path_components: If set, return all path components at the end of search.
    """
    assert (num_steps >= 1)
    batch_size = len(e_s)

    def top_k_action(log_action_dist, action_space, return_merge_scores=None):
        """
        Get top k actions.
            - k = beam_size if the beam size is smaller than or equal to the beam action space size
            - k = beam_action_space_size otherwise
        :param log_action_dist: [batch_size*k, action_space_size]
        :param action_space (r_space, e_space):
            r_space: [batch_size*k, action_space_size]
            e_space: [batch_size*k, action_space_size]
        :return:
            (next_r, next_e), action_prob, action_offset: [batch_size*new_k]
        """
        full_size = len(log_action_dist)
        assert (full_size % batch_size == 0)
        last_k = int(full_size / batch_size)

        (r_space, e_space), _ = action_space
        action_space_size = r_space.size()[1]
        # => [batch_size, k'*action_space_size]
        log_action_dist = log_action_dist.view(batch_size, -1)
        beam_action_space_size = log_action_dist.size()[1]
        k = min(beam_size, beam_action_space_size)

        if return_merge_scores is not None:
            if return_merge_scores == 'sum':
                reduce_method = torch.sum
            elif return_merge_scores == 'mean':
                reduce_method = torch.mean
            else:
                reduce_method = None

            all_action_ind = torch.LongTensor([
                range(beam_action_space_size)
                for _ in range(len(log_action_dist))
            ]).cuda(device=0)

            all_next_r = ops.batch_lookup(r_space.view(batch_size, -1),
                                          all_action_ind)
            all_next_e = ops.batch_lookup(e_space.view(batch_size, -1),
                                          all_action_ind)

            real_log_action_prob, real_next_e, real_action_ind, real_next_r = ops.merge_same(
                log_action_dist, all_next_e, all_next_r, method=reduce_method)

            next_e_list, next_r_list, action_ind_list, log_action_prob_list = [], [], [], []
            for i in range(batch_size):
                k_prime = min(len(real_log_action_prob[i]), k)
                top_log_prob, top_ind = torch.topk(real_log_action_prob[i],
                                                   k_prime)
                top_next_e, top_next_r, top_ind = real_next_e[i][
                    top_ind], real_next_r[i][top_ind], real_action_ind[i][
                        top_ind]
                next_e_list.append(top_next_e.unsqueeze(0))
                next_r_list.append(top_next_r.unsqueeze(0))
                action_ind_list.append(top_ind.unsqueeze(0))
                log_action_prob_list.append(top_log_prob.unsqueeze(0))

            next_r = ops.pad_and_cat(next_r_list,
                                     padding_value=kg.dummy_r).view(-1)
            next_e = ops.pad_and_cat(next_e_list,
                                     padding_value=kg.dummy_e).view(-1)
            log_action_prob = ops.pad_and_cat(log_action_prob_list,
                                              padding_value=0.0).view(-1)
            action_ind = ops.pad_and_cat(action_ind_list,
                                         padding_value=-1).view(-1)
        else:
            log_action_prob, action_ind = torch.topk(log_action_dist, k)
            next_r = ops.batch_lookup(r_space.view(batch_size, -1),
                                      action_ind).view(-1)
            next_e = ops.batch_lookup(e_space.view(batch_size, -1),
                                      action_ind).view(-1)

        # [batch_size, k] => [batch_size*k]
        log_action_prob = log_action_prob.view(-1)
        # *** compute parent offset
        # [batch_size, k]
        action_beam_offset = action_ind / action_space_size
        # [batch_size, 1]
        action_batch_offset = int_var_cuda(torch.arange(batch_size) *
                                           last_k).unsqueeze(1)
        # [batch_size, k] => [batch_size*k]
        action_offset = (action_batch_offset + action_beam_offset).view(-1)
        return (next_r, next_e), log_action_prob, action_offset

    def top_k_answer_unique(log_action_dist, action_space):
        """
        Get top k unique entities
            - k = beam_size if the beam size is smaller than or equal to the beam action space size
            - k = beam_action_space_size otherwise
        :param log_action_dist: [batch_size*beam_size, action_space_size] 概率
        :param action_space (r_space, e_space): 实体
            r_space: [batch_size*beam_size, action_space_size]
            e_space: [batch_size*beam_size, action_space_size]
        :return:
            (next_r, next_e), action_prob, action_offset: [batch_size*k]
        """
        full_size = len(log_action_dist)
        assert (full_size % batch_size == 0)
        last_k = int(full_size / batch_size)
        (r_space, e_space), _ = action_space
        action_space_size = r_space.size()[1]

        r_space = r_space.view(batch_size, -1)
        e_space = e_space.view(batch_size, -1)
        log_action_dist = log_action_dist.view(batch_size, -1)
        beam_action_space_size = log_action_dist.size()[1]
        assert (beam_action_space_size % action_space_size == 0)
        k = min(beam_size, beam_action_space_size)
        next_r_list, next_e_list = [], []
        log_action_prob_list = []
        action_offset_list = []
        for i in range(batch_size):
            log_action_dist_b = log_action_dist[i]
            r_space_b = r_space[i]
            e_space_b = e_space[i]
            unique_e_space_b = var_cuda(torch.unique(e_space_b.data.cpu()))
            unique_log_action_dist, unique_idx = unique_max(
                unique_e_space_b, e_space_b, log_action_dist_b)
            k_prime = min(len(unique_e_space_b), k)
            top_unique_log_action_dist, top_unique_idx2 = torch.topk(
                unique_log_action_dist, k_prime)
            top_unique_idx = unique_idx[top_unique_idx2]
            top_unique_beam_offset = top_unique_idx / action_space_size
            top_r = r_space_b[top_unique_idx]
            top_e = e_space_b[top_unique_idx]
            next_r_list.append(top_r.unsqueeze(0))
            next_e_list.append(top_e.unsqueeze(0))
            log_action_prob_list.append(
                top_unique_log_action_dist.unsqueeze(0))
            top_unique_batch_offset = i * last_k
            top_unique_action_offset = top_unique_batch_offset + top_unique_beam_offset
            action_offset_list.append(top_unique_action_offset.unsqueeze(0))
        next_r = ops.pad_and_cat(next_r_list,
                                 padding_value=kg.dummy_r).view(-1)
        next_e = ops.pad_and_cat(next_e_list,
                                 padding_value=kg.dummy_e).view(-1)
        log_action_prob = ops.pad_and_cat(log_action_prob_list,
                                          padding_value=-ops.HUGE_INT)
        action_offset = ops.pad_and_cat(action_offset_list, padding_value=-1)
        return (next_r,
                next_e), log_action_prob.view(-1), action_offset.view(-1)

    def adjust_search_trace(search_trace, action_offset):
        for i, (r, e) in enumerate(search_trace):
            new_r = r[action_offset]
            new_e = e[action_offset]
            search_trace[i] = (new_r, new_e)

    # Initialization
    r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r)  # WHY 最初为啥要空关系
    seen_nodes = int_fill_var_cuda(e_s.size(), kg.dummy_e).unsqueeze(1)  #TODO
    init_action = (r_s, e_s)
    e_s_abs = var_cuda(torch.LongTensor([kg.get_typeid(_) for _ in e_s]),
                       requires_grad=False)
    e_t_abs = var_cuda(torch.LongTensor([kg.get_typeid(_) for _ in e_t]),
                       requires_grad=False)
    # print("()()()es_size() e_s_abs.size():", e_s.size(), e_s_abs.size())
    seen_nodes_abs = int_fill_var_cuda(e_s_abs.size(), kg.dummy_e).unsqueeze(1)
    init_action_abs = (r_s, e_s_abs)
    init_action = (r_s, e_s)

    # path encoder
    # pn.initialize_path(init_action, kg)
    pn.initialize_abs_path(init_action, init_action_abs, kg, same_start)
    if kg.args.save_paths_to_csv:
        search_trace = [(r_s, e_s)]

    # Run beam search for num_steps
    # [batch_size*k], k=1
    log_action_prob = zeros_var_cuda(batch_size)
    if return_path_components:
        log_action_probs = []

    action = init_action
    action_abs = init_action_abs

    for t in range(num_steps):
        last_r, e = action
        last_r, e_abs = action_abs
        assert (q.size() == e_s.size())
        assert (q.size() == e_t.size())
        assert (e.size()[0] % batch_size == 0)
        assert (q.size()[0] % batch_size == 0)
        k = int(e.size()[0] / batch_size)

        # => [batch_size*k]
        q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k)
        e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k)
        e_s_abs = ops.tile_along_beam(e_abs.view(batch_size, -1)[:, 0], k)
        e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k)
        e_t_abs = ops.tile_along_beam(e_t_abs.view(batch_size, -1)[:, 0], k)
        obs = [e_s, q, e_t, t == (num_steps - 1), last_r, seen_nodes]
        obs_abs = [
            e_s_abs, q, e_t_abs, t == (num_steps - 1), last_r, seen_nodes_abs
        ]
        # one step forward in search
        # db_outcomes, _, _ = pn.transit_same(
        #     e, obs, kg, use_action_space_bucketing=False,
        #     merge_aspace_batching_outcome=True)  # TODO:细跟一下里面的get_action_space_in_buckets
        db_outcomes, _, _ = pn.transit_same(
            e,
            obs,
            e_abs,
            obs_abs,
            kg,
            use_action_space_bucketing=False,
            merge_aspace_batching_outcome=True
        )  # TODO:细跟一下里面的get_action_space_in_buckets
        action_space, action_dist = db_outcomes[0]
        # => [batch_size*k, action_space_size]
        log_action_dist = log_action_prob.view(-1,
                                               1) + ops.safe_log(action_dist)
        # [batch_size*k, action_space_size] => [batch_size*new_k]

        if t == num_steps - 1:
            if return_merge_scores is None:
                action, log_action_prob, action_offset = top_k_answer_unique(
                    log_action_dist, action_space)
            else:
                action, log_action_prob, action_offset = top_k_action(
                    log_action_dist, action_space, return_merge_scores)
        else:
            action, log_action_prob, action_offset = top_k_action(
                log_action_dist, action_space, None)

        #(next_r, next_e)
        action_abs = (action[0],
                      var_cuda(torch.LongTensor(
                          [kg.get_typeid(_) for _ in action[1]]),
                               requires_grad=False))

        if return_path_components:
            ops.rearrange_vector_list(log_action_probs, action_offset)
            log_action_probs.append(log_action_prob)
        pn.update_path_abs(action_abs, kg, offset=action_offset)
        seen_nodes = torch.cat(
            [seen_nodes[action_offset], action[1].unsqueeze(1)], dim=1)
        seen_nodes_abs = torch.cat(
            [seen_nodes_abs[action_offset], action_abs[1].unsqueeze(1)], dim=1)
        if kg.args.save_paths_to_csv:
            adjust_search_trace(search_trace, action_offset)
            search_trace.append(action)

    output_beam_size = int(action[0].size()[0] / batch_size)
    # [batch_size*beam_size] => [batch_size, beam_size]
    beam_search_output = dict()
    beam_search_output['pred_e2s'] = action[1].view(batch_size, -1)
    beam_search_output['pred_e2_scores'] = log_action_prob.view(batch_size, -1)

    if return_path_components:
        path_width = 10
        path_components_list = []
        for i in range(batch_size):
            p_c = []
            for k, log_action_prob in enumerate(log_action_probs):
                top_k_edge_labels = []
                for j in range(min(output_beam_size, path_width)):
                    ind = i * output_beam_size + j
                    r = kg.id2relation[int(search_trace[k + 1][0][ind])]
                    e = kg.id2entity[int(search_trace[k + 1][1][ind])]
                    if r.endswith('_inv'):
                        edge_label = '<-{}-{} {}'.format(
                            r[:-4], e, float(log_action_probs[k][ind]))
                    else:
                        edge_label = '-{}->{} {}'.format(
                            r, e, float(log_action_probs[k][ind]))
                    top_k_edge_labels.append(edge_label)
                top_k_action_prob = log_action_prob[:path_width]
                e_name = kg.id2entity[int(
                    search_trace[1][0][i *
                                       output_beam_size])] if k == 0 else ''
                p_c.append((e_name, top_k_edge_labels,
                            var_to_numpy(top_k_action_prob)))
            path_components_list.append(p_c)
        beam_search_output['search_traces'] = search_trace
        SAME_ALL_PATHS.append(beam_search_output)

    return beam_search_output