Example #1
0
    def forward(self, formula, queries, source_nodes):
        if formula.query_type == "1-chain":
            # a chain is simply a call to the path decoder
            return self.path_dec.forward(
                    self.enc.forward(source_nodes, formula.target_mode), 
                    self.enc.forward([query.anchor_nodes[0] for query in queries], formula.anchor_modes[0]),
                    formula.rels)
        elif formula.query_type == "2-inter" or formula.query_type == "3-inter":
            target_embeds = self.enc(source_nodes, formula.target_mode)

            embeds1 = self.enc([query.anchor_nodes[0] for query in queries], formula.anchor_modes[0])
            embeds1 = self.path_dec.project(embeds1, _reverse_relation(formula.rels[0]))

            embeds2 = self.enc([query.anchor_nodes[1] for query in queries], formula.anchor_modes[1])
            if len(formula.rels[1]) == 2:
                for i_rel in formula.rels[1][::-1]:
                    embeds2 = self.path_dec.project(embeds2, _reverse_relation(i_rel))
            else:
                    embeds2 = self.path_dec.project(embeds2, _reverse_relation(formula.rels[1]))

            scores1 = self.cos(target_embeds, embeds1)
            scores2 = self.cos(target_embeds, embeds2)
            if formula.query_type == "3-inter":
                embeds3 = self.enc([query.anchor_nodes[2] for query in queries], formula.anchor_modes[2])
                embeds3 = self.path_dec.project(embeds3, _reverse_relation(formula.rels[2]))
                scores3 = self.cos(target_embeds, embeds2)
                scores = scores1 * scores2 * scores3
            else:
                scores = scores1 * scores2
            return scores
        else:
            raise Exception("Query type not supported for this model.")
Example #2
0
def evaluate_metapath_auc(test_metapaths, graph, enc_dec, batch_size=512):
    predictions = []
    labels = []
    for rels in test_metapaths:
        print "Testing on", rels
        if rels[0] == rels[1]:
            continue
        rels_pos_metapaths = test_metapaths[rels]
        if len(rels_pos_metapaths) > 0:
            node_set = set(graph.adj_lists[_reverse_relation(rels[1])].keys())
            rels_all_metapaths = graph.get_metapath_byrels(rels)
            rels_neg_metapaths = []
            for e in rels_pos_metapaths:
                sample_space = list(node_set - set(rels_all_metapaths[e[0]]))
                if len(sample_space) > 0:
                    rels_neg_metapaths.append(
                        (e[0], np.random.choice(sample_space)))

            labels.extend([1 for _ in rels_pos_metapaths] +
                          [0 for _ in rels_neg_metapaths])
            metapaths = rels_pos_metapaths + rels_neg_metapaths
            splits = len(metapaths) / batch_size + 1
            for metapath_split in np.array_split(metapaths, splits):
                scores = enc_dec.forward([e[0] for e in metapath_split],
                                         [e[1] for e in metapath_split], rels)
                predictions.extend(scores.data.tolist())
    return roc_auc_score(labels, predictions)
Example #3
0
    def forward(self, formula, queries, target_nodes):
        if formula.query_type == "1-chain" or formula.query_type == "2-chain" or formula.query_type == "3-chain":
            # a chain is simply a call to the path decoder
            return self.path_dec.forward(
                self.enc.forward(target_nodes, formula.target_mode),
                self.enc.forward([query.anchor_nodes[0] for query in queries],
                                 formula.anchor_modes[0]), formula.rels)
        elif formula.query_type == "2-inter" or formula.query_type == "3-inter" or formula.query_type == "3-inter_chain":
            target_embeds = self.enc(target_nodes, formula.target_mode)

            embeds1 = self.enc([query.anchor_nodes[0] for query in queries],
                               formula.anchor_modes[0])
            embeds1 = self.path_dec.project(embeds1,
                                            _reverse_relation(formula.rels[0]))

            embeds2 = self.enc([query.anchor_nodes[1] for query in queries],
                               formula.anchor_modes[1])
            if len(formula.rels[1]) == 2:
                for i_rel in formula.rels[1][::-1]:
                    embeds2 = self.path_dec.project(embeds2,
                                                    _reverse_relation(i_rel))
            else:
                embeds2 = self.path_dec.project(
                    embeds2, _reverse_relation(formula.rels[1]))

            if formula.query_type == "3-inter":
                embeds3 = self.enc(
                    [query.anchor_nodes[2] for query in queries],
                    formula.anchor_modes[2])
                embeds3 = self.path_dec.project(
                    embeds3, _reverse_relation(formula.rels[2]))

                query_intersection = self.inter_dec(embeds1, embeds2,
                                                    formula.target_mode,
                                                    embeds3)
            else:
                query_intersection = self.inter_dec(embeds1, embeds2,
                                                    formula.target_mode)
            scores = self.cos(target_embeds, query_intersection)
            return scores
        elif formula.query_type == "3-chain_inter":
            target_embeds = self.enc(target_nodes, formula.target_mode)

            embeds1 = self.enc([query.anchor_nodes[0] for query in queries],
                               formula.anchor_modes[0])
            embeds1 = self.path_dec.project(
                embeds1, _reverse_relation(formula.rels[1][0]))
            embeds2 = self.enc([query.anchor_nodes[1] for query in queries],
                               formula.anchor_modes[1])
            embeds2 = self.path_dec.project(
                embeds2, _reverse_relation(formula.rels[1][1]))
            query_intersection = self.inter_dec(embeds1, embeds2,
                                                formula.rels[0][-1])
            query_intersection = self.path_dec.project(
                query_intersection, _reverse_relation(formula.rels[0]))
            scores = self.cos(target_embeds, query_intersection)
            return scores
Example #4
0
def evaluate_edge_auc(test_edges, graph, enc_dec, batch_size=512):
    predictions = []
    labels = []
    for rel in test_edges:
        print "Testing on", rel
        node_set = set(graph.adj_lists[rel].keys())
        rel_pos_edges = test_edges[rel]
        rel_neg_edges = [(np.random.choice(
            list(node_set -
                 set(graph.adj_lists[_reverse_relation(rel)][e[1]]))), e[1])
                         for e in rel_pos_edges]
        labels.extend([1 for _ in rel_pos_edges] + [0 for _ in rel_neg_edges])
        edges = rel_pos_edges + rel_neg_edges
        splits = len(edges) / batch_size + 1
        for edge_split in np.array_split(edges, splits):
            scores = enc_dec.forward([e[0] for e in edge_split],
                                     [e[1] for e in edge_split], [rel])
            predictions.extend(scores.data.tolist())
    return roc_auc_score(labels, predictions)