Exemplo n.º 1
0
    def evaluate(self, t_list, val=True):
        graph_dict = self.graph_dict_val if val else self.graph_dict_test
        g_forward_batched_list, t_forward_batched_list, g_backward_batched_list, t_backward_batched_list = BiDynamicRGCN.get_batch_graph_list(
            t_list, self.test_seq_len, self.graph_dict_train)
        g_val_batched_list, val_time_list, _, _ = BiDynamicRGCN.get_batch_graph_list(
            t_list, 1, graph_dict)

        hist_embeddings_forward, attn_mask_forward = self.pre_forward(
            g_forward_batched_list, t_forward_batched_list, forward=True)
        hist_embeddings_backward, attn_mask_backward = self.pre_forward(
            g_backward_batched_list, t_backward_batched_list, forward=False)

        test_graphs, _ = self.get_val_vars(g_val_batched_list, -1)
        train_graphs, time_batched_list_t = g_forward_batched_list[
            -1], t_forward_batched_list[-1]

        node_sizes = [len(g.nodes()) for g in train_graphs]
        hist_embeddings = torch.cat(
            [hist_embeddings_forward, hist_embeddings_backward],
            dim=0)  # 2 * seq_len - 2, bsz, num_ents
        attn_mask = torch.cat([
            attn_mask_forward, attn_mask_backward,
            attn_mask_forward.new_zeros(1, *attn_mask_forward.shape[1:])
        ],
                              dim=0)  # 2 * seq_len - 1, bsz, num_ents
        per_graph_ent_embeds = self.get_final_graph_embeds(train_graphs,
                                                           time_batched_list_t,
                                                           node_sizes,
                                                           hist_embeddings,
                                                           attn_mask,
                                                           full=True)

        return self.calc_metrics(per_graph_ent_embeds, test_graphs,
                                 time_batched_list_t, hist_embeddings,
                                 attn_mask)
Exemplo n.º 2
0
    def forward(self, t_list, reverse=False):
        reconstruct_loss = 0
        g_forward_batched_list, t_forward_batched_list, g_backward_batched_list, t_backward_batched_list = BiDynamicRGCN.get_batch_graph_list(t_list, self.train_seq_len, self.graph_dict_train)

        hist_embeddings_forward, attn_mask_forward = self.pre_forward(g_forward_batched_list, t_forward_batched_list, forward=True)
        hist_embeddings_backward, attn_mask_backward = self.pre_forward(g_backward_batched_list, t_backward_batched_list, forward=False)
        train_graphs, time_batched_list_t = g_forward_batched_list[-1], t_forward_batched_list[-1]

        node_sizes = [len(g.nodes()) for g in train_graphs]
        hist_embeddings = torch.cat([hist_embeddings_forward, hist_embeddings_backward], dim=0)  # 2 * seq_len - 2, bsz, num_ents
        attn_mask = torch.cat([attn_mask_forward, attn_mask_backward, attn_mask_forward.new_zeros(1, *attn_mask_forward.shape[1:])], dim=0) # 2 * seq_len - 1, bsz, num_ents
        per_graph_ent_embeds_loc, per_graph_ent_embeds_rec = self.get_final_graph_embeds(train_graphs, time_batched_list_t, node_sizes, hist_embeddings, attn_mask, full=False)

        i = 0
        for t, g, ent_embed_loc, ent_embed_rec in zip(time_batched_list_t, train_graphs, per_graph_ent_embeds_loc, per_graph_ent_embeds_rec):
            triplets, neg_tail_samples, neg_head_samples, labels = self.corrupter.single_graph_negative_sampling(t, g, self.num_ents)
            # import pdb; pdb.set_trace()
            all_embeds_g_loc, all_embeds_g_rec = self.get_all_embeds_Gt(ent_embed_loc, ent_embed_rec, g, t, hist_embeddings[:, i, 0], hist_embeddings[:, i, 1], attn_mask[:, i], val=False)
            weight_subject_query_subject_embed, weight_subject_query_object_embed, weight_object_query_subject_embed, weight_object_query_object_embed = self.calc_ensemble_ratio(triplets, t, g)
            loss_tail = self.train_link_prediction(ent_embed_loc, ent_embed_rec, triplets, neg_tail_samples, labels, all_embeds_g_loc,
                                                   all_embeds_g_rec, weight_object_query_subject_embed, weight_object_query_object_embed, corrupt_tail=True)
            loss_head = self.train_link_prediction(ent_embed_loc, ent_embed_rec, triplets, neg_head_samples, labels, all_embeds_g_loc,
                                                   all_embeds_g_rec, weight_subject_query_subject_embed, weight_subject_query_object_embed, corrupt_tail=False)
            reconstruct_loss += loss_tail + loss_head
            del all_embeds_g_loc, all_embeds_g_rec, weight_subject_query_subject_embed, weight_subject_query_object_embed, weight_object_query_subject_embed, weight_object_query_object_embed, triplets, neg_tail_samples, neg_head_samples, labels
            # pdb.set_trace()
            i += 1
        return reconstruct_loss