Beispiel #1
0
    def forward(self, t_list, reverse=False):
        reconstruct_loss = 0
        g_forward_batched_list, t_forward_batched_list, g_backward_batched_list, t_backward_batched_list = self.get_batch_graph_list(
            t_list, self.train_seq_len, self.graph_dict_train)
        hist_embeddings_forward_loc, hist_embeddings_forward_rec, start_time_tensor_forward = self.pre_forward(
            g_forward_batched_list, t_forward_batched_list, forward=True)
        hist_embeddings_backward_loc, hist_embeddings_backward_rec, start_time_tensor_backward = self.pre_forward(
            g_backward_batched_list, t_backward_batched_list, forward=False)

        train_graphs, time_batched_list_t = g_forward_batched_list[
            -1], t_forward_batched_list[-1]
        _, per_graph_ent_embeds_rec = self.get_final_graph_embeds(
            train_graphs,
            time_batched_list_t,
            self.train_seq_len,
            hist_embeddings_forward_loc,
            hist_embeddings_forward_rec,
            start_time_tensor_forward,
            hist_embeddings_backward_loc,
            hist_embeddings_backward_rec,
            start_time_tensor_backward,
            full=False)

        i = 0
        for t, g, ent_embed in zip(time_batched_list_t, train_graphs,
                                   per_graph_ent_embeds_rec):
            triplets, neg_tail_samples, neg_head_samples, labels = self.corrupter.single_graph_negative_sampling(
                t, g, self.num_ents)
            time_diff_tensor_forward = self.train_seq_len - 1 - start_time_tensor_forward[
                i]
            time_diff_tensor_backward = self.train_seq_len - 1 - start_time_tensor_backward[
                i]
            all_embeds_g = self.get_all_embeds_Gt(
                ent_embed, g, t, hist_embeddings_forward_loc[i],
                hist_embeddings_forward_rec[i][0],
                hist_embeddings_forward_rec[i][1], time_diff_tensor_forward,
                hist_embeddings_backward_loc[i],
                hist_embeddings_backward_rec[i][0],
                hist_embeddings_backward_rec[i][1], time_diff_tensor_backward)
            loss_tail = BiDynamicRGCN.train_link_prediction(self,
                                                            ent_embed,
                                                            triplets,
                                                            neg_tail_samples,
                                                            labels,
                                                            all_embeds_g,
                                                            corrupt_tail=True)
            loss_head = BiDynamicRGCN.train_link_prediction(self,
                                                            ent_embed,
                                                            triplets,
                                                            neg_head_samples,
                                                            labels,
                                                            all_embeds_g,
                                                            corrupt_tail=False)
            reconstruct_loss += loss_tail + loss_head
            i += 1
        return reconstruct_loss
Beispiel #2
0
    def evaluate(self, t_list, val=True):
        graph_dict = self.graph_dict_val if val else self.graph_dict_test
        g_forward_batched_list, t_forward_batched_list, g_backward_batched_list, t_backward_batched_list = BiDynamicRGCN.get_batch_graph_list(
            t_list, self.test_seq_len, self.graph_dict_train)
        g_val_batched_list, val_time_list, _, _ = BiDynamicRGCN.get_batch_graph_list(
            t_list, 1, graph_dict)

        hist_embeddings_forward, attn_mask_forward = self.pre_forward(
            g_forward_batched_list, t_forward_batched_list, forward=True)
        hist_embeddings_backward, attn_mask_backward = self.pre_forward(
            g_backward_batched_list, t_backward_batched_list, forward=False)

        test_graphs, _ = self.get_val_vars(g_val_batched_list, -1)
        train_graphs, time_batched_list_t = g_forward_batched_list[
            -1], t_forward_batched_list[-1]

        node_sizes = [len(g.nodes()) for g in train_graphs]
        hist_embeddings = torch.cat(
            [hist_embeddings_forward, hist_embeddings_backward],
            dim=0)  # 2 * seq_len - 2, bsz, num_ents
        attn_mask = torch.cat([
            attn_mask_forward, attn_mask_backward,
            attn_mask_forward.new_zeros(1, *attn_mask_forward.shape[1:])
        ],
                              dim=0)  # 2 * seq_len - 1, bsz, num_ents
        per_graph_ent_embeds = self.get_final_graph_embeds(train_graphs,
                                                           time_batched_list_t,
                                                           node_sizes,
                                                           hist_embeddings,
                                                           attn_mask,
                                                           full=True)

        return self.calc_metrics(per_graph_ent_embeds, test_graphs,
                                 time_batched_list_t, hist_embeddings,
                                 attn_mask)
Beispiel #3
0
    def forward(self, t_list, reverse=False):
        reconstruct_loss = 0
        g_forward_batched_list, t_forward_batched_list, g_backward_batched_list, t_backward_batched_list = BiDynamicRGCN.get_batch_graph_list(t_list, self.train_seq_len, self.graph_dict_train)

        hist_embeddings_forward, attn_mask_forward = self.pre_forward(g_forward_batched_list, t_forward_batched_list, forward=True)
        hist_embeddings_backward, attn_mask_backward = self.pre_forward(g_backward_batched_list, t_backward_batched_list, forward=False)
        train_graphs, time_batched_list_t = g_forward_batched_list[-1], t_forward_batched_list[-1]

        node_sizes = [len(g.nodes()) for g in train_graphs]
        hist_embeddings = torch.cat([hist_embeddings_forward, hist_embeddings_backward], dim=0)  # 2 * seq_len - 2, bsz, num_ents
        attn_mask = torch.cat([attn_mask_forward, attn_mask_backward, attn_mask_forward.new_zeros(1, *attn_mask_forward.shape[1:])], dim=0) # 2 * seq_len - 1, bsz, num_ents
        per_graph_ent_embeds_loc, per_graph_ent_embeds_rec = self.get_final_graph_embeds(train_graphs, time_batched_list_t, node_sizes, hist_embeddings, attn_mask, full=False)

        i = 0
        for t, g, ent_embed_loc, ent_embed_rec in zip(time_batched_list_t, train_graphs, per_graph_ent_embeds_loc, per_graph_ent_embeds_rec):
            triplets, neg_tail_samples, neg_head_samples, labels = self.corrupter.single_graph_negative_sampling(t, g, self.num_ents)
            # import pdb; pdb.set_trace()
            all_embeds_g_loc, all_embeds_g_rec = self.get_all_embeds_Gt(ent_embed_loc, ent_embed_rec, g, t, hist_embeddings[:, i, 0], hist_embeddings[:, i, 1], attn_mask[:, i], val=False)
            weight_subject_query_subject_embed, weight_subject_query_object_embed, weight_object_query_subject_embed, weight_object_query_object_embed = self.calc_ensemble_ratio(triplets, t, g)
            loss_tail = self.train_link_prediction(ent_embed_loc, ent_embed_rec, triplets, neg_tail_samples, labels, all_embeds_g_loc,
                                                   all_embeds_g_rec, weight_object_query_subject_embed, weight_object_query_object_embed, corrupt_tail=True)
            loss_head = self.train_link_prediction(ent_embed_loc, ent_embed_rec, triplets, neg_head_samples, labels, all_embeds_g_loc,
                                                   all_embeds_g_rec, weight_subject_query_subject_embed, weight_subject_query_object_embed, corrupt_tail=False)
            reconstruct_loss += loss_tail + loss_head
            del all_embeds_g_loc, all_embeds_g_rec, weight_subject_query_subject_embed, weight_subject_query_object_embed, weight_object_query_subject_embed, weight_object_query_object_embed, triplets, neg_tail_samples, neg_head_samples, labels
            # pdb.set_trace()
            i += 1
        return reconstruct_loss
Beispiel #4
0
    def __init__(self, args, num_ents, num_rels, graph_dict_train,
                 graph_dict_val, graph_dict_test):

        super(Aggregator,
              self).__init__(args, num_ents, num_rels, graph_dict_train,
                             graph_dict_val, graph_dict_test)
        module = {
            "GRRGCN": DynamicRGCN,
            "RRGCN": DynamicRGCN,
            "SARGCN": SelfAttentionRGCN,
            "BiGRRGCN": BiDynamicRGCN,
            "BiRRGCN": BiDynamicRGCN,
            "BiSARGCN": BiSelfAttentionRGCN
        }[args.temporal_module]

        # import pdb; pdb.set_trace()

        self.graph_dict_total = {
            **self.graph_dict_train,
            **self.graph_dict_val,
            **self.graph_dict_test
        }

        self.get_true_head_and_tail_all()

        spatial_path = args.spatial_checkpoint
        temporal_path = args.temporal_checkpoint

        if args.debug:
            self.bidirectional = True
            self.drop_edge = DropEdge(args, graph_dict_train, graph_dict_val,
                                      graph_dict_test)
            # self.drop_edge.count_frequency()
            self.spatial_model = StaticRGCN(args, num_ents, num_rels,
                                            graph_dict_train, graph_dict_val,
                                            graph_dict_test)
            args.module = 'BiGRRGCN'
            self.temporal_model = BiDynamicRGCN(args, num_ents, num_rels,
                                                graph_dict_train,
                                                graph_dict_val,
                                                graph_dict_test)

            self.train_seq_len = args.train_seq_len
            self.test_seq_len = args.train_seq_len
        else:
            checkpoint_path = glob.glob(
                os.path.join(temporal_path, "checkpoints", "*.ckpt"))[0]
            temporal_checkpoint = torch.load(
                checkpoint_path, map_location=lambda storage, loc: storage)
            temporal_config_path = os.path.join(temporal_path, "config.json")
            temporal_args_json = json.load(open(temporal_config_path))
            temporal_args = process_args()
            temporal_args.__dict__.update(dict(temporal_args_json))
            self.drop_edge = DropEdge(temporal_args, graph_dict_train,
                                      graph_dict_val, graph_dict_test)
            # self.drop_edge.count_frequency()

            if module == BiDynamicRGCN:
                if temporal_args.post_aggregation:
                    module = PostBiDynamicRGCN
                elif temporal_args.post_ensemble:
                    module = PostEnsembleBiDynamicRGCN
                elif temporal_args.impute:
                    module = ImputeBiDynamicRGCN

            elif module == DynamicRGCN:
                if temporal_args.post_aggregation:
                    module = PostDynamicRGCN
                if temporal_args.post_ensemble:
                    module = PostEnsembleDynamicRGCN
                elif temporal_args.impute:
                    module = ImputeDynamicRGCN

            elif module == BiSelfAttentionRGCN:
                if temporal_args.post_aggregation:
                    module = PostBiSelfAttentionRGCN

            self.temporal_model = module(temporal_args, num_ents, num_rels,
                                         graph_dict_train, graph_dict_val,
                                         graph_dict_test)
            self.temporal_model.load_state_dict(
                temporal_checkpoint['state_dict'])

            self.train_seq_len = temporal_args.train_seq_len
            self.test_seq_len = temporal_args.train_seq_len
            self.bidirectional = "Bi" in temporal_args.module

            checkpoint_path = glob.glob(
                os.path.join(spatial_path, "checkpoints", "*.ckpt"))[0]
            local_checkpoint = torch.load(
                checkpoint_path, map_location=lambda storage, loc: storage)
            local_config_path = os.path.join(spatial_path, "config.json")
            local_args_json = json.load(open(local_config_path))
            local_args = process_args()
            local_args.__dict__.update(dict(local_args_json))
            self.spatial_model = StaticRGCN(local_args, num_ents, num_rels,
                                            graph_dict_train, graph_dict_val,
                                            graph_dict_test)
            self.spatial_model.load_state_dict(local_checkpoint['state_dict'])

        for para in self.spatial_model.parameters():
            para.requires_grad = False
        for para in self.temporal_model.parameters():
            para.requires_grad = False

        # self.subject_linear = torch.nn.Linear(2, 1)
        # self.object_linear = torch.nn.Linear(2, 1)

        self.subject_linear = nn.Sequential(nn.Linear(3, 3), nn.ReLU(),
                                            nn.Linear(3, 1))
        self.object_linear = nn.Sequential(nn.Linear(3, 3), nn.ReLU(),
                                           nn.Linear(3, 1))
Beispiel #5
0
class Aggregator(TKG_Module):
    def __init__(self, args, num_ents, num_rels, graph_dict_train,
                 graph_dict_val, graph_dict_test):

        super(Aggregator,
              self).__init__(args, num_ents, num_rels, graph_dict_train,
                             graph_dict_val, graph_dict_test)
        module = {
            "GRRGCN": DynamicRGCN,
            "RRGCN": DynamicRGCN,
            "SARGCN": SelfAttentionRGCN,
            "BiGRRGCN": BiDynamicRGCN,
            "BiRRGCN": BiDynamicRGCN,
            "BiSARGCN": BiSelfAttentionRGCN
        }[args.temporal_module]

        # import pdb; pdb.set_trace()

        self.graph_dict_total = {
            **self.graph_dict_train,
            **self.graph_dict_val,
            **self.graph_dict_test
        }

        self.get_true_head_and_tail_all()

        spatial_path = args.spatial_checkpoint
        temporal_path = args.temporal_checkpoint

        if args.debug:
            self.bidirectional = True
            self.drop_edge = DropEdge(args, graph_dict_train, graph_dict_val,
                                      graph_dict_test)
            # self.drop_edge.count_frequency()
            self.spatial_model = StaticRGCN(args, num_ents, num_rels,
                                            graph_dict_train, graph_dict_val,
                                            graph_dict_test)
            args.module = 'BiGRRGCN'
            self.temporal_model = BiDynamicRGCN(args, num_ents, num_rels,
                                                graph_dict_train,
                                                graph_dict_val,
                                                graph_dict_test)

            self.train_seq_len = args.train_seq_len
            self.test_seq_len = args.train_seq_len
        else:
            checkpoint_path = glob.glob(
                os.path.join(temporal_path, "checkpoints", "*.ckpt"))[0]
            temporal_checkpoint = torch.load(
                checkpoint_path, map_location=lambda storage, loc: storage)
            temporal_config_path = os.path.join(temporal_path, "config.json")
            temporal_args_json = json.load(open(temporal_config_path))
            temporal_args = process_args()
            temporal_args.__dict__.update(dict(temporal_args_json))
            self.drop_edge = DropEdge(temporal_args, graph_dict_train,
                                      graph_dict_val, graph_dict_test)
            # self.drop_edge.count_frequency()

            if module == BiDynamicRGCN:
                if temporal_args.post_aggregation:
                    module = PostBiDynamicRGCN
                elif temporal_args.post_ensemble:
                    module = PostEnsembleBiDynamicRGCN
                elif temporal_args.impute:
                    module = ImputeBiDynamicRGCN

            elif module == DynamicRGCN:
                if temporal_args.post_aggregation:
                    module = PostDynamicRGCN
                if temporal_args.post_ensemble:
                    module = PostEnsembleDynamicRGCN
                elif temporal_args.impute:
                    module = ImputeDynamicRGCN

            elif module == BiSelfAttentionRGCN:
                if temporal_args.post_aggregation:
                    module = PostBiSelfAttentionRGCN

            self.temporal_model = module(temporal_args, num_ents, num_rels,
                                         graph_dict_train, graph_dict_val,
                                         graph_dict_test)
            self.temporal_model.load_state_dict(
                temporal_checkpoint['state_dict'])

            self.train_seq_len = temporal_args.train_seq_len
            self.test_seq_len = temporal_args.train_seq_len
            self.bidirectional = "Bi" in temporal_args.module

            checkpoint_path = glob.glob(
                os.path.join(spatial_path, "checkpoints", "*.ckpt"))[0]
            local_checkpoint = torch.load(
                checkpoint_path, map_location=lambda storage, loc: storage)
            local_config_path = os.path.join(spatial_path, "config.json")
            local_args_json = json.load(open(local_config_path))
            local_args = process_args()
            local_args.__dict__.update(dict(local_args_json))
            self.spatial_model = StaticRGCN(local_args, num_ents, num_rels,
                                            graph_dict_train, graph_dict_val,
                                            graph_dict_test)
            self.spatial_model.load_state_dict(local_checkpoint['state_dict'])

        for para in self.spatial_model.parameters():
            para.requires_grad = False
        for para in self.temporal_model.parameters():
            para.requires_grad = False

        # self.subject_linear = torch.nn.Linear(2, 1)
        # self.object_linear = torch.nn.Linear(2, 1)

        self.subject_linear = nn.Sequential(nn.Linear(3, 3), nn.ReLU(),
                                            nn.Linear(3, 1))
        self.object_linear = nn.Sequential(nn.Linear(3, 3), nn.ReLU(),
                                           nn.Linear(3, 1))

    def build_model(self):
        pass

    def get_true_head_and_tail_all(self):
        self.true_heads = dict()
        self.true_tails = dict()
        for t in self.total_time:
            triples = []
            for g in self.graph_dict_train[t], self.graph_dict_val[
                    t], self.graph_dict_test[t]:
                triples.append(
                    torch.stack(
                        [g.edges()[0], g.edata['type_s'],
                         g.edges()[1]]).transpose(0, 1))
            triples = torch.cat(triples, dim=0)
            true_head, true_tail = CorruptTriples.get_true_head_and_tail_per_graph(
                triples)
            self.true_heads[t] = true_head
            self.true_tails[t] = true_tail

    def calc_ensemble_ratio(self, triples, t, g):
        sub_feature_vecs = []
        obj_feature_vecs = []
        t = t.item()
        for s, r, o in triples:

            s = g.ids[s.item()]
            r = r.item()
            o = g.ids[o.item()]
            # triple_freq = self.drop_edge.triple_freq_per_time_step_agg[t][(s, r, o)]
            # ent_pair_freq = self.drop_edge.ent_pair_freq_per_time_step_agg[t][(s, o)]
            sub_freq = self.drop_edge.sub_freq_per_time_step_agg[t][s]
            obj_freq = self.drop_edge.obj_freq_per_time_step_agg[t][o]
            rel_freq = self.drop_edge.rel_freq_per_time_step_agg[t][r]
            sub_rel_freq = self.drop_edge.sub_rel_freq_per_time_step_agg[t][(
                s, r)]
            obj_rel_freq = self.drop_edge.obj_rel_freq_per_time_step_agg[t][(
                o, r)]
            # 0: no local, 1: no temporal

            sub_feature_vecs.append(
                torch.tensor([obj_freq, rel_freq, obj_rel_freq]))
            obj_feature_vecs.append(
                torch.tensor([sub_freq, rel_freq, sub_rel_freq]))
        # pdb.set_trace()

        try:
            sub_features = torch.stack(sub_feature_vecs).float()
            obj_features = torch.stack(obj_feature_vecs).float()
            if self.use_cuda:
                sub_features = cuda(sub_features)
                obj_features = cuda(obj_features)
            weight_subject = torch.sigmoid(self.subject_linear(sub_features))
            weight_object = torch.sigmoid(self.object_linear(obj_features))
        except:
            weight_subject = cuda(torch.tensor(
                []).long()) if self.use_cuda else torch.tensor([]).long()
            weight_object = cuda(torch.tensor(
                []).long()) if self.use_cuda else torch.tensor([]).long()

        return weight_subject, weight_object

    def forward(self, t_list, reverse=False):
        # pdb.set_trace()
        t_list = t_list.sort(descending=True)[0]
        per_graph_ent_embeds_local, g_list = self.spatial_model.train_embed(
            t_list)
        if self.bidirectional:
            per_graph_ent_embeds_temporal, train_graphs, time_list, hist_embeddings_forward, start_time_tensor_forward, hist_embeddings_backward, start_time_tensor_backward = self.temporal_model.train_embed(
                t_list)
        else:
            per_graph_ent_embeds_temporal, train_graphs, time_list, hist_embeddings, start_time_tensor = self.temporal_model.train_embed(
                t_list)

        assert t_list.tolist() == time_list
        rel_enc_local = self.spatial_model.rel_embeds
        rel_enc_temp = self.temporal_model.rel_embeds
        reconstruct_loss = 0
        i = 0
        for t, g, ent_embed_local, ent_embed_temp in zip(
                t_list, train_graphs, per_graph_ent_embeds_local,
                per_graph_ent_embeds_temporal):
            triplets, neg_tail_samples, neg_head_samples, labels = self.corrupter.single_graph_negative_sampling(
                t, g, self.num_ents)
            # with torch.no_grad():
            if self.bidirectional:
                time_diff_tensor_forward = self.train_seq_len - 1 - start_time_tensor_forward[
                    i]
                time_diff_tensor_backward = self.train_seq_len - 1 - start_time_tensor_backward[
                    i]
                all_embeds_g_temp = self.temporal_model.get_all_embeds_Gt(
                    ent_embed_temp, g, t, hist_embeddings_forward[i][0],
                    hist_embeddings_forward[i][1], time_diff_tensor_forward,
                    hist_embeddings_backward[i][0],
                    hist_embeddings_backward[i][1], time_diff_tensor_backward)
            else:
                time_diff_tensor = self.train_seq_len - 1 - start_time_tensor[i]
                all_embeds_g_temp = self.temporal_model.get_all_embeds_Gt(
                    ent_embed_temp, g, t, hist_embeddings[i][0],
                    hist_embeddings[i][1], time_diff_tensor)

            all_embeds_g_local = self.spatial_model.get_all_embeds_Gt(
                t, g, ent_embed_local)

            weight_subject, weight_object = self.calc_ensemble_ratio(
                triplets, t, g)
            score_tail_local = self.train_link_prediction(ent_embed_local,
                                                          rel_enc_local,
                                                          triplets,
                                                          neg_tail_samples,
                                                          labels,
                                                          all_embeds_g_local,
                                                          corrupt_tail=True)
            score_head_local = self.train_link_prediction(ent_embed_local,
                                                          rel_enc_local,
                                                          triplets,
                                                          neg_head_samples,
                                                          labels,
                                                          all_embeds_g_local,
                                                          corrupt_tail=False)
            score_tail_temporal = self.train_link_prediction(ent_embed_temp,
                                                             rel_enc_temp,
                                                             triplets,
                                                             neg_tail_samples,
                                                             labels,
                                                             all_embeds_g_temp,
                                                             corrupt_tail=True)
            score_head_temporal = self.train_link_prediction(
                ent_embed_temp,
                rel_enc_temp,
                triplets,
                neg_head_samples,
                labels,
                all_embeds_g_temp,
                corrupt_tail=False)
            # pdb.set_trace()
            loss_tail = self.combined_scores(score_tail_local,
                                             score_tail_temporal, labels,
                                             weight_object)
            loss_head = self.combined_scores(score_head_local,
                                             score_head_temporal, labels,
                                             weight_subject)
            reconstruct_loss += loss_tail + loss_head
            i += 1
        return reconstruct_loss

    def evaluate(self, t_list, val=True):
        t_list = t_list.sort(descending=True)[0]
        per_graph_ent_embeds_local, g_list = self.spatial_model.evaluate_embed(
            t_list, val)
        if self.bidirectional:
            per_graph_ent_embeds_temporal, test_graphs, time_list, hist_embeddings_forward, start_time_tensor_forward, hist_embeddings_backward, start_time_tensor_backward = self.temporal_model.evaluate_embed(
                t_list, val)
            # if self.drop_edge
            #     per_graph_ent_embeds_temporal, test_graphs, time_list, hist_embeddings_loc, hist_embeddings_rec, start_time_tensor
        else:
            per_graph_ent_embeds_temporal, test_graphs, time_list, hist_embeddings, start_time_tensor = self.temporal_model.evaluate_embed(
                t_list, val)
        # assert t_list.tolist() == time_list
        mrrs, hit_1s, hit_3s, hit_10s, losses = [], [], [], [], []
        ranks = []
        i = 0
        cur_t = self.test_seq_len - 1
        rel_enc_local = self.spatial_model.rel_embeds
        rel_enc_temporal = self.temporal_model.rel_embeds

        for g, t, ent_embed_local, ent_embed_temporal in zip(
                g_list, t_list, per_graph_ent_embeds_local,
                per_graph_ent_embeds_temporal):
            # pdb.set_trace()
            all_embeds_g_local = self.spatial_model.get_all_embeds_Gt(
                t, g, ent_embed_local)
            if self.bidirectional:
                time_diff_tensor_forward = cur_t - start_time_tensor_forward[i]
                time_diff_tensor_backward = cur_t - start_time_tensor_backward[
                    i]
                all_embeds_g_temporal = self.temporal_model.get_all_embeds_Gt(
                    ent_embed_temporal, g, t, hist_embeddings_forward[i][0],
                    hist_embeddings_forward[i][1], time_diff_tensor_forward,
                    hist_embeddings_backward[i][0],
                    hist_embeddings_backward[i][1], time_diff_tensor_backward)
            else:
                time_diff_tensor = cur_t - start_time_tensor[i]
                all_embeds_g_temporal = self.temporal_model.get_all_embeds_Gt(
                    ent_embed_temporal, g, t, hist_embeddings[i][0],
                    hist_embeddings[i][1], time_diff_tensor)
            index_sample = torch.stack(
                [g.edges()[0], g.edata['type_s'],
                 g.edges()[1]]).transpose(0, 1)

            weight_subject, weight_object = self.calc_ensemble_ratio(
                index_sample, t, g)

            if self.use_cuda:
                index_sample = cuda(index_sample)
            if index_sample.shape[0] == 0: continue

            rank = self.calc_metrics_single_graph(
                ent_embed_local, ent_embed_temporal, rel_enc_local,
                rel_enc_temporal, all_embeds_g_local, all_embeds_g_temporal,
                weight_subject, weight_object, index_sample, g, t)
            # loss = self.link_classification_loss(ent_embed, self.rel_embeds, index_sample, label)
            ranks.append(rank)
            # losses.append(loss.item())
            i += 1
        try:
            ranks = torch.cat(ranks)
        except:
            ranks = cuda(torch.tensor(
                []).long()) if self.use_cuda else torch.tensor([]).long()

        return ranks, np.mean(losses)

    def calc_metrics_single_graph(self,
                                  ent_embed_local,
                                  ent_embed_temporal,
                                  rel_enc_local,
                                  rel_enc_temporal,
                                  all_embeds_g_local,
                                  all_embeds_g_temporal,
                                  weight_subject,
                                  weight_object,
                                  samples,
                                  graph,
                                  time,
                                  eval_bz=100):
        with torch.no_grad():
            s = samples[:, 0]
            r = samples[:, 1]
            o = samples[:, 2]
            test_size = samples.shape[0]
            num_ent = all_embeds_g_local.shape[0]
            o_mask = self.mask_eval_set(samples,
                                        test_size,
                                        num_ent,
                                        time,
                                        graph,
                                        mode="tail")
            s_mask = self.mask_eval_set(samples,
                                        test_size,
                                        num_ent,
                                        time,
                                        graph,
                                        mode="head")

            # perturb object
            ranks_o = self.emsemble_and_get_rank(ent_embed_local,
                                                 ent_embed_temporal,
                                                 rel_enc_local,
                                                 rel_enc_temporal,
                                                 all_embeds_g_local,
                                                 all_embeds_g_temporal,
                                                 weight_subject,
                                                 s,
                                                 r,
                                                 o,
                                                 test_size,
                                                 o_mask,
                                                 graph,
                                                 eval_bz,
                                                 mode='tail')
            # perturb subject
            ranks_s = self.emsemble_and_get_rank(ent_embed_local,
                                                 ent_embed_temporal,
                                                 rel_enc_local,
                                                 rel_enc_temporal,
                                                 all_embeds_g_local,
                                                 all_embeds_g_temporal,
                                                 weight_object,
                                                 s,
                                                 r,
                                                 o,
                                                 test_size,
                                                 s_mask,
                                                 graph,
                                                 eval_bz,
                                                 mode='head')

            ranks = torch.cat([ranks_s, ranks_o])
            ranks += 1  # change to 1-indexed
            # print("Graph {} mean ranks {}".format(time.item(), ranks.float().mean().item()))
        return ranks

    def emsemble_and_get_rank(self,
                              ent_embed_local,
                              ent_embed_temporal,
                              rel_enc_local,
                              rel_enc_temporal,
                              all_embeds_g_local,
                              all_embeds_g_temporal,
                              weight,
                              s,
                              r,
                              o,
                              test_size,
                              mask,
                              graph,
                              batch_size=100,
                              mode='tail'):
        """ Perturb one element in the triplets
        """
        n_batch = (test_size + batch_size - 1) // batch_size
        ranks = []
        local_scores = []
        temporal_scores = []
        targets = []
        for idx in range(n_batch):
            local_score, _ = self.get_score(ent_embed_local, rel_enc_local,
                                            all_embeds_g_local, s, r, o,
                                            test_size, mask, graph, idx,
                                            batch_size, mode)
            temporal_score, target = self.get_score(
                ent_embed_temporal, rel_enc_temporal, all_embeds_g_temporal, s,
                r, o, test_size, mask, graph, idx, batch_size, mode)
            local_scores.append(local_score)
            temporal_scores.append(temporal_score)
            targets.append(target)
        scores = weight * torch.cat(local_scores) + (
            1 - weight) * torch.cat(temporal_scores)
        targets = torch.cat(targets)
        ranks.append(self.sort_and_rank(torch.sigmoid(scores), targets))
        return torch.cat(ranks)

    def get_score(self,
                  ent_mean,
                  rel_enc_means,
                  all_ent_embeds,
                  s,
                  r,
                  o,
                  test_size,
                  mask,
                  graph,
                  idx,
                  batch_size=100,
                  mode='tail'):
        batch_start = idx * batch_size
        batch_end = min(test_size, (idx + 1) * batch_size)
        batch_r = rel_enc_means[r[batch_start:batch_end]]

        if mode == 'tail':
            batch_s = ent_mean[s[batch_start:batch_end]]
            batch_o = all_ent_embeds
            target = o[batch_start:batch_end]
        else:
            batch_s = all_ent_embeds
            batch_o = ent_mean[o[batch_start:batch_end]]
            target = s[batch_start:batch_end]
        target = torch.tensor([graph.ids[i.item()] for i in target])

        if self.args.use_cuda:
            target = cuda(target)

        unmasked_score = self.calc_score(batch_s, batch_r, batch_o, mode=mode)
        masked_score = torch.where(
            mask[batch_start:batch_end],
            -10e6 * unmasked_score.new_ones(unmasked_score.shape),
            unmasked_score)
        return masked_score, target  # bsz, n_ent

    def mask_eval_set(self,
                      test_triplets,
                      test_size,
                      num_ent,
                      time,
                      graph,
                      mode='tail'):
        time = int(time.item())
        mask = test_triplets.new_zeros(test_size, num_ent)
        for i in range(test_size):
            h, r, t = test_triplets[i]
            h, r, t = h.item(), r.item(), t.item()
            if mode == 'tail':
                tails = self.true_tails[time][(h, r)]
                tail_idx = np.array(list(map(lambda x: graph.ids[x], tails)))
                mask[i][tail_idx] = 1
                mask[i][graph.ids[t]] = 0
            elif mode == 'head':
                heads = self.true_heads[time][(r, t)]
                head_idx = np.array(list(map(lambda x: graph.ids[x], heads)))
                mask[i][head_idx] = 1
                mask[i][graph.ids[h]] = 0
            # pdb.set_trace()
        return mask.byte()

    def combined_scores(self, local_score, temporal_score, labels, weight):
        score = weight * local_score + (1 - weight) * temporal_score
        predict_loss = F.cross_entropy(score, labels)
        return predict_loss

    def train_link_prediction(self,
                              ent_embed,
                              rel_embeds,
                              triplets,
                              neg_samples,
                              labels,
                              all_embeds_g,
                              corrupt_tail=True):
        r = rel_embeds[triplets[:, 1]]
        if corrupt_tail:
            s = ent_embed[triplets[:, 0]]
            neg_o = all_embeds_g[neg_samples]
            score = self.calc_score(s, r, neg_o, mode='tail')
        else:
            neg_s = all_embeds_g[neg_samples]
            o = ent_embed[triplets[:, 2]]
            score = self.calc_score(neg_s, r, o, mode='head')
        return score

    def sort_and_rank(self, score, target):
        # pdb.set_trace()
        _, indices = torch.sort(score, dim=1, descending=True)
        indices = torch.nonzero(indices == target.view(-1, 1))
        indices = indices[:, 1].view(-1)
        return indices
Beispiel #6
0
 def build_model(self):
     BiDynamicRGCN.build_model(self)