Esempio n. 1
0
 def run(self):
     t = time.time()
     triples_num = self.kgs.kg1.relation_triples_num + self.kgs.kg2.relation_triples_num
     triple_steps = int(math.ceil(triples_num / self.args.batch_size))
     steps_tasks = task_divide(list(range(triple_steps)), self.args.batch_threads_num)
     manager = mp.Manager()
     training_batch_queue = manager.Queue()
     neighbors1, neighbors2 = None, None
     for i in range(1, self.args.max_epoch + 1):
         self.launch_training_1epo(i, triple_steps, steps_tasks, training_batch_queue, neighbors1, neighbors2)
         if i >= self.args.start_valid and i % self.args.eval_freq == 0:
             flag = self.valid(self.args.stop_metric)
             self.flag1, self.flag2, self.early_stop = early_stop(self.flag1, self.flag2, flag)
             if self.early_stop or i == self.args.max_epoch:
                 break
         if self.args.neg_sampling == 'truncated' and i % self.args.truncated_freq == 0:
             t1 = time.time()
             assert 0.0 < self.args.truncated_epsilon < 1.0
             neighbors_num1 = int((1 - self.args.truncated_epsilon) * self.kgs.kg1.entities_num)
             neighbors_num2 = int((1 - self.args.truncated_epsilon) * self.kgs.kg2.entities_num)
             if neighbors1 is not None:
                 del neighbors1, neighbors2
             gc.collect()
             neighbors1 = bat.generate_neighbours(self.eval_kg1_useful_ent_embeddings(),
                                                  self.kgs.useful_entities_list1,
                                                  neighbors_num1, self.args.batch_threads_num)
             neighbors2 = bat.generate_neighbours(self.eval_kg2_useful_ent_embeddings(),
                                                  self.kgs.useful_entities_list2,
                                                  neighbors_num2, self.args.batch_threads_num)
             ent_num = len(self.kgs.kg1.entities_list) + len(self.kgs.kg2.entities_list)
             print("\ngenerating neighbors of {} entities costs {:.3f} s.".format(ent_num, time.time() - t1))
             gc.collect()
     print("Training ends. Total time = {:.3f} s.".format(time.time() - t))
Esempio n. 2
0
 def run(self):
     t = time.time()
     triples_num = self.kgs.kg1.relation_triples_num + self.kgs.kg2.relation_triples_num
     triple_steps = int(math.ceil(triples_num / self.args.batch_size))
     steps_tasks = task_divide(list(range(triple_steps)),
                               self.args.batch_threads_num)
     manager = mp.Manager()
     training_batch_queue = manager.Queue()
     neighbors1, neighbors2 = None, None
     labeled_align = set()
     sub_num = self.args.sub_epoch
     iter_nums = self.args.max_epoch // sub_num
     for i in range(1, iter_nums + 1):
         print("\niteration", i)
         self.launch_training_k_epo(i, sub_num, triple_steps, steps_tasks,
                                    training_batch_queue, neighbors1,
                                    neighbors2)
         if i * sub_num >= self.args.start_valid:
             flag = self.valid(self.args.stop_metric)
             self.flag1, self.flag2, self.early_stop = early_stop(
                 self.flag1, self.flag2, flag)
             if (self.early_stop
                     and i >= self.args.min_iter) or i == iter_nums:
                 break
         if i * sub_num >= self.args.start_bp:
             print("bootstrapping")
             labeled_align, entities1, entities2 = bootstrapping(
                 self.eval_ref_sim_mat(), self.ref_ent1, self.ref_ent2,
                 labeled_align, self.args.sim_th, self.args.k)
             self.train_alignment(self.kgs.kg1, self.kgs.kg2, entities1,
                                  entities2, self.args.align_times)
             if i * sub_num >= self.args.start_valid:
                 self.valid(self.args.stop_metric)
         t1 = time.time()
         if self.args.neg_sampling == "truncated":
             assert 0.0 < self.args.truncated_epsilon < 1.0
             neighbors_num1 = int((1 - self.args.truncated_epsilon) *
                                  self.kgs.kg1.entities_num)
             neighbors_num2 = int((1 - self.args.truncated_epsilon) *
                                  self.kgs.kg2.entities_num)
             if neighbors1 is not None:
                 del neighbors1, neighbors2
             gc.collect()
             neighbors1 = bat.generate_neighbours(
                 self.eval_kg1_useful_ent_embeddings(),
                 self.kgs.useful_entities_list1, neighbors_num1,
                 self.args.batch_threads_num)
             neighbors2 = bat.generate_neighbours(
                 self.eval_kg2_useful_ent_embeddings(),
                 self.kgs.useful_entities_list2, neighbors_num2,
                 self.args.batch_threads_num)
             ent_num = len(self.kgs.kg1.entities_list) + len(
                 self.kgs.kg2.entities_list)
             print("generating neighbors of {} entities costs {:.3f} s.".
                   format(ent_num,
                          time.time() - t1))
     print("Training ends. Total time = {:.3f} s.".format(time.time() - t))
Esempio n. 3
0
    def run(self):
        t = time.time()
        # Count total number of triples and how much step will need
        triples_num = self.kgs.kg1.relation_triples_num + self.kgs.kg2.relation_triples_num
        triple_steps = int(math.ceil(triples_num / self.args.batch_size))
        # Decide which thread will get which step
        steps_tasks = task_divide(list(range(triple_steps)),
                                  self.args.batch_threads_num)
        # Initialize multiprocess architecture
        manager = mp.Manager()
        training_batch_queue = manager.Queue()
        # Initialize neighbors and how many iterations will have at max
        neighbors1, neighbors2 = None, None
        labeled_align = set()
        sub_num = self.args.sub_epoch
        iter_nums = self.args.max_epoch // sub_num
        for i in range(1, iter_nums + 1):
            print("\niteration", i)
            self.save(i)
            self.launch_training_k_epo(i, sub_num, triple_steps, steps_tasks,
                                       training_batch_queue, neighbors1,
                                       neighbors2)
            if i * sub_num >= self.args.start_valid:
                # validation using stop_metric (hits@1) -> used just to print their results
                self.valid(self.args.stop_metric)
                # Removed to be changed with our new validation procedure
                flag = self.valid_new_bootea(self.args.stop_metric_new,
                                             "bootea")
                # Check what we are doing on the test data, just for debug purpose
                # We are not using them to train
                # self.test_new()
                self.flag1, self.flag2, self.early_stop = early_stop(
                    self.flag1, self.flag2, flag)

                # Code commented here used to save best embeddings -> Not used because we rolled back
                # to just keeping the last ones
                # self.flag1, new_flag2, self.early_stop = early_stop(self.flag1, self.flag2, flag)
                # if new_flag2 >= self.flag2:
                #     self.save_best_embeds()
                # self.flag2 = new_flag2

                if self.early_stop or i == iter_nums:
                    break
            labeled_align, entities1, entities2 = bootstrapping(
                self.eval_ref_sim_mat(), self.ref_ent1, self.ref_ent2,
                labeled_align, self.args.sim_th, self.args.k,
                self.len_valid_test)
            # entities1, entities2 = self.bootstrapping_new()
            # Do an extra training trying to improve the matched embeddings
            self.train_alignment(self.kgs.kg1, self.kgs.kg2, entities1,
                                 entities2, 1)
            # self.likelihood(labeled_align)
            if i * sub_num >= self.args.start_valid:
                self.valid(self.args.stop_metric)
                self.valid_new_bootea(self.args.stop_metric_new, "bootea")
            t1 = time.time()
            assert 0.0 < self.args.truncated_epsilon < 1.0
            neighbors_num1 = int(
                (1 - self.args.truncated_epsilon) * self.kgs.kg1.entities_num)
            neighbors_num2 = int(
                (1 - self.args.truncated_epsilon) * self.kgs.kg2.entities_num)
            if neighbors1 is not None:
                del neighbors1, neighbors2
            gc.collect()
            neighbors1 = bat.generate_neighbours(
                self.eval_kg1_useful_ent_embeddings(),
                self.kgs.useful_entities_list1, neighbors_num1,
                self.args.batch_threads_num)
            neighbors2 = bat.generate_neighbours(
                self.eval_kg2_useful_ent_embeddings(),
                self.kgs.useful_entities_list2, neighbors_num2,
                self.args.batch_threads_num)
            ent_num = len(self.kgs.kg1.entities_list) + len(
                self.kgs.kg2.entities_list)
            print("generating neighbors of {} entities costs {:.3f} s.".format(
                ent_num,
                time.time() - t1))
        print("Training ends. Total time = {:.3f} s.".format(time.time() - t))
Esempio n. 4
0
    def run(self):
        t = time.time()
        relation_triples_num = self.kgs.kg1.local_relation_triples_num + self.kgs.kg2.local_relation_triples_num
        attribute_triples_num = self.kgs.kg1.local_attribute_triples_num + self.kgs.kg2.local_attribute_triples_num
        relation_triple_steps = int(
            math.ceil(relation_triples_num / self.args.batch_size))
        attribute_triple_steps = int(
            math.ceil(attribute_triples_num / self.args.batch_size))
        relation_step_tasks = task_divide(list(range(relation_triple_steps)),
                                          self.args.batch_threads_num)
        attribute_step_tasks = task_divide(list(range(attribute_triple_steps)),
                                           self.args.batch_threads_num)
        manager = mp.Manager()
        relation_batch_queue = manager.Queue()
        attribute_batch_queue = manager.Queue()
        cross_kg_relation_triples = self.kgs.kg1.sup_relation_triples_list + self.kgs.kg2.sup_relation_triples_list
        cross_kg_entity_inference_in_attribute_triples = self.kgs.kg1.sup_attribute_triples_list + \
                                                         self.kgs.kg2.sup_attribute_triples_list
        cross_kg_relation_inference = self.predicate_align_model.sup_relation_alignment_triples1 + \
                                      self.predicate_align_model.sup_relation_alignment_triples2
        cross_kg_attribute_inference = self.predicate_align_model.sup_attribute_alignment_triples1 + \
                                       self.predicate_align_model.sup_attribute_alignment_triples2
        neighbors1, neighbors2 = None, None

        entity_list = self.kgs.kg1.entities_list + self.kgs.kg2.entities_list

        valid(self, embed_choice='nv')
        valid(self, embed_choice='avg')
        for i in range(1, self.args.max_epoch + 1):
            print('epoch {}:'.format(i))
            self.train_relation_view_1epo(i, relation_triple_steps,
                                          relation_step_tasks,
                                          relation_batch_queue, neighbors1,
                                          neighbors2)
            self.train_cross_kg_entity_inference_relation_view_1epo(
                i, cross_kg_relation_triples)
            if i > self.args.start_predicate_soft_alignment:
                self.train_cross_kg_relation_inference_1epo(
                    i, cross_kg_relation_inference)

            self.train_attribute_view_1epo(i, attribute_triple_steps,
                                           attribute_step_tasks,
                                           attribute_batch_queue, neighbors1,
                                           neighbors2)
            self.train_cross_kg_entity_inference_attribute_view_1epo(
                i, cross_kg_entity_inference_in_attribute_triples)
            if i > self.args.start_predicate_soft_alignment:
                self.train_cross_kg_attribute_inference_1epo(
                    i, cross_kg_attribute_inference)

            if i >= self.args.start_valid and i % self.args.eval_freq == 0:
                valid(self, embed_choice='rv')
                valid(self, embed_choice='av')
                valid(self, embed_choice='avg')
                valid_WVA(self)
                if i >= self.args.start_predicate_soft_alignment:
                    self.predicate_align_model.update_predicate_alignment(
                        self.rel_embeds.eval(session=self.session))
                    self.predicate_align_model.update_predicate_alignment(
                        self.attr_embeds.eval(session=self.session),
                        predicate_type='attribute')
                    cross_kg_relation_inference = self.predicate_align_model.sup_relation_alignment_triples1 + \
                                                  self.predicate_align_model.sup_relation_alignment_triples2
                    cross_kg_attribute_inference = self.predicate_align_model.sup_attribute_alignment_triples1 + \
                                                   self.predicate_align_model.sup_attribute_alignment_triples2

            if self.early_stop or i == self.args.max_epoch:
                break

            if self.args.neg_sampling == 'truncated' and i % self.args.truncated_freq == 0:
                t1 = time.time()
                assert 0.0 < self.args.truncated_epsilon < 1.0
                neighbors_num1 = int((1 - self.args.truncated_epsilon) *
                                     self.kgs.kg1.entities_num)
                neighbors_num2 = int((1 - self.args.truncated_epsilon) *
                                     self.kgs.kg2.entities_num)
                neighbors1 = bat.generate_neighbours(
                    self.eval_kg1_useful_ent_embeddings(),
                    self.kgs.useful_entities_list1, neighbors_num1,
                    self.args.batch_threads_num)
                neighbors2 = bat.generate_neighbours(
                    self.eval_kg2_useful_ent_embeddings(),
                    self.kgs.useful_entities_list2, neighbors_num2,
                    self.args.batch_threads_num)
                ent_num = len(self.kgs.kg1.entities_list) + len(
                    self.kgs.kg2.entities_list)
                print('neighbor dict:', len(neighbors1), type(neighbors2))
                print("generating neighbors of {} entities costs {:.3f} s.".
                      format(ent_num,
                             time.time() - t1))
        for i in range(1, self.args.shared_learning_max_epoch + 1):
            self.train_shared_space_mapping_1epo(i, entity_list)
            if i >= self.args.start_valid and i % self.args.eval_freq == 0:
                valid(self, embed_choice='final')
        self.save()
        test(self, embed_choice='nv')
        test(self, embed_choice='rv')
        test(self, embed_choice='av')
        test(self, embed_choice='avg')
        test_WVA(self)
        test(self, embed_choice='final')