示例#1
0
def draw_hit(hit, mg, instance=None):
    """
    Plot the hit. If an instance is given, then compares this hit with the original instance
    :param hit:
    :param mg:
    :param instance:
    :param compare:
    :return:
    """
    try:
        hit_graph = whole_graph_from_node(hit[0])
    except:
        hit = [mg.reversed_node_map[i] for i in hit]
        hit_graph = whole_graph_from_node(hit[0])

    out_border = get_outer_border(hit, hit_graph)
    full_hit = hit + list(out_border)
    out_border = get_outer_border(full_hit, hit_graph)
    extended_hit = full_hit + list(out_border)
    g_hit = hit_graph.subgraph(extended_hit)
    if instance is not None:
        source_graph = whole_graph_from_node(instance[0])
        trimmed = trim(instance)
        out_border = get_outer_border(instance, source_graph)
        extended = instance + list(out_border)
        g_motif = source_graph.subgraph(extended)
        rna_draw_pair((g_motif, g_hit)
                      , node_colors=(
                ['red' if n in trimmed else 'blue' if n in instance else 'grey' for n in g_motif.nodes()],
                ['red' if n in hit else 'blue' if n in full_hit else 'grey' for n in g_hit.nodes()]))
        plt.show()
    else:
        rna_draw(g_hit, node_colors=['red' if n in hit else 'blue' if n in full_hit else 'grey' for n in g_hit.nodes()])
        plt.show()
示例#2
0
def draw_smooth(motif, mg, depth=1, save=None):
    """
    Draws graphs from the retrieve further and further away
    :param motif:
    :param mg:
    :param depth:
    :return:
    """

    query_instance = motif[0]
    query_whole_g = whole_graph_from_node(query_instance[0])

    # Sometimes one can not trim the motif as much as we could have like, so we need to trim less
    # trimmed, trimmed_graph = trim_try(whole_graph=query_whole_g, instance=query_instance, max_depth=0)
    # query_instance_graph = query_whole_g.subgraph(query_instance)
    trimmed, trimmed_graph, actual_depth = trim_try(whole_graph=query_whole_g, instance=query_instance, depth=depth)
    query_instance_graph = induced_edge_filter(query_whole_g, trimmed, depth=actual_depth)

    retrieved_instances = retrieve_instances(query_instance=query_instance, mg=mg, depth=depth)
    sorted_hits = sorted(list(retrieved_instances.items()), key=lambda x: -x[1])

    # TO GET CONTEXT NODES
    # out_border = get_outer_border(motif[0], query_whole_g)
    # expanded = motif[0] + list(out_border)
    # expanded_graph = query_whole_g.subgraph(expanded)

    graphs = [query_instance_graph]
    colors = [['red' if n in trimmed else 'white' for n in query_instance_graph.nodes()]]
    # colors = [['red' if n in trimmed else 'grey' if n in query_instance else 'blue' for n in expanded_graph.nodes()]]
    subtitles = ['Query']
    plot_index = [10, 100, 1000]
    for i in plot_index:
        hit = sorted_hits[i][0]
        hit = [mg.reversed_node_map[i] for i in hit]
        hit_whole_graph = whole_graph_from_node(hit[0])
        # expand if trimmed
        full_hit = hit

        hit_graph = induced_edge_filter(hit_whole_graph, hit, depth=actual_depth)
        # if depth > 0:
        #     for d in range(depth):
        #         out_border = get_outer_border(full_hit, hit_whole_graph)
        #         full_hit = full_hit + list(out_border)
        # hit_graph = hit_whole_graph.subgraph(full_hit)

        graphs.append(hit_graph)
        colors.append(['red' if n in hit else 'white' for n in hit_graph.nodes()])
        subtitles.append(f'{i}-th hit with score : {sorted_hits[i][1]:2.2f}')
    # rna_draw_pair(graphs=graphs, node_colors=colors, subtitles=subtitles, save=save)

    rna_draw_grid(graphs=graphs, node_colors=colors, subtitles=subtitles, save=save, grid_shape=(2, 2))

    plt.show()
示例#3
0
def retrieve_instances(query_instance, mg, depth=1):
    # DEBUG
    # print(query_instance)
    # query_g = whole_graph_from_node(motif[0][0]).subgraph(motif[0])
    # failure_g = whole_graph_from_node(motif[1][0]).subgraph(motif[1])
    # failure_g2 = whole_graph_from_node(motif[2][0]).subgraph(motif[2])
    # rna_draw_pair((query_g, failure_g))
    # plt.show()
    # rna_draw_pair((query_g, failure_g2))
    # plt.show()
    # rna_draw_pair((failure_g2, failure_g))
    # plt.show()

    query_whole_graph = whole_graph_from_node(query_instance[0])

    # Sometimes one can not trim the motif as much as we could have like, so we need to trim less
    trimmed, trimmed_graph, actual_depth = trim_try(query_whole_graph, query_instance, depth=depth)
    # print('starting the retrieval')
    start = time.perf_counter()
    retrieved_instances = mg.retrieve_2(trimmed)
    print(f">>> Retrieved {len(retrieved_instances)} instances in {time.perf_counter() - start}")

    # retrieved_instances_2 = mg.retrieve_2(trimmed)
    # print(retrieved_instances == retrieved_instances_2)
    # set1 = set(retrieved_instances.items())
    # set2 = set(retrieved_instances_2.items())
    # print(set1 ^ set2)
    # print(f">>> Retrieved {len(retrieved_instances_2)} instances in {time.perf_counter() - start}")

    return retrieved_instances
示例#4
0
def get_embeddings_inference(run,
                             annot_path='../data/annotated/whole_v4',
                             max_graphs=300,
                             nc_only=False):
    """
        Build embedding matrix and graph list for clustering.
        Filters out nodes that don't have non-canonicals in neighbourhood.

    """
    from tools.learning_utils import inference_on_list
    annot_list = os.listdir(annot_path)[:max_graphs]
    keep_node_ids = []
    keep_inds = []
    # Get predictions
    model_output = inference_on_list(run,
                                     graph_list=annot_list,
                                     graphs_path=annot_path)
    Z = model_output['Z']
    node_to_ind = model_output['node_to_zind']
    node_ids = model_output['node_id_list']

    if not nc_only:
        return Z, node_to_ind
    for i, node in enumerate(node_ids):
        G = whole_graph_from_node(node)
        if has_NC_bfs(G, node, depth=1):
            keep_node_ids.append(node)
            keep_inds.append(i)

    print(f">>> got {len(Z)} nodes")
    return Z[keep_inds], keep_node_ids
示例#5
0
def compute_embs(instance, run):
    """
    :param instance: a list of nodes that form a motif
    Parse the json motifs to only get the ones with examples in our data
    and return a dict {motif_id : list of list of nodes (instances)}
    """
    source_graph = whole_graph_from_node(instance[0])
    embs, node_map = inference_on_graph_run(run, source_graph)
    return embs, node_map
示例#6
0
def get_outer_border(nodes, graph=None):
    if graph is None:
        graph = whole_graph_from_node(nodes[0])
    # expand the trimmed retrieval
    out_border = set()
    for node in nodes:
        for nei in graph.neighbors(node):
            if nei not in nodes:
                out_border.add(nei)
    return out_border
示例#7
0
def plot_instance(instance, source_graph=None):
    """
    Plot an extended, native, and trimmed motif.
    :param instance:
    :return:
    """
    if source_graph is None:
        source_graph = whole_graph_from_node(instance[0])
    trimmed = trim(instance)
    out_border = get_outer_border(instance, source_graph)
    extended = instance + list(out_border)
    g_motif = source_graph.subgraph(extended)
    rna_draw(g_motif, node_colors=['red' if n in trimmed else 'blue' if n in instance else 'grey' for n in
                                   g_motif.nodes()])
    plt.show()
示例#8
0
def trim(instance, depth=1, whole_graph=None):
    """
    Remove nodes around the border of a motif
    """
    if whole_graph is None:
        whole_graph = whole_graph_from_node(instance[0])
    out_border = get_outer_border(instance, whole_graph)

    # get the last depth ones as well as the cumulative set
    cummulative, last = out_border, out_border
    for d in range(depth):
        depth_ring = set()
        for node in last:
            for nei in whole_graph.neighbors(node):
                if nei not in cummulative:
                    depth_ring.add(nei)
        last = depth_ring
        cummulative = cummulative.union(depth_ring)
    trimmed_instance = [node for node in instance if node not in cummulative]
    return trimmed_instance
示例#9
0
def prune_motifs(motifs_dict, shortest=4, sparsest=3, non_canonical=True, non_redundant=True):
    """
    Clean the dict by removing sparse or small motifs
    :param motifs_dict:
    :return:
    """
    res_dict = {}
    sparse, short, nc = 0, 0, 0
    tot_inst, nr_inst = 0, 0
    mean_instance, mean_nodes = list(), list()
    non_redundant_list = set(os.listdir(os.path.join(script_dir, '../data/unchopped_v4_nr')))
    for mid, instances in motifs_dict.items():
        instance = instances[0]
        if non_redundant:
            tot_inst += len(instances)
            instances = [instance for instance in instances if instance[0][0] in non_redundant_list]
            nr_inst += len(instances)
        if len(instances) < sparsest:
            sparse += 1
            continue
        if len(instance) < shortest:
            short += 1
            continue
        if non_canonical:
            instance = instances[0]
            graph = whole_graph_from_node(instance[0])
            motif_graph = graph.subgraph(instance)
            if not has_NC(motif_graph):
                nc += 1
                continue
        mean_instance.append(len(instances))
        mean_nodes.append(len(instances[0]))
        res_dict[mid] = instances

    print(f'filtered {sparse} on sparsity, {short} on length, {nc} on non canonicals')
    print(f'non redundancy removed {tot_inst - nr_inst} /{tot_inst} instances')
    print(f'On average, {np.mean(mean_instance)} instances of motifs with {np.mean(mean_nodes)} nodes')
    return res_dict
示例#10
0
def ged_computing(motifs, mg, depth=1):
    from tools.rna_ged_nx import ged
    res_dict = dict()
    all_motifs = [(motif_id, motif) for motif_id, motif in motifs.items()]
    for i, (motif_id, motif) in enumerate(all_motifs):
        inner_dict = {}

        # if int(motif_id) != 4:
        #     continue

        # Get the hits
        print('attempting id : ', motif_id)
        query_instance = motif[0]
        query_whole_graph = whole_graph_from_node(query_instance[0])
        retrieved_instances = retrieve_instances(query_instance=query_instance, mg=mg, depth=depth)
        sorted_hits = sorted(list(retrieved_instances.items()), key=lambda x: -x[1])

        # Get the actual query that was used (because of trimming) and expand it with the depth
        # This is V2, the original one was just computing between the trimmed and the reduced hit
        # trimmed, trimmed_graph = trim_try(whole_graph=query_whole_graph, instance=query_instance)
        # query_instance_graph = query_whole_graph.subgraph(query_instance)
        trimmed, trimmed_graph, actual_depth = trim_try(whole_graph=query_whole_graph, instance=query_instance)
        query_instance_graph = induced_edge_filter(query_whole_graph, trimmed, depth=actual_depth)

        plot_index = [0, 10, 100, 1000]
        for j in plot_index:
            # In case we have less than 1000 hits
            try:
                hit = sorted_hits[j][0]
            except IndexError:
                continue
            hit = [mg.reversed_node_map[node] for node in hit]
            # print(hit)
            hit_whole_graph = whole_graph_from_node(hit[0])

            # If one changes this, one should also remove the query expansion
            hit_graph = induced_edge_filter(hit_whole_graph, hit, depth=actual_depth)
            # hit_graph = whole_graph_from_node(hit[0]).subgraph(hit)
            start = time.perf_counter()
            ged_value = ged(query_instance_graph, hit_graph, timeout=2)
            print(j, len(query_instance_graph), len(hit_graph), ged_value, time.perf_counter() - start)
            # res_dict[j].append(ged_value)
            inner_dict[j] = ged_value

            # TO PLOT THE HITS
            # expanded = hit
            # if depth > 0:
            #     out_border = get_outer_border(hit, hit_whole_graph)
            #     expanded = hit + list(out_border)
            # expanded_graph = hit_whole_graph.subgraph(expanded)

            colors = [['red' if n in trimmed else 'grey' for n in query_instance_graph.nodes()],
                      ['red' if n in hit else 'grey' for n in hit_graph.nodes()]]
            subtitles = ('', ged_value)
            rna_draw_pair((query_instance_graph, hit_graph), node_colors=colors, subtitles=subtitles)
            plt.show()

        # Pick another random that is not the current graph
        other_random = random.randint(0, len(all_motifs) - 2)
        if other_random >= i:
            other_random += 1
        random_query_instance = all_motifs[other_random][1][0]
        # random_graph = whole_graph_from_node(random_query_instance[0]).subgraph(random_query_instance)
        random_graph = induced_edge_filter(whole_graph_from_node(random_query_instance[0]),
                                           random_query_instance, depth=actual_depth)


        # TO PLOT THE RANDOM
        # colors = [['red' if n in trimmed else 'grey' for n in query_instance_graph.nodes()],
        #           ['grey' for n in random_graph.nodes()]]
        # rna_draw_pair((query_instance_graph, random_graph), node_colors=colors)
        # plt.show()

        # res_dict['random_other'].append(ged(query_instance_graph, random_graph, timeout=5))
        start = time.perf_counter()
        inner_dict['random_other'] = (ged(query_instance_graph, random_graph, timeout=2))
        print('random', len(trimmed_graph), len(hit_graph), inner_dict['random_other'], time.perf_counter() - start)
        print(inner_dict)
        res_dict[motif_id] = inner_dict

    return res_dict
示例#11
0
def maga(mgraph, levels=10):
    print(f">>> Meta-graph has {len(mgraph.graph.nodes())} nodes",
          f"and {len(mgraph.graph.edges())} edges.")
    maga_graph = nx.relabel_nodes(
        mgraph.graph, {n: ms.FrozenMultiset([n])
                       for n in mgraph.graph})

    maga_graph = maga_graph.to_directed()

    # how many times to sample a cluster for boring ones
    n_boring_samples = 100
    # keep track of how many instances of each cluster are boring.
    boring_clusters = {
        c: {
            'samples': 0.01,
            'boring': 0
        }
        for c in set(mgraph.labels)
    }

    maga_tree = nx.DiGraph()
    maga_tree.add_nodes_from((ms.FrozenMultiset([n]) for n in mgraph.graph))

    # this dictionary is of the following form
    # {u: {c1: {v, w}, c2: {x}}}
    # maps each node to the clusters to which it is connected
    # for each connected cluster, we store the endpoint of the edge.
    # in this case, node `u` is connected to clusters `c1` via `v`, `w`,
    # and connected to `c2` via `x`

    maga_adj = defaultdict(def_set)

    print(">>> Building MAGA graph.")
    for n in maga_graph.nodes():
        maga_graph.nodes[n]['node_set'] = set()
    for c1, c2, d in tqdm(maga_graph.edges(data=True)):
        maga_graph[c1][c2]['edge_set'] = {
            frozenset([u, v])
            for u, v, _ in d['edge_set']
        }
        for u, v in d['edge_set']:

            maga_adj[u][mgraph.labels[v]].add(v)
            maga_adj[v][mgraph.labels[u]].add(u)

            for node in (u, v):
                clust = mgraph.labels[node]
                if boring_clusters[clust]['samples'] < n_boring_samples:
                    boring_clusters[clust]['samples'] += 1
                    node_id = mgraph.reversed_node_map[node]
                    G = whole_graph_from_node(node_id)
                    if not has_NC_bfs(G, node_id, depth=1):
                        boring_clusters[clust]['boring'] += 1

            u_node = ms.FrozenMultiset([mgraph.labels[u]])
            v_node = ms.FrozenMultiset([mgraph.labels[v]])
            maga_graph.nodes[u_node]['node_set'].add(frozenset([u]))
            maga_graph.nodes[v_node]['node_set'].add(frozenset([v]))

    # consider a cluster boring if at at least 80% of instances are boring
    boring_clusters = {clust for clust, counts in boring_clusters.items()\
                            if counts['boring'] / counts['samples'] > .8}

    print(">>> Doing MAGA.")
    maga_build = maga_next(maga_graph,
                           maga_tree,
                           maga_adj,
                           mgraph,
                           boring_clusters,
                           levels=levels)
    for l, maga_graph in enumerate(maga_build):
        print("maga level ", l)
        print("maga nodes ", len(maga_graph.nodes()), "maga edges ",
              len(maga_graph.edges()))

    return maga_graph
示例#12
0
    def retrieve_2(self, motif):
        """
        Start with a motif representative : a list of nodes that make motif.
        Build the query graph :
         -Create embeddings for the motif nodes, they need the whole graph. then do clustering and put query nodes
         in the appropriate cluster. Then add the edges that make up the connectivity of the query motif
        - Then add all nodes in a cluster that is part of the query graph in a big motif_instance set
        - Then Follow the query graph edges and connect these instances

        We maintain both a dict motif_instance { frozenset_of_ids : score}
        and a dict motifs_instances_grouped { pdb_id : set of frozensets } for a more efficient looping :
        When exploring a new edge in the query meta graph, we loop through edges that make this edge and every time
        we can only look at the frozensets in motifs_instances_grouped[current_pdb]
        :param motif:
        :return: {frozenset of node ids : score}
        """
        original_graph = whole_graph_from_node(motif[0])
        query_nodes, query_edges = self.build_query_graph(
            original_graph, motif)

        # Sort the query edges based on meta edge identity to get speedup
        # Try other sorting : the fastest is that one where we start with
        # populated edges that thus don't have to go trough a large M
        # query_edges = sorted(list(query_edges), key=lambda x: (x[2], x[3]))
        clusts_populations = {
            clust_id: len(self.graph.nodes[clust_id]['node_ids'])
            for node, clust_id in query_nodes
        }
        query_edges = sorted(
            list(query_edges),
            key=lambda x: (-sum(
                (clusts_populations[x[2]], clusts_populations[x[3]])), x[2]))

        def node_to_pdbid(node_index):
            """
                Return PDB which contains motif instance frozenset.
            """
            pdbid, _ = self.reversed_node_map[list(node_index)[0]]
            return pdbid[:4]

        def add_mnode(clust_id, mg, motifs_instances,
                      motifs_instances_grouped):
            # get all nodes in the data that might be involved ie 1 motifs that could be a part of the motif
            mnode = mg.graph.nodes[clust_id]
            node_candidates = mnode['node_ids']

            # map pdbid -> set(frozensets)
            # Turn the ids of seeds into frozensets that represent current motif nodes that get expanded
            for int_id in node_candidates:
                motif_nodes = frozenset([int_id])
                motifs_instances[motif_nodes] = float(mg.id_to_score[int_id])
                motifs_instances_grouped[node_to_pdbid(motif_nodes)].add(
                    motif_nodes)

        motifs_instances = dict()
        motifs_instances_grouped = defaultdict(set)
        visited_clusts = set()

        for edge in query_edges:
            # Try to access the corresponding meta-edge and get the list of all candidate edges
            _, _, start_clust, end_clust, _ = edge
            try:
                medge_set = self.graph.edges[start_clust,
                                             end_clust]['edge_set']
            except KeyError:
                continue

            if start_clust not in visited_clusts:
                add_mnode(start_clust,
                          mg=self,
                          motifs_instances=motifs_instances,
                          motifs_instances_grouped=motifs_instances_grouped)
                visited_clusts.add(start_clust)

            if end_clust not in visited_clusts:
                add_mnode(end_clust,
                          mg=self,
                          motifs_instances=motifs_instances,
                          motifs_instances_grouped=motifs_instances_grouped)
                visited_clusts.add(start_clust)

            # A node that was expanded should be removed after this round
            # as its merged version is strictly more promising
            visited_ones = set()
            new_ones = dict()
            # print()
            # print('=== NEW EDGE === ')
            # print('M of cardinal', len(motifs_instances))
            # print()

            for start_node, end_node, distance in medge_set:
                # Adding the nodes in the two directions can result in introducing two sets with one
                # being the subset of the other. We manually check that using a temp dict
                temp_new_ones = dict()
                current_pdb = node_to_pdbid(frozenset([start_node]))
                # start = time.perf_counter()
                for current_motif in motifs_instances_grouped[current_pdb]:
                    # for current_motif in motifs_instances:
                    # Only take expanding motifs
                    if start_node not in current_motif and end_node not in current_motif:
                        continue
                    if start_node in current_motif and end_node in current_motif:
                        continue

                    # If one of the end of the edge is in the current motif, expand it
                    # Then we remove it from the list as it would only yield inferior scores
                    visited_ones.add(current_motif)
                    score = motifs_instances[current_motif]
                    extended_motif = set(current_motif)
                    if start_node in current_motif:
                        extended_motif.add(end_node)
                        new_score = float(self.id_to_score[end_node])
                    else:
                        extended_motif.add(start_node)
                        new_score = float(self.id_to_score[start_node])

                    extended_motif = frozenset(extended_motif)
                    temp_new_ones[extended_motif] = new_score + score
                    # new_ones[extended_motif] = new_score + score

                # print(f">>> time1 {time.perf_counter() - start}")
                # start = time.perf_counter()

                # # Now we remove the doublons results (when the edge added both times resulted in an inferior result)
                # This has negligeable runtime compared to iteration
                to_remove = set()
                for sa, sb in itertools.combinations(temp_new_ones.keys(), 2):
                    if sa.issubset(sb):
                        to_remove.add(sa)
                    if sb.issubset(sa):
                        to_remove.add(sb)
                for subset in to_remove:
                    del temp_new_ones[subset]
                new_ones.update(temp_new_ones)
                # print(f">>> time2 {time.perf_counter() - start}")

            # map(motifs_instances.pop, visited_ones)
            for visited in visited_ones:
                motifs_instances.pop(visited)
                motifs_instances_grouped[node_to_pdbid(visited)].remove(
                    visited)
            motifs_instances.update(new_ones)
            for new_one in new_ones:
                motifs_instances_grouped[node_to_pdbid(new_one)].add(new_one)
        return motifs_instances