Esempio n. 1
0
 def run(self):
     t = time.time()
     triples_num = self.kgs.kg1.relation_triples_num + self.kgs.kg2.relation_triples_num
     triple_steps = int(math.ceil(triples_num / self.args.batch_size))
     steps_tasks = task_divide(list(range(triple_steps)),
                               self.args.batch_threads_num)
     manager = mp.Manager()
     training_batch_queue = manager.Queue()
     for it in range(1, self.args.max_iter + 1):
         self.flag1 = -1
         self.flag2 = -1
         self.early_stop = False
         for i in range(1, self.args.max_epoch + 1):
             self.launch_desc_1epo(i)
             if i > 0 and i % self.args.eval_freq == 0:
                 # gc.collect()
                 flag = self.valid_desc(self.args.stop_metric)
                 # gc.collect()
                 self.flag1, self.flag2, self.early_stop = early_stop(
                     self.flag1, self.flag2, flag)
                 if self.early_stop or i == self.args.max_epoch:
                     break
         gc.collect()
         stop = self.find_new_alignment_desc()
         gc.collect()
         if stop:
             print("co-training ends")
             break
         self.flag1 = -1
         self.flag2 = -1
         self.early_stop = False
         for i in range(1, self.args.max_epoch + 1):
             self.launch_triple_training_1epo(i, triple_steps, steps_tasks,
                                              training_batch_queue, None,
                                              None)
             self.launch_mapping_training_1epo(i, triple_steps)
             self.launch_mapping_training_1epo_new(i, triple_steps)
             if i > 0 and i % self.args.eval_freq == 0:
                 flag = self.valid(self.args.stop_metric)
                 # gc.collect()
                 self.flag1, self.flag2, self.early_stop = early_stop(
                     self.flag1, self.flag2, flag)
                 if self.early_stop or i == self.args.max_epoch:
                     break
         stop = self.find_new_alignment_rel()
         # gc.collect()
         if stop:
             print("co-training ends")
             break
     print("Training ends. Total time = {:.3f} s.".format(time.time() - t))
Esempio n. 2
0
 def run(self):
     t = time.time()
     relation_triples_num = len(self.kgs.kg1.relation_triples_list) + len(
         self.kgs.kg2.relation_triples_list)
     attribute_triples_num = len(self.attribute_triples_list1) + len(
         self.attribute_triples_list2)
     relation_triple_steps = int(
         math.ceil(relation_triples_num / self.args.batch_size))
     attribute_triple_steps = int(
         math.ceil(attribute_triples_num / self.args.batch_size))
     relation_step_tasks = task_divide(list(range(relation_triple_steps)),
                                       self.args.batch_threads_num)
     attribute_step_tasks = task_divide(list(range(attribute_triple_steps)),
                                        self.args.batch_threads_num)
     manager = mp.Manager()
     relation_batch_queue = manager.Queue()
     attribute_batch_queue = manager.Queue()
     entity_list = list(self.kgs.kg1.entities_list +
                        self.kgs.kg2.entities_list)
     for i in range(1, self.args.max_epoch + 1):
         self.launch_triple_training_1epo(i, relation_triple_steps,
                                          relation_step_tasks,
                                          relation_batch_queue, None, None)
         # self.launch_triple_training_1epo_ce(i, attribute_triple_steps, attribute_step_tasks, attribute_batch_queue)
         # self.launch_joint_training_1epo(i, entity_list)
         if i >= self.args.start_valid and i % self.args.eval_freq == 0:
             flag = self.valid(self.args.stop_metric)
             self.flag1, self.flag2, self.early_stop = early_stop(
                 self.flag1, self.flag2, flag)
             if self.early_stop or i == self.args.max_epoch:
                 break
     print("Training ends. Total time = {:.3f} s.".format(time.time() - t))
Esempio n. 3
0
 def run(self):
     t = time.time()
     triples_num = self.kgs.kg1.relation_triples_num + self.kgs.kg2.relation_triples_num
     triple_steps = int(math.ceil(triples_num / self.args.batch_size))
     steps_tasks = task_divide(list(range(triple_steps)), self.args.batch_threads_num)
     manager = mp.Manager()
     training_batch_queue = manager.Queue()
     neighbors1, neighbors2 = None, None
     for i in range(1, self.args.max_epoch + 1):
         self.launch_training_1epo(i, triple_steps, steps_tasks, training_batch_queue, neighbors1, neighbors2)
         if i >= self.args.start_valid and i % self.args.eval_freq == 0:
             flag = self.valid(self.args.stop_metric)
             self.flag1, self.flag2, self.early_stop = early_stop(self.flag1, self.flag2, flag)
             if self.early_stop or i == self.args.max_epoch:
                 break
         if self.args.neg_sampling == 'truncated' and i % self.args.truncated_freq == 0:
             t1 = time.time()
             assert 0.0 < self.args.truncated_epsilon < 1.0
             neighbors_num1 = int((1 - self.args.truncated_epsilon) * self.kgs.kg1.entities_num)
             neighbors_num2 = int((1 - self.args.truncated_epsilon) * self.kgs.kg2.entities_num)
             if neighbors1 is not None:
                 del neighbors1, neighbors2
             gc.collect()
             neighbors1 = bat.generate_neighbours(self.eval_kg1_useful_ent_embeddings(),
                                                  self.kgs.useful_entities_list1,
                                                  neighbors_num1, self.args.batch_threads_num)
             neighbors2 = bat.generate_neighbours(self.eval_kg2_useful_ent_embeddings(),
                                                  self.kgs.useful_entities_list2,
                                                  neighbors_num2, self.args.batch_threads_num)
             ent_num = len(self.kgs.kg1.entities_list) + len(self.kgs.kg2.entities_list)
             print("\ngenerating neighbors of {} entities costs {:.3f} s.".format(ent_num, time.time() - t1))
             gc.collect()
     print("Training ends. Total time = {:.3f} s.".format(time.time() - t))
Esempio n. 4
0
 def run(self):
     t = time.time()
     triples_num = self.kgs.kg1.relation_triples_num + self.kgs.kg2.relation_triples_num
     triple_steps = int(math.ceil(triples_num / self.args.batch_size))
     steps_tasks = task_divide(list(range(triple_steps)),
                               self.args.batch_threads_num)
     manager = mp.Manager()
     training_batch_queue = manager.Queue()
     neighbors1, neighbors2 = None, None
     labeled_align = set()
     sub_num = self.args.sub_epoch
     iter_nums = self.args.max_epoch // sub_num
     for i in range(1, iter_nums + 1):
         print("\niteration", i)
         self.launch_training_k_epo(i, sub_num, triple_steps, steps_tasks,
                                    training_batch_queue, neighbors1,
                                    neighbors2)
         if i * sub_num >= self.args.start_valid:
             flag = self.valid(self.args.stop_metric)
             self.flag1, self.flag2, self.early_stop = early_stop(
                 self.flag1, self.flag2, flag)
             if (self.early_stop
                     and i >= self.args.min_iter) or i == iter_nums:
                 break
         if i * sub_num >= self.args.start_bp:
             print("bootstrapping")
             labeled_align, entities1, entities2 = bootstrapping(
                 self.eval_ref_sim_mat(), self.ref_ent1, self.ref_ent2,
                 labeled_align, self.args.sim_th, self.args.k)
             self.train_alignment(self.kgs.kg1, self.kgs.kg2, entities1,
                                  entities2, self.args.align_times)
             if i * sub_num >= self.args.start_valid:
                 self.valid(self.args.stop_metric)
         t1 = time.time()
         if self.args.neg_sampling == "truncated":
             assert 0.0 < self.args.truncated_epsilon < 1.0
             neighbors_num1 = int((1 - self.args.truncated_epsilon) *
                                  self.kgs.kg1.entities_num)
             neighbors_num2 = int((1 - self.args.truncated_epsilon) *
                                  self.kgs.kg2.entities_num)
             if neighbors1 is not None:
                 del neighbors1, neighbors2
             gc.collect()
             neighbors1 = bat.generate_neighbours(
                 self.eval_kg1_useful_ent_embeddings(),
                 self.kgs.useful_entities_list1, neighbors_num1,
                 self.args.batch_threads_num)
             neighbors2 = bat.generate_neighbours(
                 self.eval_kg2_useful_ent_embeddings(),
                 self.kgs.useful_entities_list2, neighbors_num2,
                 self.args.batch_threads_num)
             ent_num = len(self.kgs.kg1.entities_list) + len(
                 self.kgs.kg2.entities_list)
             print("generating neighbors of {} entities costs {:.3f} s.".
                   format(ent_num,
                          time.time() - t1))
     print("Training ends. Total time = {:.3f} s.".format(time.time() - t))
Esempio n. 5
0
    def train_embeddings(self, loss, optimizer, output):
        # **t=train_number k=neg_num
        neg_num = self.args.neg_triple_num
        train_num = len(self.kgs.train_links)
        train_links = np.array(self.kgs.train_links)
        pos = np.ones((train_num, neg_num)) * (train_links[:, 0].reshape((train_num, 1)))
        neg_left = pos.reshape((train_num * neg_num,))
        pos = np.ones((train_num, neg_num)) * (train_links[:, 1].reshape((train_num, 1)))
        neg2_right = pos.reshape((train_num * neg_num,))
        neg2_left = None
        neg_right = None
        feed_dict_se = None
        feed_dict_ae = None

        for i in range(1, self.args.max_epoch + 1):
            start = time.time()
            if i % 10 == 1:
                neg2_left = np.random.choice(self.e, train_num * neg_num)
                neg_right = np.random.choice(self.e, train_num * neg_num)
            feed_dict_ae = self.utils.construct_feed_dict(self.ae_input, self.support, self.ph_ae)
            feed_dict_ae.update({self.ph_ae['dropout']: self.args.dropout})
            feed_dict_ae.update({'neg_left:0': neg_left, 'neg_right:0': neg_right,
                                 'neg2_left:0': neg2_left, 'neg2_right:0': neg2_right})
            feed_dict_se = self.utils.construct_feed_dict(1., self.support, self.ph_se)
            feed_dict_se.update({self.ph_se['dropout']: self.args.dropout})
            feed_dict_se.update({'neg_left:0': neg_left, 'neg_right:0': neg_right,
                                 'neg2_left:0': neg2_left, 'neg2_right:0': neg2_right})
            batch_loss1, _ = self.session.run(fetches=[self.model_ae.loss, self.model_ae.opt_op],
                                              feed_dict=feed_dict_ae)
            batch_loss2, _ = self.session.run(fetches=[self.model_se.loss, self.model_se.opt_op],
                                              feed_dict=feed_dict_se)

            batch_loss = batch_loss1 + batch_loss2
            print('epoch {}, avg. relation triple loss: {:.4f}, cost time: {:.4f}s'.format(i, batch_loss,
                                                                                           time.time() - start))

            # ********************no early stop********************************************
            if i >= self.args.start_valid and i % self.args.eval_freq == 0:
                self.feed_dict_se = feed_dict_se
                self.feed_dict_ae = feed_dict_ae
                flag = self.valid_(self.args.stop_metric)
                self.flag1, self.flag2, self.early_stop = early_stop(self.flag1, self.flag2, flag)
                if self.early_stop or i == self.args.max_epoch:
                    break
        vec_se = self.session.run(output, feed_dict=feed_dict_se)
        vec_ae = self.session.run(self.model_ae.outputs, feed_dict=feed_dict_ae)
        self.vec_se = vec_se
        self.vec_ae = vec_ae
        return vec_se, vec_ae
Esempio n. 6
0
 def run(self):
     t = time.time()
     triples_num = self.kgs.kg1.relation_triples_num + self.kgs.kg2.relation_triples_num
     triple_steps = int(math.ceil(triples_num / self.args.batch_size))
     steps_tasks = task_divide(list(range(triple_steps)), self.args.batch_threads_num)
     manager = mp.Manager()
     training_batch_queue = manager.Queue()
     for i in range(1, self.args.max_epoch + 1):
         self.launch_training_1epo(i, triple_steps, steps_tasks, training_batch_queue, None, None)
         if i >= self.args.start_valid and i % self.args.eval_freq == 0:
             flag = self.valid(self.args.stop_metric)
             self.flag1, self.flag2, self.early_stop = early_stop(self.flag1, self.flag2, flag)
             if self.early_stop or i == self.args.max_epoch:
                 break
     print("Training ends. Total time = {:.3f} s.".format(time.time() - t))
Esempio n. 7
0
 def run(self):
     t = time.time()
     train_data = self._train_data
     for i in range(1, self.args.max_epoch + 1):
         time_i = time.time()
         last_mean_loss = self.seq_train(train_data)
         print('epoch %i, avg. batch_loss: %f,  cost time: %.4f s' %
               (i, last_mean_loss, time.time() - time_i))
         if i >= self.args.start_valid and i % self.args.eval_freq == 0:
             flag = self.valid(self.args.stop_metric)
             self.flag1, self.flag2, self.early_stop = early_stop(
                 self.flag1, self.flag2, flag)
             if self.early_stop or i >= self.args.max_epoch:
                 break
     print("Training ends. Total time = {:.3f} s.".format(time.time() - t))
Esempio n. 8
0
 def run(self):
     flag1 = 0
     flag2 = 0
     steps = len(self.sup_ent2) // self.args.batch_size
     neighbors1, neighbors2 = None, None
     if steps == 0:
         steps = 1
     for epoch in range(1, self.args.max_epoch + 1):
         start = time.time()
         epoch_loss = 0.0
         for step in range(steps):
             self.pos_link_batch, self.neg_link_batch = self.generate_input_batch(
                 self.args.batch_size,
                 neighbors1=neighbors1,
                 neighbors2=neighbors2)
             fetches = {"loss": self.loss, "optimizer": self.optimizer}
             if self.args.rel_param > 0:
                 hs, _, ts = self.generate_rel_batch()
                 feed_dict = {
                     self.rel_pos_links: self.pos_link_batch,
                     self.rel_neg_links: self.neg_link_batch,
                     self.hs: hs,
                     self.ts: ts
                 }
                 results = self.session.run(fetches=fetches,
                                            feed_dict=feed_dict)
             else:
                 feed_dict = {
                     self.pos_links: self.pos_link_batch,
                     self.neg_links: self.neg_link_batch
                 }
                 results = self.session.run(fetches=fetches,
                                            feed_dict=feed_dict)
             batch_loss = results["loss"]
             epoch_loss += batch_loss
         print('epoch {}, loss: {:.4f}, cost time: {:.4f}s'.format(
             epoch, epoch_loss,
             time.time() - start))
         if epoch % self.args.eval_freq == 0 and epoch >= self.args.start_valid:
             flag = self.valid(self.args.stop_metric)
             flag1, flag2, is_stop = early_stop(flag1, flag2, flag)
             if is_stop:
                 print("\n == training stop == \n")
                 break
             neighbors1, neighbors2 = self.find_neighbors()
             if epoch >= self.args.start_augment * self.args.eval_freq:
                 if self.args.sim_th > 0.0:
                     self.augment_neighborhood()
Esempio n. 9
0
    def training(self):
        neg_num = self.args.neg_triple_num
        train_num = len(self.kgs.train_links)
        train_links = np.array(self.kgs.train_links)
        pos = np.ones((train_num, neg_num)) * (train_links[:, 0].reshape(
            (train_num, 1)))
        neg_left = pos.reshape((train_num * neg_num, ))
        pos = np.ones((train_num, neg_num)) * (train_links[:, 1].reshape(
            (train_num, 1)))
        neg2_right = pos.reshape((train_num * neg_num, ))
        # output = self.sess.run(self.output)
        # neg2_left = get_neg(train_links[:, 1], output, self.args.neg_triple_num)
        # neg_right = get_neg(train_links[:, 0], output, self.args.neg_triple_num)
        # self.feeddict = {"neg_left:0": neg_left,
        #                  "neg_right:0": neg_right,
        #                  "neg2_left:0": neg2_left,
        #                  "neg2_right:0": neg2_right}

        for i in range(1, self.args.max_epoch + 1):
            start = time.time()
            if i % 10 == 1:
                output = self.sess.run(self.output)
                neg2_left = get_neg(train_links[:, 1], output,
                                    self.args.neg_triple_num)
                neg_right = get_neg(train_links[:, 0], output,
                                    self.args.neg_triple_num)
                self.feeddict = {
                    "neg_left:0": neg_left,
                    "neg_right:0": neg_right,
                    "neg2_left:0": neg2_left,
                    "neg2_right:0": neg2_right
                }

            _, batch_loss = self.sess.run([self.optimizer, self.loss],
                                          feed_dict=self.feeddict)
            print(
                'epoch {}, avg. relation triple loss: {:.4f}, cost time: {:.4f}s'
                .format(i, batch_loss,
                        time.time() - start))

            # ********************no early stop********************************************
            if i >= self.args.start_valid and i % self.args.eval_freq == 0:
                flag = self.valid_(self.args.stop_metric)
                self.flag1, self.flag2, self.early_stop = early_stop(
                    self.flag1, self.flag2, flag)
                if self.early_stop or i == self.args.max_epoch:
                    break
Esempio n. 10
0
File: mmea.py Progetto: zhilizl/MMEA
    def run(self):
        t = time.time()
        relation_triples_num = self.kgs.kg1.relation_triples_num + self.kgs.kg2.relation_triples_num
        attribute_triples_num = self.kgs.kg1.local_attribute_triples_num + self.kgs.kg2.local_attribute_triples_num
        relation_triple_steps = int(math.ceil(relation_triples_num / self.args.batch_size))
        attribute_triple_steps = int(math.ceil(attribute_triples_num / self.args.batch_size))
        relation_step_tasks = task_divide(list(range(relation_triple_steps)), self.args.batch_threads_num)
        attribute_step_tasks = task_divide(list(range(attribute_triple_steps)), self.args.batch_threads_num)
        manager = mp.Manager()
        relation_batch_queue = manager.Queue()
        attribute_batch_queue = manager.Queue()

        neighbors1, neighbors2 = None, None
        entity_list = self.kgs.kg1.entities_list + self.kgs.kg2.entities_list

        for i in range(1, self.args.max_epoch + 1):
            print('epoch {}:'.format(i))
            # relation
            self.train_only_relation_1epo(i, relation_triple_steps, relation_step_tasks, relation_batch_queue,
                                          neighbors1, neighbors2)
            # image
            self.train_only_image_1epo(i, entity_list)
            # attribute
            self.train_only_attribute_1epo(i, attribute_triple_steps, attribute_step_tasks, attribute_batch_queue,
                                           neighbors1, neighbors2)
            # common
            self.train_common_space_learning_1epo(i, entity_list)
            self.train_entity_mapping_1epo(i, relation_triple_steps)

            if i >= self.args.start_valid and i % self.args.eval_freq == 0:
                valid_temp(self, embed_choice='rv')
                valid_temp(self, embed_choice='iv')
                valid_temp(self, embed_choice='av')
                # valid_temp(self, embed_choice='final')
                # valid_temp(self, embed_choice='avg')
                flag = self.valid(self.args.stop_metric)
                self.flag1, self.flag2, self.early_stop = early_stop(self.flag1, self.flag2, flag)
                if self.args.early_stop and (self.early_stop or i == self.args.max_epoch):
                    break

        print("Training ends. Total time = {:.3f} s.".format(time.time() - t))
Esempio n. 11
0
    def run(self):
        t = time.time()
        # Count total number of triples and how much step will need
        triples_num = self.kgs.kg1.relation_triples_num + self.kgs.kg2.relation_triples_num
        triple_steps = int(math.ceil(triples_num / self.args.batch_size))
        # Decide which thread will get which step
        steps_tasks = task_divide(list(range(triple_steps)),
                                  self.args.batch_threads_num)
        # Initialize multiprocess architecture
        manager = mp.Manager()
        training_batch_queue = manager.Queue()
        # Initialize neighbors and how many iterations will have at max
        neighbors1, neighbors2 = None, None
        labeled_align = set()
        sub_num = self.args.sub_epoch
        iter_nums = self.args.max_epoch // sub_num
        for i in range(1, iter_nums + 1):
            print("\niteration", i)
            self.save(i)
            self.launch_training_k_epo(i, sub_num, triple_steps, steps_tasks,
                                       training_batch_queue, neighbors1,
                                       neighbors2)
            if i * sub_num >= self.args.start_valid:
                # validation using stop_metric (hits@1) -> used just to print their results
                self.valid(self.args.stop_metric)
                # Removed to be changed with our new validation procedure
                flag = self.valid_new_bootea(self.args.stop_metric_new,
                                             "bootea")
                # Check what we are doing on the test data, just for debug purpose
                # We are not using them to train
                # self.test_new()
                self.flag1, self.flag2, self.early_stop = early_stop(
                    self.flag1, self.flag2, flag)

                # Code commented here used to save best embeddings -> Not used because we rolled back
                # to just keeping the last ones
                # self.flag1, new_flag2, self.early_stop = early_stop(self.flag1, self.flag2, flag)
                # if new_flag2 >= self.flag2:
                #     self.save_best_embeds()
                # self.flag2 = new_flag2

                if self.early_stop or i == iter_nums:
                    break
            labeled_align, entities1, entities2 = bootstrapping(
                self.eval_ref_sim_mat(), self.ref_ent1, self.ref_ent2,
                labeled_align, self.args.sim_th, self.args.k,
                self.len_valid_test)
            # entities1, entities2 = self.bootstrapping_new()
            # Do an extra training trying to improve the matched embeddings
            self.train_alignment(self.kgs.kg1, self.kgs.kg2, entities1,
                                 entities2, 1)
            # self.likelihood(labeled_align)
            if i * sub_num >= self.args.start_valid:
                self.valid(self.args.stop_metric)
                self.valid_new_bootea(self.args.stop_metric_new, "bootea")
            t1 = time.time()
            assert 0.0 < self.args.truncated_epsilon < 1.0
            neighbors_num1 = int(
                (1 - self.args.truncated_epsilon) * self.kgs.kg1.entities_num)
            neighbors_num2 = int(
                (1 - self.args.truncated_epsilon) * self.kgs.kg2.entities_num)
            if neighbors1 is not None:
                del neighbors1, neighbors2
            gc.collect()
            neighbors1 = bat.generate_neighbours(
                self.eval_kg1_useful_ent_embeddings(),
                self.kgs.useful_entities_list1, neighbors_num1,
                self.args.batch_threads_num)
            neighbors2 = bat.generate_neighbours(
                self.eval_kg2_useful_ent_embeddings(),
                self.kgs.useful_entities_list2, neighbors_num2,
                self.args.batch_threads_num)
            ent_num = len(self.kgs.kg1.entities_list) + len(
                self.kgs.kg2.entities_list)
            print("generating neighbors of {} entities costs {:.3f} s.".format(
                ent_num,
                time.time() - t1))
        print("Training ends. Total time = {:.3f} s.".format(time.time() - t))