Ejemplo n.º 1
0
def predict_gen(model,
                loader,
                max_graphs=10,
                nc_only=False,
                get_sim_mat=False,
                device='cpu'):
    """
    Yield embeddings one batch at a time.

    :param model:
    :param loader:
    :param max_graphs:
    :param get_sim_mat:
    :param device:
    :return:
    """
    all_graphs = loader.dataset.all_graphs
    graph_dir = loader.dataset.path

    model = model.to(device)
    Ks = []
    with torch.no_grad():
        # For each batch, we have the graph, its index in the list and its size
        for i, (graph, K, graph_indices,
                graph_sizes) in tqdm(enumerate(loader), total=len(loader)):
            if get_sim_mat:
                Ks.append(K)

            Z = []
            Ks = []
            g_inds = []
            node_ids = []

            graph_indices = list(graph_indices.numpy().flatten())
            for graph_index, n_nodes in zip(graph_indices, graph_sizes):
                # For each graph, we build an id list in rep that contains all the nodes
                rep = [(all_graphs[graph_index], node_index)
                       for node_index in range(n_nodes)]
                g_inds.extend(rep)

                # list of node ids from original graph
                g_path = os.path.join(graph_dir, all_graphs[graph_index])
                G = fetch_graph(g_path)
                g_nodes = sorted(G.nodes())
                node_ids.extend([g_nodes[i] for i in range(n_nodes)])
            if max_graphs is not None and i > max_graphs - 1:
                raise StopIteration

            graph = send_graph_to_device(graph, device)
            z = model(graph)

            Z.append(z.cpu().numpy())
            Z = np.concatenate(Z)
            g_inds = {value: i for i, value in enumerate(g_inds)}
            node_map = {value: i for i, value in enumerate(node_ids)}
            yield Z, g_inds, node_map
Ejemplo n.º 2
0
def predict(model,
            loader,
            max_graphs=10,
            nc_only=False,
            get_sim_mat=False,
            device='cpu'):
    """

    :param model:
    :param loader:
    :param max_graphs:
    :param get_sim_mat:
    :param device:
    :return:
    """

    all_graphs = loader.dataset.all_graphs
    graph_dir = loader.dataset.path
    Z = []
    Ks = []
    g_inds = []
    node_ids = []
    tot = max_graphs if max_graphs is not None else len(loader)

    model = model.to(device)
    model.eval()
    with torch.no_grad():
        # For each batch, we have the graph, its index in the list and its size
        # for i, (graph, K, graph_indices, graph_sizes) in enumerate(loader):
        for i, (graph, K, graph_indices,
                graph_sizes) in tqdm(enumerate(loader), total=tot):
            if get_sim_mat:
                Ks.append(K)
            graph_indices = list(graph_indices.numpy().flatten())
            keep_Z_indices = []
            offset = 0
            for graph_index, n_nodes in zip(graph_indices, graph_sizes):
                # For each graph, we build an id list in rep
                # that contains all the nodes

                keep_indices = list(range(n_nodes))

                # list of node ids from original graph
                g_path = os.path.join(graph_dir, all_graphs[graph_index])
                G = fetch_graph(g_path)

                assert n_nodes == len(G.nodes())
                if nc_only:
                    keep_indices = get_nc_nodes_index(G)
                keep_Z_indices.extend([ind + offset for ind in keep_indices])

                rep = [(all_graphs[graph_index], node_index)
                       for node_index in keep_indices]

                g_inds.extend(rep)

                g_nodes = sorted(G.nodes())
                node_ids.extend([g_nodes[i] for i in keep_indices])

                offset += n_nodes

            graph = send_graph_to_device(graph, device)
            z = model(graph)
            z = z.cpu().numpy()
            z = z[keep_Z_indices]
            Z.append(z)

            if max_graphs is not None and i > max_graphs - 1:
                break

    Z = np.concatenate(Z)

    # node to index in graphlist
    g_inds = {value: i for i, value in enumerate(g_inds)}
    # node to index in Z
    node_map = {value: i for i, value in enumerate(node_ids)}
    # index in Z to node
    node_map_r = {i: value for i, value in enumerate(node_ids)}

    return {
        'Z': Z,
        'node_to_gind': g_inds,
        'node_to_zind': node_map,
        'ind_to_node': node_map_r,
        'node_id_list': node_ids
    }
Ejemplo n.º 3
0
    def __init__(self,
                 run,
                 graph_dir='../data/annotated/whole_v4',
                 n_components=8,
                 min_count=50,
                 max_var=0.1,
                 min_edge=50,
                 clust_algo='k_means',
                 optimize=True,
                 max_graphs=None,
                 nc_only=False,
                 bb_only=False):

        # General
        self.run = run
        self.graph_dir = graph_dir

        # Nodes parameters
        self.n_components = n_components
        self.min_count = min_count
        self.max_var = max_var
        self.clust_algo = clust_algo

        # Edges parameters
        self.bb_only = bb_only
        self.min_edge = min_edge

        # BUILD MNODES
        model_output = inference_on_list(self.run,
                                         self.graph_dir,
                                         os.listdir(self.graph_dir),
                                         max_graphs=max_graphs,
                                         nc_only=nc_only)

        Z = model_output['Z']
        self.node_map = model_output['node_to_zind']
        self.reversed_node_map = model_output['ind_to_node']

        print(len(Z))

        clust_info = cluster(Z,
                             algo=clust_algo,
                             optimize=optimize,
                             n_clusters=n_components)

        if self.clust_algo == 'gmm':
            distance = True
            self.cluster_model = clust_info['model']
            self.labels = clust_info['labels']
            if distance:
                centers = clust_info['centers']
                dists = cdist(Z, centers)
                scores = np.take_along_axis(dists,
                                            self.labels[:, None],
                                            axis=1)
                scores = np.exp(-scores)
            else:
                probas = clust_info['scores']
                scores = np.take_along_axis(probas,
                                            self.labels[:, None],
                                            axis=1)
        elif self.clust_algo == 'som':
            self.cluster_model = clust_info['model']
            self.labels = clust_info['labels']
            scores = np.exp(-clust_info['errors'])
        elif self.clust_algo == 'k_means':
            self.cluster_model = clust_info['model']
            self.labels = clust_info['labels']
            centers = clust_info['centers']
            dists = cdist(Z, centers)
            scores = np.take_along_axis(dists, self.labels[:, None], axis=1)
            scores = np.exp(-scores)
        else:
            raise NotImplementedError

        self.spread = clust_info['spread']

        self.components = np.unique(self.labels)
        self.id_to_score = {
            ind: scores[ind]
            for ind, _ in self.reversed_node_map.items()
        }
        print("Clustered")

        self.graph = nx.Graph()

        # don't keep clusters that are too sparse or not populated enough
        # keep_clusts = cluster_filter(clusts, cov, self.min_count, self.max_var)
        # keep_clusts = set(keep_clusts)

        for id_clust in self.components:
            self.graph.add_node(id_clust, node_ids=set())

        for index, clust in enumerate(self.labels):
            self.graph.nodes[clust]['node_ids'].add(index)

        # BUILD MEDGES
        for graph_name in os.listdir(self.graph_dir)[:max_graphs]:
            graph_path = os.path.join(self.graph_dir, graph_name)
            g = fetch_graph(graph_path)
            g = g.to_undirected()
            for start_node, end_node in g.edges():
                # Get edges id
                if start_node not in self.node_map:
                    continue
                if end_node not in self.node_map:
                    continue
                if self.bb_only and g[start_node][end_node]['label'] != 'B53':
                    continue

                start_id, end_id = self.node_map[start_node], self.node_map[
                    end_node]
                start_clust, end_clust = self.labels[start_id], self.labels[
                    end_id]

                # Filter edges between discarded clusters
                # if start_clust not in keep_clusts or end_clust not in keep_clusts:
                # continue

                # Reorder and either create MEdge or complete it
                # if start_node > end_node:
                #     start_node, end_node = end_node, start_node
                # if start_clust > end_clust:
                #     start_clust, end_clust = end_clust, start_clust

                if not self.graph.has_edge(start_clust, end_clust):
                    self.graph.add_edge(start_clust, end_clust, edge_set=set())

                # self.graph.edges[(start_clust, end_clust)]['edge_set'].add((start_node, end_node, 1))
                self.graph.edges[(start_clust, end_clust)]['edge_set'].add(
                    (start_id, end_id, 1))

        # Filtering and hashing
        to_remove = list()
        for start, end, edge_set in self.graph.edges(data='edge_set'):
            # remove from adjacency
            if len(edge_set) < self.min_edge:
                to_remove.append((start, end))
        for start, end in to_remove:
            self.graph.remove_edge(start, end)