def inference_on_graph(model, graph, edge_map=default_edge_map, device='cpu', nc_only=False): """ Do inference on one networkx graph. """ graph = nx.to_undirected(graph) one_hot = { edge: torch.tensor(edge_map[label]) for edge, label in (nx.get_edge_attributes(graph, 'label')).items() } nx.set_edge_attributes(graph, name='one_hot', values=one_hot) g_dgl = dgl.DGLGraph() g_dgl.from_networkx(nx_graph=graph, edge_attrs=['one_hot']) g_dgl = send_graph_to_device(g_dgl, device) model = model.to(device) with torch.no_grad(): embs = model(g_dgl) embs.cpu().numpy() g_nodes = list(sorted(graph.nodes())) keep_indices = range(len(graph.nodes())) if nc_only: keep_indices = get_nc_nodes_index(graph) node_map = { g_nodes[node_ind]: i for i, node_ind in enumerate(keep_indices) } embs = embs[keep_indices] return embs, node_map
def predict_gen(model, loader, max_graphs=10, nc_only=False, get_sim_mat=False, device='cpu'): """ Yield embeddings one batch at a time. :param model: :param loader: :param max_graphs: :param get_sim_mat: :param device: :return: """ all_graphs = loader.dataset.all_graphs graph_dir = loader.dataset.path model = model.to(device) Ks = [] with torch.no_grad(): # For each batch, we have the graph, its index in the list and its size for i, (graph, K, graph_indices, graph_sizes) in tqdm(enumerate(loader), total=len(loader)): if get_sim_mat: Ks.append(K) Z = [] Ks = [] g_inds = [] node_ids = [] graph_indices = list(graph_indices.numpy().flatten()) for graph_index, n_nodes in zip(graph_indices, graph_sizes): # For each graph, we build an id list in rep that contains all the nodes rep = [(all_graphs[graph_index], node_index) for node_index in range(n_nodes)] g_inds.extend(rep) # list of node ids from original graph g_path = os.path.join(graph_dir, all_graphs[graph_index]) G = fetch_graph(g_path) g_nodes = sorted(G.nodes()) node_ids.extend([g_nodes[i] for i in range(n_nodes)]) if max_graphs is not None and i > max_graphs - 1: raise StopIteration graph = send_graph_to_device(graph, device) z = model(graph) Z.append(z.cpu().numpy()) Z = np.concatenate(Z) g_inds = {value: i for i, value in enumerate(g_inds)} node_map = {value: i for i, value in enumerate(node_ids)} yield Z, g_inds, node_map
def predict(model, loader, max_graphs=10, nc_only=False, get_sim_mat=False, device='cpu'): """ :param model: :param loader: :param max_graphs: :param get_sim_mat: :param device: :return: """ all_graphs = loader.dataset.all_graphs graph_dir = loader.dataset.path Z = [] Ks = [] g_inds = [] node_ids = [] tot = max_graphs if max_graphs is not None else len(loader) model = model.to(device) model.eval() with torch.no_grad(): # For each batch, we have the graph, its index in the list and its size # for i, (graph, K, graph_indices, graph_sizes) in enumerate(loader): for i, (graph, K, graph_indices, graph_sizes) in tqdm(enumerate(loader), total=tot): if get_sim_mat: Ks.append(K) graph_indices = list(graph_indices.numpy().flatten()) keep_Z_indices = [] offset = 0 for graph_index, n_nodes in zip(graph_indices, graph_sizes): # For each graph, we build an id list in rep # that contains all the nodes keep_indices = list(range(n_nodes)) # list of node ids from original graph g_path = os.path.join(graph_dir, all_graphs[graph_index]) G = fetch_graph(g_path) assert n_nodes == len(G.nodes()) if nc_only: keep_indices = get_nc_nodes_index(G) keep_Z_indices.extend([ind + offset for ind in keep_indices]) rep = [(all_graphs[graph_index], node_index) for node_index in keep_indices] g_inds.extend(rep) g_nodes = sorted(G.nodes()) node_ids.extend([g_nodes[i] for i in keep_indices]) offset += n_nodes graph = send_graph_to_device(graph, device) z = model(graph) z = z.cpu().numpy() z = z[keep_Z_indices] Z.append(z) if max_graphs is not None and i > max_graphs - 1: break Z = np.concatenate(Z) # node to index in graphlist g_inds = {value: i for i, value in enumerate(g_inds)} # node to index in Z node_map = {value: i for i, value in enumerate(node_ids)} # index in Z to node node_map_r = {i: value for i, value in enumerate(node_ids)} return { 'Z': Z, 'node_to_gind': g_inds, 'node_to_zind': node_map, 'ind_to_node': node_map_r, 'node_id_list': node_ids }