class AttConf(object): def __init__(self): self.train_seeds_ratio = 0.3 self.dim = 128 self.drop_out = 0.0 self.layer_num = 2 self.epoch_num = 100 self.nega_sample_freq = 5 self.nega_sample_num = 25 self.learning_rate = 0.001 self.l2_regularization = 1e-2 self.margin_gamma = 1.0 self.log_comment = "comment" self.structure_channel = False self.name_channel = False self.attribute_value_channel = False self.literal_attribute_channel = False self.digit_attribute_channel = False self.load_new_seed_split = False def set_load_new_seed_split(self, load_new_seed_split): self.load_new_seed_split = load_new_seed_split def set_channel(self, channel_name): if channel_name == 'Literal': self.set_literal_attribute_channel(True) elif channel_name == 'Digital': self.set_digit_attribute_channel(True) elif channel_name == 'Attribute': self.set_attribute_value_channel(True) elif channel_name == 'Structure': self.set_structure_channel(True) elif channel_name == 'Name': self.set_name_channel(True) else: raise Exception() def set_epoch_num(self, epoch_num): self.epoch_num = epoch_num def set_nega_sample_num(self, nega_sample_num): self.nega_sample_num = nega_sample_num def set_log_comment(self, log_comment): self.log_comment = log_comment def set_name_channel(self, use_name_channel): self.name_channel = use_name_channel def set_digit_attribute_channel(self, use_digit_attribute_channel): self.digit_attribute_channel = use_digit_attribute_channel def set_literal_attribute_channel(self, use_literal_attribute_channel): self.literal_attribute_channel = use_literal_attribute_channel def set_attribute_value_channel(self, use_attribute_value_channel): self.attribute_value_channel = use_attribute_value_channel def set_structure_channel(self, use_structure_channel): self.structure_channel = use_structure_channel def set_drop_out(self, drop_out): self.drop_out = drop_out def set_learning_rate(self, learning_rate): self.learning_rate = learning_rate def set_l2_regularization(self, l2_regularization): self.l2_regularization = l2_regularization def print_parameter(self, file=None): parameters = self.__dict__ print_time_info('Parameter setttings:', dash_top=True, file=file) for key, value in parameters.items(): if type(value) in {int, float, str, bool}: print('\t%s:' % key, value, file=file) print('---------------------------------------', file=file) def init_log(self, log_dir): log_dir = Path(log_dir) self.log_dir = log_dir if log_dir.exists(): rmtree(str(log_dir), ignore_errors=True) print_time_info("Warning! Forced remove directory %s." % (str(log_dir))) log_dir.mkdir() comment = log_dir.name with open(log_dir / 'parameters.txt', 'w') as f: print_time_info(comment, file=f) self.print_parameter(f) def init(self, directory, device): set_random_seed() self.directory = Path(directory) self.loaded_data = LoadData( self.train_seeds_ratio, self.directory, self.nega_sample_num, name_channel=self.name_channel, attribute_channel=self.attribute_value_channel, digit_literal_channel=self.digit_attribute_channel or self.literal_attribute_channel, load_new_seed_split=self.load_new_seed_split, device=device) self.sr_ent_num = self.loaded_data.sr_ent_num self.tg_ent_num = self.loaded_data.tg_ent_num self.att_num = self.loaded_data.att_num # Init graph adjacent matrix print_time_info('Begin preprocessing adjacent matrix') self.channels = {} edges_sr = torch.tensor(self.loaded_data.triples_sr)[:, :2] edges_tg = torch.tensor(self.loaded_data.triples_tg)[:, :2] edges_sr = torch.unique(edges_sr, dim=0) edges_tg = torch.unique(edges_tg, dim=0) if self.name_channel: self.channels['name'] = { 'edges_sr': edges_sr, 'edges_tg': edges_tg, 'sr_ent_embed': self.loaded_data.sr_embed, 'tg_ent_embed': self.loaded_data.tg_embed, } if self.structure_channel: self.channels['structure'] = { 'edges_sr': edges_sr, 'edges_tg': edges_tg } if self.attribute_value_channel: self.channels['attribute'] = { 'edges_sr': edges_sr, 'edges_tg': edges_tg, 'att_num': self.loaded_data.att_num, 'attribute_triples_sr': self.loaded_data.attribute_triples_sr, 'attribute_triples_tg': self.loaded_data.attribute_triples_tg, 'value_embedding': self.loaded_data.value_embedding } if self.literal_attribute_channel: self.channels['attribute'] = { 'edges_sr': edges_sr, 'edges_tg': edges_tg, 'att_num': self.loaded_data.literal_att_num, 'attribute_triples_sr': self.loaded_data.literal_triples_sr, 'attribute_triples_tg': self.loaded_data.literal_triples_tg, 'value_embedding': self.loaded_data.literal_value_embedding } if self.digit_attribute_channel: self.channels['attribute'] = { 'edges_sr': edges_sr, 'edges_tg': edges_tg, 'att_num': self.loaded_data.digit_att_num, 'attribute_triples_sr': self.loaded_data.digital_triples_sr, 'attribute_triples_tg': self.loaded_data.digital_triples_tg, 'value_embedding': self.loaded_data.digit_value_embedding } print_time_info('Finished preprocesssing adjacent matrix') def train(self, device): set_random_seed() self.loaded_data.negative_sample() # Compose Graph NN gnn_channel = GNNChannel(self.sr_ent_num, self.tg_ent_num, self.dim, self.layer_num, self.drop_out, self.channels) self.gnn_channel = gnn_channel gnn_channel.to(device) gnn_channel.train() # Prepare optimizer optimizer = Adagrad(filter(lambda p: p.requires_grad, gnn_channel.parameters()), lr=self.learning_rate, weight_decay=self.l2_regularization) criterion = AlignLoss(self.margin_gamma) best_hit_at_1 = 0 best_epoch_num = 0 for epoch_num in range(1, self.epoch_num + 1): gnn_channel.train() optimizer.zero_grad() sr_seed_hid, tg_seed_hid, _, _ = gnn_channel.forward( self.loaded_data.train_sr_ent_seeds, self.loaded_data.train_tg_ent_seeds) loss = criterion(sr_seed_hid, tg_seed_hid) loss.backward() optimizer.step() if epoch_num % self.nega_sample_freq == 0: if str(self.directory).find('DWY100k') >= 0: self.loaded_data.negative_sample() else: self.negative_sample() hit_at_1 = self.evaluate(epoch_num, gnn_channel, print_info=False, device=device) if hit_at_1 > best_hit_at_1: best_hit_at_1 = hit_at_1 best_epoch_num = epoch_num print('Model best Hit@1 on valid set is %.2f at %d epoch.' % (best_hit_at_1, best_epoch_num)) return best_hit_at_1, best_epoch_num def evaluate(self, epoch_num, info_gnn, print_info=True, device='cpu'): info_gnn.eval() sim = info_gnn.predict(self.loaded_data.valid_sr_ent_seeds, self.loaded_data.valid_tg_ent_seeds) top_lr, top_rl, mr_lr, mr_rl, mrr_lr, mrr_rl = get_hits( sim, print_info=print_info, device=device) hit_at_1 = (top_lr[0] + top_rl[0]) / 2 return hit_at_1 def negative_sample(self, ): sim_sr, sim_tg = self.gnn_channel.negative_sample( self.loaded_data.train_sr_ent_seeds_ori, self.loaded_data.train_tg_ent_seeds_ori) sr_nns = get_nearest_neighbor(sim_sr, self.nega_sample_num) tg_nns = get_nearest_neighbor(sim_tg, self.nega_sample_num) self.loaded_data.update_negative_sample(sr_nns, tg_nns) def save_sim_matrix(self, device): # Get the similarity matrix of the current model self.gnn_channel.eval() sim_train = self.gnn_channel.predict( self.loaded_data.train_sr_ent_seeds_ori, self.loaded_data.train_tg_ent_seeds_ori) sim_valid = self.gnn_channel.predict( self.loaded_data.valid_sr_ent_seeds, self.loaded_data.valid_tg_ent_seeds) sim_test = self.gnn_channel.predict(self.loaded_data.test_sr_ent_seeds, self.loaded_data.test_tg_ent_seeds) get_hits(sim_test, print_info=True, device=device) print_time_info('Best result on the test set', dash_top=True) sim_train = sim_train.cpu().numpy() sim_valid = sim_valid.cpu().numpy() sim_test = sim_test.cpu().numpy() def save_sim(sim, comment): if sim.shape[0] > 20000: partial_sim = sort_and_keep_indices(sim, device) partial_sim_t = sort_and_keep_indices(sim.T, device) np.save(str(self.log_dir / ('%s_sim.npy' % comment)), partial_sim) np.save(str(self.log_dir / ('%s_sim_t.npy' % comment)), partial_sim_t) else: np.save(str(self.log_dir / ('%s_sim.npy' % comment)), sim) save_sim(sim_train, 'train') save_sim(sim_valid, 'valid') save_sim(sim_test, 'test') print_time_info( "Model configs and predictions saved to directory: %s." % str(self.log_dir)) def save_model(self): save_path = self.log_dir / 'model.pt' state_dict = self.gnn_channel.state_dict() state_dict = OrderedDict( filter(lambda x: x[1].layout != torch.sparse_coo, state_dict.items())) torch.save(state_dict, str(save_path)) print_time_info("Model is saved to directory: %s." % str(self.log_dir))