コード例 #1
0
ファイル: TKG_Recurrent_Module.py プロジェクト: wxd-neu/TeMP
    def forward(self, t_list, reverse=False):
        kld_loss = 0
        reconstruct_loss = 0
        h = self.h0.expand(self.num_layers, len(t_list),
                           self.hidden_size).contiguous()
        g_batched_list, time_batched_list = self.get_batch_graph_list(
            t_list, self.train_seq_len, self.graph_dict_train)

        for t in range(self.train_seq_len - 1):
            g_batched_list_t, bsz, cur_h, triplets, labels, node_sizes = self.get_val_vars(
                g_batched_list, t, h)
            # triplets, neg_tail_samples, neg_head_samples, labels = self.corrupter.samples_labels_train(time_batched_list_t, g_batched_list_t)
            per_graph_ent_embeds = self.get_per_graph_ent_embeds(
                g_batched_list_t, cur_h, node_sizes, val=True)

            pooled_fact_embeddings = []
            for i, ent_embed in enumerate(per_graph_ent_embeds):
                pooled_fact_embeddings.append(
                    self.get_pooled_facts(ent_embed, triplets[i]))
            _, h = self.rnn(
                torch.stack(pooled_fact_embeddings, dim=0).unsqueeze(0),
                h[:, :bsz])

        train_graphs, time_batched_list_t = filter_none(
            g_batched_list[-1]), filter_none(time_batched_list[-1])
        bsz = len(train_graphs)
        cur_h = h[-1][:bsz]  # bsz, hidden_size
        # run RGCN on graph to get encoded ent_embeddings and rel_embeddings in G_t

        node_sizes = [len(g.nodes()) for g in train_graphs]
        triplets, neg_tail_samples, neg_head_samples, labels = self.corrupter.single_graph_negative_sampling(
            time_batched_list_t, train_graphs)
        per_graph_ent_embeds = self.get_per_graph_ent_embeds(
            train_graphs, cur_h, node_sizes)

        for i, ent_embed in enumerate(per_graph_ent_embeds):
            loss_tail = self.train_link_prediction(ent_embed,
                                                   triplets[i],
                                                   neg_tail_samples[i],
                                                   labels[i],
                                                   corrupt_tail=True)
            loss_head = self.train_link_prediction(ent_embed,
                                                   triplets[i],
                                                   neg_head_samples[i],
                                                   labels[i],
                                                   corrupt_tail=False)
            reconstruct_loss += loss_tail + loss_head
        return reconstruct_loss, kld_loss
コード例 #2
0
ファイル: TKG_Recurrent_Module.py プロジェクト: wxd-neu/TeMP
    def evaluate(self, t_list, val=True):
        graph_dict = self.graph_dict_val if val else self.graph_dict_test
        h = self.h0.expand(self.num_layers, len(t_list),
                           self.hidden_size).contiguous()
        g_train_batched_list, time_list = self.get_batch_graph_list(
            t_list, self.test_seq_len, self.graph_dict_train)
        g_batched_list, val_time_list = self.get_batch_graph_list(
            t_list, 1, graph_dict)

        for t in range(self.test_seq_len - 1):
            g_batched_list_t, bsz, cur_h, triplets, labels, node_sizes = self.get_val_vars(
                g_train_batched_list, t, h)
            per_graph_ent_embeds = self.get_per_graph_ent_embeds(
                g_batched_list_t, cur_h, node_sizes, val=True)

            pooled_fact_embeddings = []
            for i, ent_embed in enumerate(per_graph_ent_embeds):
                pooled_fact_embeddings.append(
                    self.get_pooled_facts(ent_embed, triplets[i]))

            _, h = self.rnn(
                torch.stack(pooled_fact_embeddings, dim=0).unsqueeze(0),
                h[:, :bsz])

        test_graph, bsz, cur_h, triplets, labels, _ = self.get_val_vars(
            g_batched_list, -1, h)
        train_graph = filter_none(g_train_batched_list[-1])
        node_sizes = [len(g.nodes()) for g in train_graph]
        per_graph_ent_embeds = self.get_per_graph_ent_embeds(train_graph,
                                                             cur_h,
                                                             node_sizes,
                                                             val=True)
        return self.calc_metrics(per_graph_ent_embeds, time_list[-1], triplets,
                                 labels)
コード例 #3
0
 def get_val_vars(self, g_batched_list, t, h):
     g_batched_list_t = filter_none(g_batched_list[t])
     bsz = len(g_batched_list_t)
     samples, labels = self.corrupter.sample_labels_val(g_batched_list_t)
     # run RGCN on graph to get encoded ent_embeddings and rel_embeddings in G_t
     node_sizes = [len(g.nodes()) for g in g_batched_list_t]
     return g_batched_list_t, bsz, samples, labels, node_sizes
コード例 #4
0
ファイル: DynamicRGCN.py プロジェクト: wxd-neu/TeMP
    def get_per_graph_ent_dropout_embeds(self, cur_time_list, target_time_list, node_sizes, time_diff_tensor, first_prev_graph_embeds, second_prev_graph_embeds):
        batched_graph = self.get_batch_graph_dropout_embeds(filter_none(cur_time_list), target_time_list)
        if self.use_cuda:
            move_dgl_to_cuda(batched_graph)
        first_layer_embeds, second_layer_embeds = self.ent_encoder(batched_graph, first_prev_graph_embeds, second_prev_graph_embeds, time_diff_tensor, cur_time_list, node_sizes)

        return first_layer_embeds.split(node_sizes), second_layer_embeds.split(node_sizes)
コード例 #5
0
ファイル: PostBiDynamicRGCN.py プロジェクト: wxd-neu/TeMP
    def get_per_graph_ent_dropout_embeds_one_direction(
            self, cur_time_list, target_time_list, node_sizes,
            time_diff_tensor, first_prev_graph_embeds,
            second_prev_graph_embeds, forward):
        batched_graph = self.get_batch_graph_dropout_embeds(
            filter_none(cur_time_list), target_time_list)
        if self.use_cuda:
            move_dgl_to_cuda(batched_graph)
        second_local_embeds, first_layer_embeds, second_layer_embeds = self.ent_encoder.forward_post_ensemble_one_direction(
            batched_graph, first_prev_graph_embeds, second_prev_graph_embeds,
            time_diff_tensor, cur_time_list, node_sizes, forward)

        return second_local_embeds.split(node_sizes), first_layer_embeds.split(
            node_sizes), second_layer_embeds.split(node_sizes)
コード例 #6
0
    def forward(self, t_list, reverse=False):
        kld_loss = 0
        reconstruct_loss = 0
        h = self.h0.expand(self.num_layers, len(t_list),
                           self.hidden_size).contiguous()
        g_batched_list, time_batched_list = self.get_batch_graph_list(
            t_list, self.train_seq_len, self.graph_dict_train)

        for t in range(self.train_seq_len):
            g_batched_list_t, time_batched_list_t = filter_none(
                g_batched_list[t]), filter_none(time_batched_list[t])
            bsz = len(g_batched_list_t)
            cur_h = h[-1][:bsz]  # bsz, hidden_size
            # run RGCN on graph to get encoded ent_embeddings and rel_embeddings in G_t

            node_sizes = [len(g.nodes()) for g in g_batched_list_t]
            triplets, neg_tail_samples, neg_head_samples, labels = self.corrupter.single_graph_negative_sampling(
                time_batched_list_t, g_batched_list_t)

            per_graph_ent_mean, per_graph_ent_std, ent_enc_means, ent_enc_stds = \
                self.get_posterior_embeddings(g_batched_list_t, cur_h, node_sizes)
            # run distmult decoding
            pooled_fact_embeddings = []
            i = 0
            for ent_mean, ent_std in zip(per_graph_ent_mean,
                                         per_graph_ent_std):
                if self.use_VAE:
                    loss_tail = self.train_reparametrize_link_prediction(
                        ent_mean,
                        ent_std,
                        triplets[i],
                        neg_tail_samples[i],
                        labels[i],
                        corrupt_tail=True)
                    loss_head = self.train_reparametrize_link_prediction(
                        ent_mean,
                        ent_std,
                        triplets[i],
                        neg_head_samples[i],
                        labels[i],
                        corrupt_tail=False)
                else:
                    # loss_tail = self.train_reparametrize_link_prediction(ent_mean, ent_mean.new_zeros(ent_mean.shape), triplets[i], neg_tail_samples[i], labels[i], corrupt_tail=True)
                    # loss_head = self.train_reparametrize_link_prediction(ent_mean, ent_mean.new_zeros(ent_mean.shape), triplets[i], neg_head_samples[i], labels[i], corrupt_tail=False)
                    loss_tail = self.train_link_prediction(ent_mean,
                                                           triplets[i],
                                                           neg_tail_samples[i],
                                                           labels[i],
                                                           corrupt_tail=True)
                    loss_head = self.train_link_prediction(ent_mean,
                                                           triplets[i],
                                                           neg_head_samples[i],
                                                           labels[i],
                                                           corrupt_tail=False)

                pooled_fact_embeddings.append(
                    self.get_pooled_facts(ent_mean, triplets[i]))
                reconstruct_loss += loss_tail + loss_head
                i += 1

            # get all the prior ent_embeddings and rel_embeddings in G_t
            if self.use_VAE:
                prior_ent_means, prior_ent_stds = self.get_prior_from_hidden(
                    g_batched_list_t, node_sizes, cur_h)
                kld_loss += self.kld_gauss(ent_enc_means, ent_enc_stds,
                                           prior_ent_means, prior_ent_stds)
                kld_loss += self.kld_gauss(self.rel_embeds,
                                           F.softplus(self.rel_enc_stds),
                                           self.rel_prior_means,
                                           self.rel_prior_std)

            _, h = self.rnn(
                torch.stack(pooled_fact_embeddings, dim=0).unsqueeze(0),
                h[:, :bsz])
        return reconstruct_loss, kld_loss
コード例 #7
0
 def get_val_vars(self, g_batched_list, t):
     g_batched_list_t = filter_none(g_batched_list[t])
     # run RGCN on graph to get encoded ent_embeddings and rel_embeddings in G_t
     node_sizes = [len(g.nodes()) for g in g_batched_list_t]
     return g_batched_list_t, node_sizes