Beispiel #1
0
class AttConf(object):
    def __init__(self):
        self.train_seeds_ratio = 0.3
        self.dim = 128
        self.drop_out = 0.0
        self.layer_num = 2
        self.epoch_num = 100
        self.nega_sample_freq = 5
        self.nega_sample_num = 25

        self.learning_rate = 0.001
        self.l2_regularization = 1e-2
        self.margin_gamma = 1.0

        self.log_comment = "comment"

        self.structure_channel = False
        self.name_channel = False
        self.attribute_value_channel = False
        self.literal_attribute_channel = False
        self.digit_attribute_channel = False

        self.load_new_seed_split = False

    def set_load_new_seed_split(self, load_new_seed_split):
        self.load_new_seed_split = load_new_seed_split

    def set_channel(self, channel_name):
        if channel_name == 'Literal':
            self.set_literal_attribute_channel(True)
        elif channel_name == 'Digital':
            self.set_digit_attribute_channel(True)
        elif channel_name == 'Attribute':
            self.set_attribute_value_channel(True)
        elif channel_name == 'Structure':
            self.set_structure_channel(True)
        elif channel_name == 'Name':
            self.set_name_channel(True)
        else:
            raise Exception()

    def set_epoch_num(self, epoch_num):
        self.epoch_num = epoch_num

    def set_nega_sample_num(self, nega_sample_num):
        self.nega_sample_num = nega_sample_num

    def set_log_comment(self, log_comment):
        self.log_comment = log_comment

    def set_name_channel(self, use_name_channel):
        self.name_channel = use_name_channel

    def set_digit_attribute_channel(self, use_digit_attribute_channel):
        self.digit_attribute_channel = use_digit_attribute_channel

    def set_literal_attribute_channel(self, use_literal_attribute_channel):
        self.literal_attribute_channel = use_literal_attribute_channel

    def set_attribute_value_channel(self, use_attribute_value_channel):
        self.attribute_value_channel = use_attribute_value_channel

    def set_structure_channel(self, use_structure_channel):
        self.structure_channel = use_structure_channel

    def set_drop_out(self, drop_out):
        self.drop_out = drop_out

    def set_learning_rate(self, learning_rate):
        self.learning_rate = learning_rate

    def set_l2_regularization(self, l2_regularization):
        self.l2_regularization = l2_regularization

    def print_parameter(self, file=None):
        parameters = self.__dict__
        print_time_info('Parameter setttings:', dash_top=True, file=file)
        for key, value in parameters.items():
            if type(value) in {int, float, str, bool}:
                print('\t%s:' % key, value, file=file)
        print('---------------------------------------', file=file)

    def init_log(self, log_dir):
        log_dir = Path(log_dir)
        self.log_dir = log_dir
        if log_dir.exists():
            rmtree(str(log_dir), ignore_errors=True)
            print_time_info("Warning! Forced remove directory %s." %
                            (str(log_dir)))
        log_dir.mkdir()
        comment = log_dir.name
        with open(log_dir / 'parameters.txt', 'w') as f:
            print_time_info(comment, file=f)
            self.print_parameter(f)

    def init(self, directory, device):
        set_random_seed()
        self.directory = Path(directory)
        self.loaded_data = LoadData(
            self.train_seeds_ratio,
            self.directory,
            self.nega_sample_num,
            name_channel=self.name_channel,
            attribute_channel=self.attribute_value_channel,
            digit_literal_channel=self.digit_attribute_channel
            or self.literal_attribute_channel,
            load_new_seed_split=self.load_new_seed_split,
            device=device)
        self.sr_ent_num = self.loaded_data.sr_ent_num
        self.tg_ent_num = self.loaded_data.tg_ent_num
        self.att_num = self.loaded_data.att_num

        # Init graph adjacent matrix
        print_time_info('Begin preprocessing adjacent matrix')
        self.channels = {}

        edges_sr = torch.tensor(self.loaded_data.triples_sr)[:, :2]
        edges_tg = torch.tensor(self.loaded_data.triples_tg)[:, :2]
        edges_sr = torch.unique(edges_sr, dim=0)
        edges_tg = torch.unique(edges_tg, dim=0)

        if self.name_channel:
            self.channels['name'] = {
                'edges_sr': edges_sr,
                'edges_tg': edges_tg,
                'sr_ent_embed': self.loaded_data.sr_embed,
                'tg_ent_embed': self.loaded_data.tg_embed,
            }
        if self.structure_channel:
            self.channels['structure'] = {
                'edges_sr': edges_sr,
                'edges_tg': edges_tg
            }
        if self.attribute_value_channel:
            self.channels['attribute'] = {
                'edges_sr': edges_sr,
                'edges_tg': edges_tg,
                'att_num': self.loaded_data.att_num,
                'attribute_triples_sr': self.loaded_data.attribute_triples_sr,
                'attribute_triples_tg': self.loaded_data.attribute_triples_tg,
                'value_embedding': self.loaded_data.value_embedding
            }
        if self.literal_attribute_channel:
            self.channels['attribute'] = {
                'edges_sr': edges_sr,
                'edges_tg': edges_tg,
                'att_num': self.loaded_data.literal_att_num,
                'attribute_triples_sr': self.loaded_data.literal_triples_sr,
                'attribute_triples_tg': self.loaded_data.literal_triples_tg,
                'value_embedding': self.loaded_data.literal_value_embedding
            }
        if self.digit_attribute_channel:
            self.channels['attribute'] = {
                'edges_sr': edges_sr,
                'edges_tg': edges_tg,
                'att_num': self.loaded_data.digit_att_num,
                'attribute_triples_sr': self.loaded_data.digital_triples_sr,
                'attribute_triples_tg': self.loaded_data.digital_triples_tg,
                'value_embedding': self.loaded_data.digit_value_embedding
            }
        print_time_info('Finished preprocesssing adjacent matrix')

    def train(self, device):
        set_random_seed()
        self.loaded_data.negative_sample()
        # Compose Graph NN
        gnn_channel = GNNChannel(self.sr_ent_num, self.tg_ent_num, self.dim,
                                 self.layer_num, self.drop_out, self.channels)
        self.gnn_channel = gnn_channel
        gnn_channel.to(device)
        gnn_channel.train()

        # Prepare optimizer
        optimizer = Adagrad(filter(lambda p: p.requires_grad,
                                   gnn_channel.parameters()),
                            lr=self.learning_rate,
                            weight_decay=self.l2_regularization)
        criterion = AlignLoss(self.margin_gamma)

        best_hit_at_1 = 0
        best_epoch_num = 0

        for epoch_num in range(1, self.epoch_num + 1):
            gnn_channel.train()
            optimizer.zero_grad()
            sr_seed_hid, tg_seed_hid, _, _ = gnn_channel.forward(
                self.loaded_data.train_sr_ent_seeds,
                self.loaded_data.train_tg_ent_seeds)
            loss = criterion(sr_seed_hid, tg_seed_hid)
            loss.backward()
            optimizer.step()
            if epoch_num % self.nega_sample_freq == 0:
                if str(self.directory).find('DWY100k') >= 0:
                    self.loaded_data.negative_sample()
                else:
                    self.negative_sample()
                hit_at_1 = self.evaluate(epoch_num,
                                         gnn_channel,
                                         print_info=False,
                                         device=device)
                if hit_at_1 > best_hit_at_1:
                    best_hit_at_1 = hit_at_1
                    best_epoch_num = epoch_num
        print('Model best Hit@1 on valid set is %.2f at %d epoch.' %
              (best_hit_at_1, best_epoch_num))
        return best_hit_at_1, best_epoch_num

    def evaluate(self, epoch_num, info_gnn, print_info=True, device='cpu'):
        info_gnn.eval()
        sim = info_gnn.predict(self.loaded_data.valid_sr_ent_seeds,
                               self.loaded_data.valid_tg_ent_seeds)
        top_lr, top_rl, mr_lr, mr_rl, mrr_lr, mrr_rl = get_hits(
            sim, print_info=print_info, device=device)
        hit_at_1 = (top_lr[0] + top_rl[0]) / 2
        return hit_at_1

    def negative_sample(self, ):
        sim_sr, sim_tg = self.gnn_channel.negative_sample(
            self.loaded_data.train_sr_ent_seeds_ori,
            self.loaded_data.train_tg_ent_seeds_ori)
        sr_nns = get_nearest_neighbor(sim_sr, self.nega_sample_num)
        tg_nns = get_nearest_neighbor(sim_tg, self.nega_sample_num)
        self.loaded_data.update_negative_sample(sr_nns, tg_nns)

    def save_sim_matrix(self, device):
        # Get the similarity matrix of the current model
        self.gnn_channel.eval()
        sim_train = self.gnn_channel.predict(
            self.loaded_data.train_sr_ent_seeds_ori,
            self.loaded_data.train_tg_ent_seeds_ori)
        sim_valid = self.gnn_channel.predict(
            self.loaded_data.valid_sr_ent_seeds,
            self.loaded_data.valid_tg_ent_seeds)
        sim_test = self.gnn_channel.predict(self.loaded_data.test_sr_ent_seeds,
                                            self.loaded_data.test_tg_ent_seeds)
        get_hits(sim_test, print_info=True, device=device)
        print_time_info('Best result on the test set', dash_top=True)
        sim_train = sim_train.cpu().numpy()
        sim_valid = sim_valid.cpu().numpy()
        sim_test = sim_test.cpu().numpy()

        def save_sim(sim, comment):
            if sim.shape[0] > 20000:
                partial_sim = sort_and_keep_indices(sim, device)
                partial_sim_t = sort_and_keep_indices(sim.T, device)
                np.save(str(self.log_dir / ('%s_sim.npy' % comment)),
                        partial_sim)
                np.save(str(self.log_dir / ('%s_sim_t.npy' % comment)),
                        partial_sim_t)
            else:
                np.save(str(self.log_dir / ('%s_sim.npy' % comment)), sim)

        save_sim(sim_train, 'train')
        save_sim(sim_valid, 'valid')
        save_sim(sim_test, 'test')
        print_time_info(
            "Model configs and predictions saved to directory: %s." %
            str(self.log_dir))

    def save_model(self):
        save_path = self.log_dir / 'model.pt'
        state_dict = self.gnn_channel.state_dict()
        state_dict = OrderedDict(
            filter(lambda x: x[1].layout != torch.sparse_coo,
                   state_dict.items()))
        torch.save(state_dict, str(save_path))
        print_time_info("Model is saved to directory: %s." % str(self.log_dir))