Example #1
0
def distributed_induction(graph: nx.MultiDiGraph, sample: nx.MultiDiGraph,
                          partition_map: PartitionMap, ownership: Set[Vertex]):
    # Step 1: Get non-sampled edges non-owned nodes
    edge_queries = [[] for _ in range(mpi.size)]
    for edge in filter(
            lambda e: not sample.has_edge(*e) and sample.has_node(e[0]),
            graph.edges):
        owners = partition_map.get_owners(edge[1])
        edge_queries[random.choice(owners)].append(
            edge)  # Select only one of the owners randomly
    # Step 2: Resolve induction of owned nodes
    for edge in edge_queries[mpi.rank]:
        if edge[1] in ownership:
            sample.add_edge(*edge)
    edge_queries[mpi.rank].clear()
    # Step 3: Query each node's owner for
    query_inductions(sample, edge_queries, ownership)
Example #2
0
def fuse_sequence_of_reshapes(graph: nx.MultiDiGraph):
    for node in list(graph.nodes()):
        node = Node(graph, node)
        if not graph.has_node(node.id):
            # data node can be already removed
            continue
        if (node.has_valid('type') and node.type == 'Reshape'
                and len(node.out_nodes()) == 1
                and node.out_node().has_valid('kind')
                and node.out_node().kind == 'data'
                and len(node.out_node().out_nodes()) == 1):

            log.debug('First phase for Reshape: {}'.format(node.name))

            next_op = node.out_node().out_node()
            log.debug('second node: {}'.format(next_op.graph.node[next_op.id]))
            if next_op.has_valid('type') and next_op.type == 'Reshape':
                # Detected Reshape1 --> data --> Reshape2 pattern without side edges
                # Remove Reshape1
                log.debug('Second phase for Reshape: {}'.format(node.name))
                remove_op_node_with_data_node(graph, node)
Example #3
0
class NxGraph(BaseGraph):
    """
    NxGraph is a wrapper that provides methods to interact with a networkx.MultiDiGraph.

    NxGraph extends kgx.graph.base_graph.BaseGraph and implements all the methods from BaseGraph.
    """
    def __init__(self):
        super().__init__()
        self.graph = MultiDiGraph()
        self.name = None

    def add_node(self, node: str, **kwargs: Any) -> None:
        """
        Add a node to the graph.

        Parameters
        ----------
        node: str
            Node identifier
        **kwargs: Any
            Any additional node properties

        """
        if "data" in kwargs:
            data = kwargs["data"]
        else:
            data = kwargs
        self.graph.add_node(node, **data)

    def add_edge(self,
                 subject_node: str,
                 object_node: str,
                 edge_key: str = None,
                 **kwargs: Any) -> None:
        """
        Add an edge to the graph.

        Parameters
        ----------
        subject_node: str
            The subject (source) node
        object_node: str
            The object (target) node
        edge_key: Optional[str]
            The edge key
        kwargs: Any
            Any additional edge properties

        """
        if "data" in kwargs:
            data = kwargs["data"]
        else:
            data = kwargs
        return self.graph.add_edge(subject_node,
                                   object_node,
                                   key=edge_key,
                                   **data)

    def add_node_attribute(self, node: str, attr_key: str,
                           attr_value: Any) -> None:
        """
        Add an attribute to a given node.

        Parameters
        ----------
        node: str
            The node identifier
        attr_key: str
            The key for an attribute
        attr_value: Any
            The value corresponding to the key

        """
        self.graph.add_node(node, **{attr_key: attr_value})

    def add_edge_attribute(
        self,
        subject_node: str,
        object_node: str,
        edge_key: Optional[str],
        attr_key: str,
        attr_value: Any,
    ) -> None:
        """
        Add an attribute to a given edge.

        Parameters
        ----------
        subject_node: str
            The subject (source) node
        object_node: str
            The object (target) node
        edge_key: Optional[str]
            The edge key
        attr_key: str
            The attribute key
        attr_value: Any
            The attribute value

        """
        self.graph.add_edge(subject_node,
                            object_node,
                            key=edge_key,
                            **{attr_key: attr_value})

    def update_node_attribute(self,
                              node: str,
                              attr_key: str,
                              attr_value: Any,
                              preserve: bool = False) -> Dict:
        """
        Update an attribute of a given node.

        Parameters
        ----------
        node: str
            The node identifier
        attr_key: str
            The key for an attribute
        attr_value: Any
            The value corresponding to the key
        preserve: bool
            Whether or not to preserve existing values for the given attr_key

        Returns
        -------
        Dict
            A dictionary corresponding to the updated node properties

        """
        node_data = self.graph.nodes[node]
        updated = prepare_data_dict(node_data, {attr_key: attr_value},
                                    preserve=preserve)
        self.graph.add_node(node, **updated)
        return updated

    def update_edge_attribute(
        self,
        subject_node: str,
        object_node: str,
        edge_key: Optional[str],
        attr_key: str,
        attr_value: Any,
        preserve: bool = False,
    ) -> Dict:
        """
        Update an attribute of a given edge.

        Parameters
        ----------
        subject_node: str
            The subject (source) node
        object_node: str
            The object (target) node
        edge_key: Optional[str]
            The edge key
        attr_key: str
            The attribute key
        attr_value: Any
            The attribute value
        preserve: bool
            Whether or not to preserve existing values for the given attr_key

        Returns
        -------
        Dict
            A dictionary corresponding to the updated edge properties

        """
        e = self.graph.edges((subject_node, object_node, edge_key),
                             keys=True,
                             data=True)
        edge_data = list(e)[0][3]
        updated = prepare_data_dict(edge_data, {attr_key: attr_value},
                                    preserve)
        self.graph.add_edge(subject_node, object_node, key=edge_key, **updated)
        return updated

    def get_node(self, node: str) -> Dict:
        """
        Get a node and its properties.

        Parameters
        ----------
        node: str
            The node identifier

        Returns
        -------
        Dict
            The node dictionary

        """
        n = {}
        if self.graph.has_node(node):
            n = self.graph.nodes[node]
        return n

    def get_edge(self,
                 subject_node: str,
                 object_node: str,
                 edge_key: Optional[str] = None) -> Dict:
        """
        Get an edge and its properties.

        Parameters
        ----------
        subject_node: str
            The subject (source) node
        object_node: str
            The object (target) node
        edge_key: Optional[str]
            The edge key

        Returns
        -------
        Dict
            The edge dictionary

        """
        e = {}
        if self.graph.has_edge(subject_node, object_node, edge_key):
            e = self.graph.get_edge_data(subject_node, object_node, edge_key)
        return e

    def nodes(self, data: bool = True) -> Dict:
        """
        Get all nodes in a graph.

        Parameters
        ----------
        data: bool
            Whether or not to fetch node properties

        Returns
        -------
        Dict
            A dictionary of nodes

        """
        return self.graph.nodes(data)

    def edges(self, keys: bool = False, data: bool = True) -> Dict:
        """
        Get all edges in a graph.

        Parameters
        ----------
        keys: bool
            Whether or not to include edge keys
        data: bool
            Whether or not to fetch node properties

        Returns
        -------
        Dict
            A dictionary of edges

        """
        return self.graph.edges(keys=keys, data=data)

    def in_edges(self,
                 node: str,
                 keys: bool = False,
                 data: bool = False) -> List:
        """
        Get all incoming edges for a given node.

        Parameters
        ----------
        node: str
            The node identifier
        keys: bool
            Whether or not to include edge keys
        data: bool
            Whether or not to fetch node properties

        Returns
        -------
        List
            A list of edges

        """
        return self.graph.in_edges(node, keys=keys, data=data)

    def out_edges(self,
                  node: str,
                  keys: bool = False,
                  data: bool = False) -> List:
        """
        Get all outgoing edges for a given node.

        Parameters
        ----------
        node: str
            The node identifier
        keys: bool
            Whether or not to include edge keys
        data: bool
            Whether or not to fetch node properties

        Returns
        -------
        List
            A list of edges

        """
        return self.graph.out_edges(node, keys=keys, data=data)

    def nodes_iter(self) -> Generator:
        """
        Get an iterable to traverse through all the nodes in a graph.

        Returns
        -------
        Generator
            A generator for nodes where each element is a Tuple that
            contains (node_id, node_data)

        """
        for n in self.graph.nodes(data=True):
            yield n

    def edges_iter(self) -> Generator:
        """
        Get an iterable to traverse through all the edges in a graph.

        Returns
        -------
        Generator
            A generator for edges where each element is a 4-tuple that
            contains (subject, object, edge_key, edge_data)

        """
        for u, v, k, data in self.graph.edges(keys=True, data=True):
            yield u, v, k, data

    def remove_node(self, node: str) -> None:
        """
        Remove a given node from the graph.

        Parameters
        ----------
        node: str
            The node identifier

        """
        self.graph.remove_node(node)

    def remove_edge(self,
                    subject_node: str,
                    object_node: str,
                    edge_key: Optional[str] = None) -> None:
        """
        Remove a given edge from the graph.

        Parameters
        ----------
        subject_node: str
            The subject (source) node
        object_node: str
            The object (target) node
        edge_key: Optional[str]
            The edge key

        """
        self.graph.remove_edge(subject_node, object_node, edge_key)

    def has_node(self, node: str) -> bool:
        """
        Check whether a given node exists in the graph.

        Parameters
        ----------
        node: str
            The node identifier

        Returns
        -------
        bool
            Whether or not the given node exists

        """
        return self.graph.has_node(node)

    def has_edge(self,
                 subject_node: str,
                 object_node: str,
                 edge_key: Optional[str] = None) -> bool:
        """
        Check whether a given edge exists in the graph.

        Parameters
        ----------
        subject_node: str
            The subject (source) node
        object_node: str
            The object (target) node
        edge_key: Optional[str]
            The edge key

        Returns
        -------
        bool
            Whether or not the given edge exists

        """
        return self.graph.has_edge(subject_node, object_node, key=edge_key)

    def number_of_nodes(self) -> int:
        """
        Returns the number of nodes in a graph.

        Returns
        -------
        int

        """
        return self.graph.number_of_nodes()

    def number_of_edges(self) -> int:
        """
        Returns the number of edges in a graph.

        Returns
        -------
        int

        """
        return self.graph.number_of_edges()

    def degree(self):
        """
        Get the degree of all the nodes in a graph.
        """
        return self.graph.degree()

    def clear(self) -> None:
        """
        Remove all the nodes and edges in the graph.
        """
        self.graph.clear()

    @staticmethod
    def set_node_attributes(graph: BaseGraph, attributes: Dict) -> None:
        """
        Set nodes attributes from a dictionary of key-values.

        Parameters
        ----------
        graph: kgx.graph.base_graph.BaseGraph
            The graph to modify
        attributes: Dict
            A dictionary of node identifier to key-value pairs

        """
        return set_node_attributes(graph.graph, attributes)

    @staticmethod
    def set_edge_attributes(graph: BaseGraph, attributes: Dict) -> None:
        """
        Set nodes attributes from a dictionary of key-values.

        Parameters
        ----------
        graph: kgx.graph.base_graph.BaseGraph
            The graph to modify
        attributes: Dict
            A dictionary of node identifier to key-value pairs

        Returns
        -------
        Any

        """
        return set_edge_attributes(graph.graph, attributes)

    @staticmethod
    def get_node_attributes(graph: BaseGraph, attr_key: str) -> Dict:
        """
        Get all nodes that have a value for the given attribute ``attr_key``.

        Parameters
        ----------
        graph: kgx.graph.base_graph.BaseGraph
            The graph to modify
        attr_key: str
            The attribute key

        Returns
        -------
        Dict
            A dictionary where nodes are the keys and the values
            are the attribute values for ``key``

        """
        return get_node_attributes(graph.graph, attr_key)

    @staticmethod
    def get_edge_attributes(graph: BaseGraph, attr_key: str) -> Dict:
        """
        Get all edges that have a value for the given attribute ``attr_key``.

        Parameters
        ----------
        graph: kgx.graph.base_graph.BaseGraph
            The graph to modify
        attr_key: str
            The attribute key

        Returns
        -------
        Dict
            A dictionary where edges are the keys and the values
            are the attribute values for ``attr_key``

        """
        return get_edge_attributes(graph.graph, attr_key)

    @staticmethod
    def relabel_nodes(graph: BaseGraph, mapping: Dict) -> None:
        """
        Relabel identifiers for a series of nodes based on mappings.

        Parameters
        ----------
        graph: kgx.graph.base_graph.BaseGraph
            The graph to modify
        mapping: Dict
            A dictionary of mapping where the key is the old identifier
            and the value is the new identifier.

        """
        relabel_nodes(graph.graph, mapping, copy=False)
Example #4
0
class MossNet:
    def __init__(self, moss_results_dict):
        '''Create a ``MossNet`` object from a 3D dictionary of downloaded MOSS results

        Args:
            ``moss_results_dict`` (``dict``): A 3D dictionary of downloaded MOSS results

        Returns:
            ``MossNet``: A ``MossNet`` object
        '''
        if isinstance(moss_results_dict, MultiDiGraph):
            self.graph = moss_results_dict; return
        if isinstance(moss_results_dict, str):
            try:
                if moss_results_dict.lower().endswith('.gz'):
                    moss_results_dict = load(gopen(moss_results_dict))
                else:
                    moss_results_dict = load(open(moss_results_dict,'rb'))
            except:
                raise ValueError("Unable to load dictionary: %s" % moss_results_dict)
        if not isinstance(moss_results_dict, dict):
            raise TypeError("moss_results_dict must be a 3D dictionary of MOSS results")
        self.graph = MultiDiGraph()
        for u in moss_results_dict:
            u_edges = moss_results_dict[u]
            if not isinstance(u_edges, dict):
                raise TypeError("moss_results_dict must be a 3D dictionary of MOSS results")
            for v in u_edges:
                u_v_links = u_edges[v]
                if not isinstance(u_edges[v], dict):
                    raise TypeError("moss_results_dict must be a 3D dictionary of MOSS results")
                for f in u_v_links:
                    try:
                        left, right = u_v_links[f]
                    except:
                        raise TypeError("moss_results_dict must be a 3D dictionary of MOSS results")
                    self.graph.add_edge(u, v, attr_dict = {'files':f, 'left':left, 'right':right})

    def save(self, outfile):
        '''Save this ``MossNet`` object as a 3D dictionary of MOSS results

        Args:
            ``outfile`` (``str``): The desired output file's path
        '''
        out = dict()
        for u in self.graph.nodes:
            u_edges = dict(); out[u] = u_edges
            for v in self.graph.neighbors(u):
                u_v_links = dict(); u_edges[v] = u_v_links; u_v_edge_data = self.graph.get_edge_data(u,v)
                for k in u_v_edge_data:
                    edge = u_v_edge_data[k]['attr_dict']; u_v_links[edge['files']] = (edge['left'], edge['right'])
        if outfile.lower().endswith('.gz'):
            f = gopen(outfile, mode='wb', compresslevel=9)
        else:
            f = open(outfile, 'wb')
        pkldump(out, f); f.close()

    def __add__(self, o):
        if not isinstance(o, MossNet):
            raise TypeError("unsupported operand type(s) for +: 'MossNet' and '%s'" % type(o).__name__)
        g = MultiDiGraph()
        g.add_edges_from(list(self.graph.edges(data=True)) + list(o.graph.edges(data=True)))
        g.add_nodes_from(list(self.graph.nodes(data=True)) + list(o.graph.nodes(data=True)))
        return MossNet(g)

    def get_networkx(self):
        '''Return a NetworkX ``MultiDiGraph`` equivalent to this ``MossNet`` object

        Returns:
            ``MultiDiGraph``: A NetworkX ``DiGraph`` equivalent to this ``MossNet`` object
        '''
        return self.graph.copy()

    def get_nodes(self):
        '''Returns a ``set`` of node labels in this ``MossNet`` object

        Returns:
            ``set``: The node labels in this ``MossNet`` object
        '''
        return set(self.graph.nodes)

    def get_pair(self, u, v, style='tuples'):
        '''Returns the links between nodes ``u`` and ``v``

        Args:
            ``u`` (``str``): A node label

            ``v`` (``str``): A node label not equal to ``u``

            ``style`` (``str``): The representation of a given link

            * ``"tuples"``: Links are ``((u_percent, u_html), (v_percent, v_html))`` tuples

            * ``"html"``: Links are HTML representation (one HTML for all links)

            * ``"htmls"``: Links are HTML representations (one HTML per link)

        Returns:
            ``dict``: The links between ``u`` and ``v`` (keys are filenames)
        '''
        if style not in {'tuples', 'html', 'htmls'}:
            raise ValueError("Invalid link style: %s" % style)
        if u == v:
            raise ValueError("u and v cannot be equal: %s" % u)
        for node in [u,v]:
            if not self.graph.has_node(node):
                raise ValueError("Nonexistant node: %s" % node)
        links = self.graph.get_edge_data(u,v)
        out = dict()
        for k in sorted(links.keys(), key=lambda x: links[x]['attr_dict']['files']):
            d = links[k]['attr_dict']
            u_fn, v_fn = d['files']
            u_percent, u_html = d['left']
            v_percent, v_html = d['right']
            if style == 'tuples':
                out[(u_fn, v_fn)] = ((u_percent, u_html), (v_percent, v_html))
            elif style in {'html', 'htmls'}:
                out[(u_fn, v_fn)] = '<html><table style="width:100%%" border="1"><tr><td colspan="2"><center><b>%s/%s --- %s/%s</b></center></td></tr><tr><td>%s (%d%%)</td><td>%s (%d%%)</td></tr><tr><td><pre>%s</pre></td><td><pre>%s</pre></td></tr></table></html>' % (u, u_fn, v, v_fn, u, u_percent, v, v_percent, u_html, v_html)
        if style == 'html':
            out = '<html>' + '<br>'.join(out[fns].replace('<html>','').replace('</html>','') for fns in sorted(out.keys())) + '</html>'
        return out

    def get_summary(self, style='html'):
        '''Returns a summary of this ``MossNet``

        Args:
            ``style`` (``str``): The representation of this ``MossNet``

        Returns:
            ``dict``: A summary of this ``MossNet``, where keys are filenames
        '''
        if style not in {'html'}:
            raise ValueError("Invalid summary style: %s" % style)
        matches = list() # list of (u_path, u_percent, v_path, v_percent) tuples
        for u,v in self.traverse_pairs(order=None):
            links = self.graph.get_edge_data(u,v)
            for k in links:
                d = links[k]['attr_dict']
                u_fn, v_fn = d['files']
                u_percent, u_html = d['left']
                v_percent, v_html = d['right']
                matches.append(('%s/%s' % (u,u_fn), u_percent, '%s/%s' % (v,v_fn), v_percent))
        matches.sort(reverse=True, key=lambda x: max(x[1],x[3]))
        return '<html><table style="width:100%%" border="1">%s</table></html>' % ''.join(('<tr><td>%s (%d%%)</td><td>%s (%d%%)</td></tr>' % tup) for tup in matches)

    def num_links(self, u, v):
        '''Returns the number of links between ``u`` and ``v``

        Args:
            ``u`` (``str``): A node label

            ``v`` (``str``): A node label not equal to ``u``

        Returns:
            ``int``: The number of links between ``u`` and ``v``
        '''
        for node in [u,v]:
            if not self.graph.has_node(node):
                raise ValueError("Nonexistant node: %s" % node)
        return len(self.graph.get_edge_data(u,v))

    def num_nodes(self):
        '''Returns the number of nodes in this ``MossNet`` object

        Returns:
            ``int``: The number of nodes in this ``MossNet`` object
        '''
        return self.graph.number_of_nodes()

    def num_edges(self):
        '''Returns the number of (undirected) edges in this ``MossNet`` object (including parallel edges)

        Returns:
            ``int``: The number of (undirected) edges in this ``MossNet`` object (including parallel edges)
        '''
        return int(self.graph.number_of_edges()/2)

    def outlier_pairs(self):
        '''Predict which student pairs are outliers (i.e., too many problem similarities).
        The distribution of number of links between student pairs (i.e., histogram) is modeled as y = A/(B^x),
        where x = a number of links, and y = the number of student pairs with that many links

        Returns:
            ``list`` of ``tuple``: The student pairs expected to be outliers (in decreasing order of significance)
        '''
        links = dict() # key = number of links; value = set of student pairs that have that number of links
        for u,v in self.traverse_pairs():
            n = self.num_links(u,v)
            if n not in links:
                links[n] = set()
            links[n].add((u,v))
        mult = list(); min_links = min(len(s) for s in links.values()); max_links = max(len(s) for s in links.values())
        for i in range(min_links, max_links):
            if i not in links or i+1 not in links or len(links[i+1]) > len(links[i]):
                break
            mult.append(float(len(links[i]))/len(links[i+1]))
        B = sum(mult)/len(mult)
        A = len(links[min_links]) * (B**min_links)
        n_cutoff = log(A)/log(B)
        out = list()
        for n in sorted(links.keys(), reverse=True):
            if n < n_cutoff:
                break
            for u,v in links[n]:
                out.append((n,u,v))
        return out

    def traverse_pairs(self, order='descending'):
        '''Iterate over student pairs

        Args:
            ``order`` (``str``): Order to sort pairs in iteration

            * ``None`` to not sort (may be faster for large/dense graphs)

            * ``"ascending"`` to sort in ascending order of number of links

            * ``"descending"`` to sort in descending order of number of links
        '''
        if order not in {None, 'None', 'none', 'ascending', 'descending'}:
            raise ValueError("Invalid order: %s" % order)
        nodes = list(self.graph.nodes)
        pairs = [(u,v) for u in self.graph.nodes for v in self.graph.neighbors(u) if u < v]
        if order == 'ascending':
            pairs.sort(key=lambda x: len(self.graph.get_edge_data(x[0],x[1])))
        elif order == 'descending':
            pairs.sort(key=lambda x: len(self.graph.get_edge_data(x[0],x[1])), reverse=True)
        for pair in pairs:
            yield pair

    def export(self, outpath, style='html', gte=0, verbose=False):
        '''Export the links in this ``MossNet`` in the specified style

        Args:
            ``outpath`` (``str``): Path to desired output folder/file

            ``style`` (``str``): Desired output style

            ``gte`` (``int``): The minimum number of links for an edge to be exported

            * ``"dot"`` to export as a GraphViz DOT file

            * ``"gexf"`` to export as a Graph Exchange XML Format (GEXF) file

            * ``"html"`` to export one HTML file per pair

            ``verbose`` (``bool``): ``True`` to show verbose messages, otherwise ``False``
        '''
        if style not in {'dot', 'gexf', 'html'}:
            raise ValueError("Invalid export style: %s" % style)
        if isdir(outpath) or isfile(outpath):
            raise ValueError("Output path exists: %s" % outpath)
        if not isinstance(gte, int):
            raise TypeError("'gte' must be an 'int', but you provided a '%s'" % type(gte).__name__)
        if gte < 0:
            raise ValueError("'gte' must be non-negative, but yours was %d" % gte)

        # export as folder of HTML files
        if style == 'html':
            summary = self.get_summary(style='html')
            pairs = list(self.traverse_pairs(order=None))
            makedirs(outpath)
            f = open('%s/summary.html' % outpath, 'w'); f.write(summary); f.close()
            for i,pair in enumerate(pairs):
                if verbose:
                    print("Exporting pair %d of %d..." % (i+1, len(pairs)), end='\r')
                u,v = pair
                if self.num_links(u,v) < gte:
                    continue
                if style == 'html':
                    f = open("%s/%d_%s_%s.html" % (outpath, self.num_links(u,v), u, v), 'w')
                    f.write(self.get_pair(u, v, style='html'))
                    f.close()
            if verbose:
                print("Successfully exported %d pairs" % len(pairs))

        # export as GraphViz DOT or a GEXF file
        elif style in {'dot', 'gexf'}:
            if verbose:
                print("Computing colors...", end='')
            max_links = max(self.num_links(u,v) for u,v in self.traverse_pairs())
            try:
                from seaborn import color_palette
            except:
                raise RuntimeError("Exporting as a DOT or GEXF file currently requires seaborn")
            pal = color_palette("Reds", max_links)
            if verbose:
                print(" done")
                print("Computing node information...", end='')
            nodes = list(self.get_nodes())
            index = {u:i for i,u in enumerate(nodes)}
            if verbose:
                print(" done")
                print("Writing output file...", end='')
            outfile = open(outpath, 'w')
            if style == 'dot':
                pal = [str(c).upper() for c in pal.as_hex()]
                outfile.write("graph G {\n")
                for u in nodes:
                    outfile.write('  node%d[label="%s"]\n' % (index[u], u))
                for u,v in self.traverse_pairs():
                    curr_num_links = self.num_links(u,v)
                    if curr_num_links < gte:
                        continue
                    outfile.write('  node%d -- node%d[color="%s"]\n' % (index[u], index[v], pal[curr_num_links-1]))
                outfile.write('}\n')
            elif style == 'gexf':
                from datetime import datetime
                pal = [(int(255*c[0]), int(255*c[1]), int(255*c[2])) for c in pal]
                outfile.write('<?xml version="1.0" encoding="UTF-8"?>\n')
                outfile.write('<gexf xmlns="http://www.gexf.net/1.3draft" xmlns:viz="http://www.gexf.net/1.3draft/viz">\n')
                outfile.write('  <meta lastmodifieddate="%s">\n' % datetime.today().strftime('%Y-%m-%d'))
                outfile.write('    <creator>MossNet</creator>\n')
                outfile.write('    <description>A MossNet network exported to GEXF</description>\n')
                outfile.write('  </meta>\n')
                outfile.write('  <graph mode="static" defaultedgetype="undirected">\n')
                outfile.write('    <nodes>\n')
                for u in nodes:
                    outfile.write('      <node id="%d" label="%s"/>\n' % (index[u], u))
                outfile.write('    </nodes>\n')
                outfile.write('    <edges>\n')
                for i,pair in enumerate(self.traverse_pairs()):
                    u,v = pair
                    curr_num_links = self.num_links(u,v)
                    if curr_num_links == 0:
                        continue
                    color = pal[curr_num_links-1]
                    outfile.write('      <edge id="%d" source="%d" target="%d">\n' % (i, index[u], index[v]))
                    outfile.write('        <viz:color r="%d" g="%d" b="%d"/>\n' % (color[0], color[1], color[2]))
                    outfile.write('      </edge>\n')
                outfile.write('    </edges>\n')
                outfile.write('  </graph>\n')
                outfile.write('</gexf>\n')
            outfile.close()
            if verbose:
                print(" done")
Example #5
0
def merge_nodes(graph: nx.MultiDiGraph,
                nodes_to_merge_names: list,
                inputs_desc: list = None,
                outputs_desc: list = None):
    """
    Merges nodes specified in the set 'nodes_to_merge_names' into one mega-node, creating new edges between mega-node
    and inputs/outputs nodes of the mega-node. The added edges contain name of input/output nodes which will be used for
    generation of placeholders and will be saved to the IR xml so IE plug-in know how to map input/output data for the
    layer. Also the function adds protobufs of the nodes of the sub-graph and 'Const' ops consumed by nodes in the
    sub-graph to the node's attribute 'pbs'.
    :param graph: the graph object to operate on.
    :param nodes_to_merge_names: list of nodes names that should be merged into a single node.
    :param inputs_desc: optional list describing input nodes order.
    :param outputs_desc: optional list describing output nodes order.
    """
    if not is_connected_component(graph, nodes_to_merge_names):
        log.warning(
            "The following nodes do not form connected sub-graph: {}".format(
                nodes_to_merge_names))
        dump_graph_for_graphviz(graph, nodes_to_dump=nodes_to_merge_names)

    new_node_name = unique_id(graph, "TFSubgraphCall_")
    log.info("Create new node with name '{}' for nodes '{}'".format(
        new_node_name, ', '.join(nodes_to_merge_names)))
    graph.add_node(new_node_name)
    new_node_attrs = graph.node[new_node_name]

    new_node_attrs['name'] = new_node_name
    set_tf_custom_call_node_attrs(new_node_attrs)
    new_node = Node(graph, new_node_name)

    added_input_tensors_names = set(
    )  # set of tensors that are were added as input to the sub-graph
    added_new_node_output_tensors = dict(
    )  # key - tensor name, value - out port

    for node_name in nodes_to_merge_names:
        node = Node(graph, node_name)
        add_node_pb_if_not_yet_added(node, new_node)
        for in_node_name, edge_attrs in get_inputs(graph, node_name):
            in_node = Node(graph, in_node_name)

            # internal edges between nodes of the sub-graph
            if in_node_name in nodes_to_merge_names:
                add_node_pb_if_not_yet_added(in_node, new_node)
                continue

            # edge outside of sub-graph into sub-graph
            if in_node_name not in nodes_to_merge_names:
                # we cannot use the 'in_node_name' as a protobuf operation name here
                # because the 'in_node_name' could be a sub-graph matched before.
                input_tensor_name = node.pb.input[edge_attrs['in']]
                if input_tensor_name not in added_input_tensors_names:
                    graph.add_edge(
                        in_node_name, new_node_name,
                        **merge_edge_props(
                            {
                                'in':
                                find_input_port(new_node, inputs_desc,
                                                node_name, edge_attrs['in']),
                                'out':
                                edge_attrs['out'],
                                'internal_input_node_name':
                                input_tensor_name,
                                'original_dst_node_name':
                                node_name,
                                'original_dst_port':
                                edge_attrs['in'],
                                'in_attrs': [
                                    'in', 'internal_input_node_name',
                                    'original_dst_node_name',
                                    'original_dst_port', 'placeholder_name'
                                ],
                                'out_attrs': ['out']
                            }, edge_attrs))
                    log.debug(
                        "Creating edge from outside of sub-graph to inside sub-graph: {} -> {}"
                        .format(in_node_name, new_node_name))
                    added_input_tensors_names.add(input_tensor_name)

        # edge from inside sub-graph to outside sub-graph
        for out_node_name, edge_attrs in get_outputs(graph, node_name):
            if out_node_name not in nodes_to_merge_names:
                log.debug(
                    "Creating edge from inside of sub-graph to outside sub-graph: {} -> {}"
                    .format(new_node_name, out_node_name))
                out_name = internal_output_name_for_node(
                    node_name, edge_attrs['out'])
                if out_name not in added_new_node_output_tensors.keys():
                    added_new_node_output_tensors[out_name] = find_output_port(
                        new_node, outputs_desc, node_name, edge_attrs['out'])
                graph.add_edge(
                    new_node_name, out_node_name,
                    **merge_edge_props(
                        {
                            'in': edge_attrs['in'],
                            'out': added_new_node_output_tensors[out_name],
                            'internal_output_node_name': out_name,
                            'in_attrs': ['in', 'internal_input_node_name'],
                            'out_attrs': ['out', 'internal_output_node_name']
                        }, edge_attrs))
        new_node['output_tensors_names'] = [
            val for val in
            {v: k
             for k, v in added_new_node_output_tensors.items()}.values()
        ]

    # add nodes using the same order as in initial GraphDef so we can dump them to IR in "correct" order
    new_node['nodes_order'] = [
        node for node in graph.graph['initial_nodes_order']
        if node in new_node['pbs'].keys()
    ]

    for n in nodes_to_merge_names:
        if graph.has_node(
                n):  # check if not deleted by another (similar) pattern
            graph.remove_node(n)
    return Node(graph, new_node_name)