Exemple #1
0
def inference_on_graph(model,
                       graph,
                       edge_map=default_edge_map,
                       device='cpu',
                       nc_only=False):
    """
        Do inference on one networkx graph.
    """
    graph = nx.to_undirected(graph)
    one_hot = {
        edge: torch.tensor(edge_map[label])
        for edge, label in (nx.get_edge_attributes(graph, 'label')).items()
    }
    nx.set_edge_attributes(graph, name='one_hot', values=one_hot)

    g_dgl = dgl.DGLGraph()
    g_dgl.from_networkx(nx_graph=graph, edge_attrs=['one_hot'])
    g_dgl = send_graph_to_device(g_dgl, device)
    model = model.to(device)
    with torch.no_grad():
        embs = model(g_dgl)
        embs.cpu().numpy()
    g_nodes = list(sorted(graph.nodes()))

    keep_indices = range(len(graph.nodes()))

    if nc_only:
        keep_indices = get_nc_nodes_index(graph)

    node_map = {
        g_nodes[node_ind]: i
        for i, node_ind in enumerate(keep_indices)
    }
    embs = embs[keep_indices]
    return embs, node_map
Exemple #2
0
def predict_gen(model,
                loader,
                max_graphs=10,
                nc_only=False,
                get_sim_mat=False,
                device='cpu'):
    """
    Yield embeddings one batch at a time.

    :param model:
    :param loader:
    :param max_graphs:
    :param get_sim_mat:
    :param device:
    :return:
    """
    all_graphs = loader.dataset.all_graphs
    graph_dir = loader.dataset.path

    model = model.to(device)
    Ks = []
    with torch.no_grad():
        # For each batch, we have the graph, its index in the list and its size
        for i, (graph, K, graph_indices,
                graph_sizes) in tqdm(enumerate(loader), total=len(loader)):
            if get_sim_mat:
                Ks.append(K)

            Z = []
            Ks = []
            g_inds = []
            node_ids = []

            graph_indices = list(graph_indices.numpy().flatten())
            for graph_index, n_nodes in zip(graph_indices, graph_sizes):
                # For each graph, we build an id list in rep that contains all the nodes
                rep = [(all_graphs[graph_index], node_index)
                       for node_index in range(n_nodes)]
                g_inds.extend(rep)

                # list of node ids from original graph
                g_path = os.path.join(graph_dir, all_graphs[graph_index])
                G = fetch_graph(g_path)
                g_nodes = sorted(G.nodes())
                node_ids.extend([g_nodes[i] for i in range(n_nodes)])
            if max_graphs is not None and i > max_graphs - 1:
                raise StopIteration

            graph = send_graph_to_device(graph, device)
            z = model(graph)

            Z.append(z.cpu().numpy())
            Z = np.concatenate(Z)
            g_inds = {value: i for i, value in enumerate(g_inds)}
            node_map = {value: i for i, value in enumerate(node_ids)}
            yield Z, g_inds, node_map
Exemple #3
0
def predict(model,
            loader,
            max_graphs=10,
            nc_only=False,
            get_sim_mat=False,
            device='cpu'):
    """

    :param model:
    :param loader:
    :param max_graphs:
    :param get_sim_mat:
    :param device:
    :return:
    """

    all_graphs = loader.dataset.all_graphs
    graph_dir = loader.dataset.path
    Z = []
    Ks = []
    g_inds = []
    node_ids = []
    tot = max_graphs if max_graphs is not None else len(loader)

    model = model.to(device)
    model.eval()
    with torch.no_grad():
        # For each batch, we have the graph, its index in the list and its size
        # for i, (graph, K, graph_indices, graph_sizes) in enumerate(loader):
        for i, (graph, K, graph_indices,
                graph_sizes) in tqdm(enumerate(loader), total=tot):
            if get_sim_mat:
                Ks.append(K)
            graph_indices = list(graph_indices.numpy().flatten())
            keep_Z_indices = []
            offset = 0
            for graph_index, n_nodes in zip(graph_indices, graph_sizes):
                # For each graph, we build an id list in rep
                # that contains all the nodes

                keep_indices = list(range(n_nodes))

                # list of node ids from original graph
                g_path = os.path.join(graph_dir, all_graphs[graph_index])
                G = fetch_graph(g_path)

                assert n_nodes == len(G.nodes())
                if nc_only:
                    keep_indices = get_nc_nodes_index(G)
                keep_Z_indices.extend([ind + offset for ind in keep_indices])

                rep = [(all_graphs[graph_index], node_index)
                       for node_index in keep_indices]

                g_inds.extend(rep)

                g_nodes = sorted(G.nodes())
                node_ids.extend([g_nodes[i] for i in keep_indices])

                offset += n_nodes

            graph = send_graph_to_device(graph, device)
            z = model(graph)
            z = z.cpu().numpy()
            z = z[keep_Z_indices]
            Z.append(z)

            if max_graphs is not None and i > max_graphs - 1:
                break

    Z = np.concatenate(Z)

    # node to index in graphlist
    g_inds = {value: i for i, value in enumerate(g_inds)}
    # node to index in Z
    node_map = {value: i for i, value in enumerate(node_ids)}
    # index in Z to node
    node_map_r = {i: value for i, value in enumerate(node_ids)}

    return {
        'Z': Z,
        'node_to_gind': g_inds,
        'node_to_zind': node_map,
        'ind_to_node': node_map_r,
        'node_id_list': node_ids
    }