class MaskedGraphDataset(Dataset): def __init__(self, graph_dataset, mode="train", sampling_mode=1, negative_size=32, expand_factor=64, cache_refresh_time=128, normalize_embed=False, test_topk=-1): assert mode in [ "train", "validation", "test" ], "mode in MaskedGraphDataset must be one of train, validation, and test" assert sampling_mode in [ 0, 1, 2, 3 ], "sampling_mode in MaskedGraphDataset must be in [0,1,2,3]" if mode == "test": assert sampling_mode == 0, "!!! During testing, sampling_mode must be 0, in order to emit all positive true parents" start = time.time() self.mode = mode self.sampling_mode = sampling_mode self.negative_size = negative_size self.expand_factor = expand_factor self.cache_refresh_time = cache_refresh_time self.normalize_embed = normalize_embed self.test_topk = test_topk self.node_features = graph_dataset.g_full.ndata['x'] if self.normalize_embed: self.node_features = F.normalize(self.node_features, p=2, dim=1) self.vocab = graph_dataset.vocab self.full_graph = graph_dataset.g_full.to_networkx() # add node feature vector self.kv = KeyedVectors(vector_size=self.node_features.shape[1]) self.kv.add([str(i) for i in range(len(self.vocab))], self.node_features.numpy()) # add interested node list and subgraph if mode == "train": self.node_list = graph_dataset.train_node_ids self.graph = self.full_graph.subgraph( graph_dataset.train_node_ids).copy() elif mode == "validation": self.node_list = graph_dataset.validation_node_ids self.graph = self.full_graph.subgraph( graph_dataset.train_node_ids + graph_dataset.validation_node_ids).copy() else: self.node_list = graph_dataset.test_node_ids self.graph = self.full_graph.subgraph( graph_dataset.train_node_ids + graph_dataset.test_node_ids).copy() # remove supersource nodes (i.e., nodes without in-degree 0) roots = [ node for node in self.graph.nodes() if self.graph.in_degree(node) == 0 ] interested_node_set = set(self.node_list) - set(roots) self.node_list = list(interested_node_set) # generate and cache intermediate data self.node2parents = {} # list of correct parent positions self.node2positive_pointer = {} self.node2masks = { } # self.node2masks[n] is a list of positions that should not be chosen as the anchor of query node n self.all_positions = set( graph_dataset.train_node_ids ) # each `position` is the candidate parent node of query element for node in tqdm(self.graph.nodes(), desc="generating intermediate data ..."): parents = [edge[0] for edge in self.graph.in_edges(node)] self.node2parents[node] = parents self.node2positive_pointer[node] = 0 if node in interested_node_set: descendants = nx.descendants(self.graph, node) masks = set(list(descendants) + parents + [node] + roots) self.node2masks[node] = masks # [IMPORTANT] Here, we must remove the edges between validation/test node ids with train graph to avoid data leakage edge_to_remove = [] if mode == "validation": for node in graph_dataset.validation_node_ids: edge_to_remove.extend(list(self.graph.in_edges(node))) print( "Remove {} edges between validation nodes and training nodes". format(len(edge_to_remove))) elif mode == "test": for node in graph_dataset.test_node_ids: edge_to_remove.extend(list(self.graph.in_edges(node))) print( "Remove {} edges between test nodes and training nodes".format( len(edge_to_remove))) self.graph.remove_edges_from(edge_to_remove) # used for caching local subgraphs self.cache = { } # if g = self.cache[anchor_node], then g is the egonet centered on the anchor_node self.cache_counter = { } # if n = self.cache[anchor_node], then n is the number of times you used this cache # used for sampling negative poistions during train/validation stage self.pointer = 0 self.queue = (graph_dataset.train_node_ids * 5).copy() end = time.time() print("Finish loading dataset ({} seconds)".format(end - start)) def __str__(self): return "MaskedGraphDataset mode:{}".format(self.mode) def __len__(self): return len(self.node_list) def __getitem__(self, idx): """ Generate an data instance based on train/validation/test mode. One data instance is a list of (anchor_egonet, query_node_feature, label) triplets. If self.sampling_mode == 0: This list may contain more than one triplets with label = 1 If self.sampling_mode == 1: This list contain one and ONLY one triplet with label = 1, others have label = 0 """ res = [] query_node = self.node_list[idx] # generate positive triplet(s) if self.sampling_mode == 0: for parent_node in self.node2parents[query_node]: anchor_egonet, query_node_feature = self._get_subgraph_and_node_pair( query_node, parent_node, 1) res.append([anchor_egonet, query_node_feature, 1]) elif self.sampling_mode == 1: positive_pointer = self.node2positive_pointer[query_node] parent_node = self.node2parents[query_node][positive_pointer] anchor_egonet, query_node_feature = self._get_subgraph_and_node_pair( query_node, parent_node, 1) self.node2positive_pointer[query_node] = ( positive_pointer + 1) % len(self.node2parents[query_node]) res.append([anchor_egonet, query_node_feature, 1]) # select negative parents if self.mode in ["train", "validation"]: negative_parents = self._get_negative_anchors( query_node, self.negative_size) else: if self.test_topk == -1: negative_parents = [ ele for ele in self.all_positions if ele not in self.node2masks[query_node] ] else: negative_pool = [ str(ele) for ele in self.all_positions if ele not in self.node2masks[query_node] ] negative_dist = self.kv.distances(str(query_node), negative_pool) top_negatives = sorted(zip(negative_pool, negative_dist), key=lambda x: x[1])[:self.test_topk] negative_parents = [int(ele[0]) for ele in top_negatives] # generate negative triplets for negative_parent in negative_parents: anchor_egonet, query_node_feature = self._get_subgraph_and_node_pair( query_node, negative_parent, 0) res.append([anchor_egonet, query_node_feature, 0]) return tuple(res) def _get_negative_anchors(self, query_node, negative_size): if self.sampling_mode == 0: return self._get_at_most_k_negatives(query_node, negative_size) elif self.sampling_mode == 1: return self._get_exactly_k_negatives(query_node, negative_size) def _get_at_most_k_negatives(self, query_node, negative_size): """ Generate AT MOST negative_size samples for the query node """ if self.pointer == 0: random.shuffle(self.queue) while True: negatives = [ ele for ele in self.queue[self.pointer:self.pointer + negative_size] if ele not in self.node2masks[query_node] ] if len(negatives) > 0: break self.pointer += negative_size if self.pointer >= len(self.queue): self.pointer = 0 return negatives def _get_exactly_k_negatives(self, query_node, negative_size): """ Generate EXACTLY negative_size samples for the query node """ if self.pointer == 0: random.shuffle(self.queue) masks = self.node2masks[query_node] negatives = [] max_try = 0 while len(negatives) != negative_size: n_lack = negative_size - len(negatives) negatives.extend([ ele for ele in self.queue[self.pointer:self.pointer + n_lack] if ele not in masks ]) self.pointer += n_lack if self.pointer >= len(self.queue): self.pointer = 0 random.shuffle(self.queue) max_try += 1 if max_try > 10: # corner cases, trim/expand negatives to the size print( "Alert in _get_exactly_k_negatives, query_node: {}, current negative size: {}" .format(query_node, len(negatives))) if len(negatives) > negative_size: negatives = negatives[:negative_size] else: negatives.extend([ ele for ele in self.queue[:(negative_size - len(negatives))] ]) return negatives def _get_subgraph_and_node_pair(self, query_node, anchor_node, instance_mode): """ Generate anchor_egonet and obtain query_node feature instance_mode: 0 means negative example, 1 means positive example """ # query_node_feature query_node_feature = self.node_features[query_node, :] # [IMPORTANT] only read from cache if this pair is a negative example and already saved in cache # You cannot read the positive egonet from the cache because this egonet will contain the query node itself, which makes the model prediction task trivial if instance_mode == 0 and (anchor_node in self.cache) and ( self.cache_counter[anchor_node] < self.cache_refresh_time): g = self.cache[anchor_node] self.cache_counter[anchor_node] += 1 else: g = self._get_subgraph(query_node, anchor_node, instance_mode) if instance_mode == 0: # save to cache self.cache[anchor_node] = g self.cache_counter[anchor_node] = 0 return g, query_node_feature def _get_subgraph(self, query_node, anchor_node, instance_mode): # grand parents of query node (i.e., parents of anchor node) nodes = [edge[0] for edge in self.graph.in_edges(anchor_node)] gp_nodes = [] gp_index = {} for i, n in enumerate(nodes): for edge in self.graph.in_edges(n): gp_nodes.append(edge[0]) gp_index[edge[0]] = i nodes_pos = [0] * len(gp_nodes) nodes_pos.extend([1] * len(nodes)) nodes = gp_nodes + nodes # parent of query (i.e., anchor node itself) parent_node_idx = len(nodes) nodes.append(anchor_node) nodes_pos.append(2) # siblings of query node (i.e., children of anchor node) if instance_mode == 0: # negative example. do not need to worry about query_node appears to be the child of anchor_node if self.graph.out_degree(anchor_node) <= self.expand_factor: siblings = [ edge[1] for edge in self.graph.out_edges(anchor_node) ] else: siblings = [ edge[1] for edge in random.sample(list( self.graph.out_edges(anchor_node)), k=self.expand_factor) ] else: # positive example. remove query_node from the children set of anchor_node if self.graph.out_degree(anchor_node) <= self.expand_factor: siblings = [ edge[1] for edge in self.graph.out_edges(anchor_node) if edge[1] != query_node ] else: siblings = [ edge[1] for edge in random.sample(list( self.graph.out_edges(anchor_node)), k=self.expand_factor) if edge[1] != query_node ] nodes.extend(siblings) nodes_pos.extend([3] * len(siblings)) # create dgl graph with features g = dgl.DGLGraph() g.add_nodes( len(nodes), { "x": self.node_features[nodes, :], "_id": torch.tensor(nodes), "pos": torch.tensor(nodes_pos) }) g.add_edges(list(range(len(gp_nodes), parent_node_idx + 1)), parent_node_idx) g.add_edges(parent_node_idx, list(range(parent_node_idx + 1, len(nodes)))) for i, gp in enumerate(gp_nodes): g.add_edges(i, len(gp_nodes) + gp_index[gp]) # add self-cycle g.add_edges(g.nodes(), g.nodes()) return g
def main(args): """ Load graph dataset """ graph_dataset = MAGDataset(name="", path=args.test_data, raw=False) node_features = graph_dataset.g_full.ndata['x'] node_features = F.normalize(node_features, p=2, dim=1) vocab = graph_dataset.vocab full_graph = graph_dataset.g_full.to_networkx() kv = KeyedVectors(vector_size=node_features.shape[1]) kv.add([str(i) for i in range(len(vocab))], node_features.numpy()) node_list = graph_dataset.test_node_ids graph = full_graph.subgraph(graph_dataset.train_node_ids + graph_dataset.test_node_ids).copy() roots = [node for node in graph.nodes() if graph.in_degree(node) == 0] interested_node_set = set(node_list) - set(roots) node_list = list(interested_node_set) node2parents = {} # list of correct parent positions node2masks = { } # list of positions that should not be chosen as negative positions all_positions = set( graph_dataset.train_node_ids ) # each `position` is the candidate parent node of query element for node in tqdm(graph.nodes(), desc="generating intermediate data ..."): parents = [edge[0] for edge in graph.in_edges(node)] node2parents[node] = parents if node in interested_node_set: descendants = nx.descendants(graph, node) masks = set(list(descendants) + parents + [node] + roots) node2masks[node] = masks edge_to_remove = [] for node in graph_dataset.test_node_ids: edge_to_remove.extend(list(graph.in_edges(node))) print( f"Remove {len(edge_to_remove)} edges between test nodes and training nodes" ) graph.remove_edges_from(edge_to_remove) """ Cache information """ all_positions_list = list(all_positions) tx_id2rank_id = {v: k for k, v in enumerate(all_positions_list)} rank_id2tx_id = {k: v for k, v in enumerate(all_positions_list)} all_positions_string = [str(ele) for ele in all_positions] parent_node2info = {} ego2parent = {} ego2children = {} for parent_node in tqdm(graph, desc="generate egonet distances"): neighbor = [] neighbor.append(parent_node) grand_parents = [edge[0] for edge in graph.in_edges(parent_node)] ego2parent[parent_node] = grand_parents num_gp = len(grand_parents) neighbor.extend(grand_parents) siblings = [edge[1] for edge in graph.out_edges(parent_node)] ego2children[parent_node] = siblings neighbor.extend(siblings) # calculate embedding distances p_distances = kv.distances(str(parent_node), [str(ele) for ele in neighbor]) parent_node2info[parent_node] = { "p_distances": p_distances, "num_gp": num_gp } """ Load model for prediction and save results """ with open(args.model, "rb") as fin: model = pickle.load(fin) feature_extractor = FeatureExtractor(graph, kv) result_file_path = args.output_ranks with open(result_file_path, "w") as fout: for query_node in tqdm(node_list): featMat = [] labels = list(range(len(node2parents[query_node]))) # cache information all_distances = kv.distances(str(query_node), all_positions_string) # select negative negative_pool = [ str(ele) for ele in all_positions if ele not in node2masks[query_node] ] negative_dist = kv.distances(str(query_node), negative_pool) if args.retrieval_size == -1: top_negatives = list(zip(negative_pool, negative_dist)) else: top_negatives = sorted( zip(negative_pool, negative_dist), key=lambda x: x[1])[:args.retrieval_size] negatives = [int(ele[0]) for ele in top_negatives] # add positive positions for positive_parent in node2parents[query_node]: parent_node_info = parent_node2info[positive_parent] features = feature_extractor.extract_features_fast( query_node, positive_parent, ego2parent, ego2children, parent_node_info, all_distances, tx_id2rank_id, rank_id2tx_id) featMat.append(features) # add negative positions for negative_parent in negatives: parent_node_info = parent_node2info[negative_parent] features = feature_extractor.extract_features_fast( query_node, negative_parent, ego2parent, ego2children, parent_node_info, all_distances, tx_id2rank_id, rank_id2tx_id) featMat.append(features) dtest = xgb.DMatrix(np.array(featMat), missing=-999) ypred = model.predict(dtest, ntree_limit=model.best_ntree_limit) distance = 1.0 - ypred ranks = calculate_ranks_from_distance(distance, labels) fout.write(f"{ranks}\n")