Example #1
0
def propagate_downwards(G: NNGraph):
    for node in G.dfs():
        # First propagate the in dim hints to the out dim hints
        # Any node that does not want this to happen should set its out dim hints

        if node.in_dims_hint is not None:
            if isinstance(node, ReshapeParameters):
                if len(node.old_shape) == len(node.in_dims_hint[0]):
                    LOG.debug("set reshape %s in dims hint %s", node.name,
                              node.in_dims_hint[0])
                    node.old_shape.apply_naming_hints(node.in_dims_hint[0])
            elif isinstance(node, GlobalPoolParameters):
                if node.keep_dims:
                    node.out_dims_hint = deepcopy(node.in_dims_hint)
            elif isinstance(node, MatrixBroadcastedLinearOpParameters):
                max_hint = None
                for hint in node.in_dims_hint:
                    if hint is not None and (max_hint is None
                                             or len(hint) > len(max_hint)):
                        max_hint = hint
                if max_hint is not None:
                    node.out_dims_hint = [max_hint]
            elif isinstance(node, ConcatParameters):
                # if any incoming edge of the concat doesn't have a hint
                # set it the same as the others
                any_in_hint = next(
                    (hint for hint in node.in_dims_hint if hint is not None),
                    None)
                if any_in_hint:
                    LOG.debug("set concat %s in dims hint %s", node.name,
                              any_in_hint)
                    for edge in G.in_edges(node.name):
                        if not node.in_dims_hint[edge.to_idx]:
                            node.in_dims_hint[edge.to_idx] = any_in_hint
                    node.out_dims_hint = [any_in_hint]
            else:
                if node.out_dims_hint is None:
                    node.out_dims_hint = deepcopy(node.in_dims_hint)

        # if we have an out dim hint then propagate it to downstream nodes
        if node.out_dims_hint is not None:
            LOG.debug("propagate down hint from %s", node.name)
            for edge in G.out_edges(node.name):
                hint = node.out_dims_hint[edge.from_idx]
                if hint is None:
                    continue
                if edge.to_node.in_dims_hint is None:
                    edge.to_node.in_dims_hint = SparseList()
                if edge.to_node.in_dims_hint[edge.to_idx] is None:
                    edge.to_node.in_dims_hint[edge.to_idx] = hint
Example #2
0
def propagate_upwards(G: NNGraph):
    for node in G.dfs(reverse=True):
        # First propagate the out dim hints to the in dim hints
        # Any node that does not want this to happen should set its in dim hints

        if node.out_dims_hint is not None:
            if isinstance(node, ReshapeParameters):
                if len(node.shape) < len(node.out_dims_hint[0]):
                    node.shape = Dim.unnamed((
                        [1] * (len(node.out_dims_hint[0]) - len(node.shape))) +
                                             node.shape.shape)
                node.shape.apply_naming_hints(node.out_dims_hint[0])
                if node.in_dims_hint is None:
                    node.in_dims_hint = SparseList(
                        [["%s" % i for i in range(len(node.old_shape))]])
            elif isinstance(node, MatrixBroadcastedLinearOpParameters):
                node.in_dims_hint = [node.out_dims_hint[0]] * 2
            elif isinstance(node, MatrixMulParameters):
                continue
            elif isinstance(node, GlobalPoolParameters):
                if node.keep_dims:
                    node.in_dims_hint = deepcopy(node.out_dims_hint)
            elif isinstance(
                    node, ConstantInputParameters) and not node.dims.is_named:
                node.dims.apply_naming_hints(node.out_dims_hint[0])
            else:
                if node.in_dims_hint is None:
                    node.in_dims_hint = deepcopy(node.out_dims_hint)

        # if we have an in dim hint then propagate it to upstream nodes
        if node.in_dims_hint is not None:
            for edge in G.in_edges(node.name):
                hint = node.in_dims_hint[edge.to_idx]
                if hint is None:
                    continue
                if edge.from_node.out_dims_hint is None:
                    edge.from_node.out_dims_hint = SparseList()
                if edge.from_node.out_dims_hint[edge.from_idx] is None:
                    edge.from_node.out_dims_hint[edge.from_idx] = hint
                    if isinstance(edge.from_node, InputParameters):
                        assert edge.from_idx == 0, "input node should only have one output"
                        dims_len = len(edge.from_node.dims)
                        hint_len = len(hint)
                        if dims_len < hint_len:
                            edge.from_node.dims = Dim.unnamed(
                                [1] * (hint_len - dims_len) +
                                edge.from_node.dims.shape)
Example #3
0
def propagate_downwards(G: NNGraph):
    for node in G.dfs():
        # First propagate the in dim hints to the out dim hints
        # Any node that does not want this to happen should set its out dim hints

        if node.in_dims_hint is not None:
            if isinstance(node, ReshapeParameters):
                assert len(node.old_shape) == len(node.in_dims_hint[0]), "reshape doesn't match input"
                node.old_shape.apply_naming_hints(node.in_dims_hint[0])
            else:
                if node.out_dims_hint is None:
                    node.out_dims_hint = deepcopy(node.in_dims_hint)

        # if we have an out dim hint then propagate it to downstream nodes
        if node.out_dims_hint is not None:
            for edge in G.out_edges(node.name):
                
                hint = node.out_dims_hint[edge.from_idx]
                if edge.to_node.in_dims_hint is None:
                    edge.to_node.in_dims_hint = SparseList()
                if edge.to_node.in_dims_hint[edge.to_idx] is None:
                    edge.to_node.in_dims_hint[edge.to_idx] = hint
Example #4
0
def propagate_upwards(G: NNGraph):
    for node in G.dfs(reverse=True):
        # First propagate the out dim hints to the in dim hints
        # Any node that does not want this to happen should set its in dim hints

        if node.out_dims_hint is not None:
            if isinstance(node, ReshapeParameters):
                assert len(node.shape) == len(node.out_dims_hint[0])
                node.shape.apply_naming_hints(node.out_dims_hint[0])
                if node.in_dims_hint is None:
                    node.in_dims_hint = SparseList([["%s" % i for i in range(len(node.old_shape))]])
            else:
                if node.in_dims_hint is None:
                    node.in_dims_hint = deepcopy(node.out_dims_hint)

        # if we have an in dim hint then propagate it to upstream nodes
        if node.in_dims_hint is not None:
            for edge in G.in_edges(node.name):
                hint = node.in_dims_hint[edge.to_idx]
                if edge.from_node.out_dims_hint is None:
                    edge.from_node.out_dims_hint = SparseList()
                if edge.from_node.out_dims_hint[edge.from_idx] is None:
                    edge.from_node.out_dims_hint[edge.from_idx] = hint
    def report_graph(self,
                     G: NNGraph,
                     dot,
                     all_ports,
                     fake_idx,
                     nodes=None,
                     all_dims=False,
                     anonymise=False,
                     expressions=False,
                     qrecs=None,
                     fusions=False,
                     parent=None):
        if nodes is None:
            nodes = set(G.nodes())
        for node in G.dfs():
            if node not in nodes:
                continue
            if isinstance(node, (FusionInputParameters)):
                continue
            if expressions and isinstance(node, ExpressionFusionParameters):
                all_ports[node] = self.report_expression(
                    dot,
                    G,
                    node,
                    anonymise=anonymise,
                    report_quantized=expressions == "quantized")
            elif fusions and isinstance(node, FusionBase):
                all_ports[node] = self.report_fusion(dot,
                                                     G,
                                                     node,
                                                     all_ports,
                                                     fake_idx,
                                                     all_dims=all_dims,
                                                     anonymise=anonymise,
                                                     expressions=expressions,
                                                     qrecs=qrecs)

            else:
                num_in_edges = len(G.indexed_in_edges(node.name))
                num_out_edges = len(G.indexed_out_edges(node.name))
                ports = all_ports.setdefault(node, [None] * 2)
                if not isinstance(node, FusionOutputParameters):
                    names = self.build_nodebox(node,
                                               ports,
                                               num_in_edges,
                                               num_out_edges,
                                               anon=anonymise)
                    dot.node(
                        node.name,
                        nohtml(names),
                        shape='record',
                        xlabel=str(node.step_idx),
                        color="blue" if node.is_not_generated else "black")
            for edge in G.in_edges(node.name):
                if edge.from_node not in nodes:
                    if not all_dims:
                        continue

                out_port, in_port = self.get_ports(all_ports, edge)
                if edge.from_node in nodes:
                    from_node_id = self.get_from_id(all_ports, edge, out_port)
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    edge_label, edge_error = self.in_label(
                        G,
                        edge,
                        qrecs,
                        parent=parent,
                        from_node=not isinstance(edge.from_node,
                                                 FusionInputParameters),
                        to_node=not isinstance(edge.to_node,
                                               FusionOutputParameters))
                    dot.edge(from_node_id,
                             to_node_id,
                             xlabel=edge_label,
                             color="red" if edge_error else "black")
                else:
                    fake_name = f'fake_{fake_idx}'
                    fake_idx += 1
                    dot.node(fake_name, shape='point', fillcolor='black')
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    edge_label, edge_error = self.in_label(G,
                                                           edge,
                                                           qrecs,
                                                           parent=parent)
                    dot.edge(fake_name,
                             to_node_id,
                             xlabel=edge_label,
                             color="red" if edge_error else "black")
            if not all_dims:
                continue
            for edge_group in G.indexed_out_edges(node.name):
                if any(edge.to_node in nodes for edge in edge_group):
                    continue
                edge = edge_group[0]
                out_port, _ = self.get_ports(all_ports, edge)
                fake_name = f'fake_{fake_idx}'
                fake_idx += 1
                dot.node(fake_name,
                         shape='plaintext',
                         label=' ',
                         fillcolor='black')
                from_node_id = self.get_from_id(all_ports, edge, out_port)
                edge_label, edge_error = self.out_label(
                    G,
                    edge,
                    qrecs,
                    parent=parent,
                    from_node=not isinstance(edge.from_node,
                                             FusionInputParameters),
                    to_node=not isinstance(edge.to_node,
                                           FusionOutputParameters))
                dot.edge(from_node_id,
                         fake_name,
                         xlabel=edge_label,
                         color="red" if edge_error else "black")
Example #6
0
    def report(self,
               G: NNGraph,
               nodes=None,
               graph_format='PDF',
               all_dims=False,
               filename=None,
               view=True,
               anonymise=False,
               expressions=False,
               quant_labels=False):
        if nodes is None:
            nodes = set(G.nodes())

        self.init_name_cache()
        all_ports = {}
        graph_name = G.graphname if hasattr(G, 'graphname') else 'graph'
        dot = Digraph(comment=graph_name,
                      format=graph_format,
                      node_attr={'height': '.1'},
                      edge_attr={'fontsize': '10.0'})
        fake_idx = 0
        for node in G.dfs():
            if node not in nodes:
                continue
            if expressions and isinstance(node, ExpressionFusionParameters):
                all_ports[node] = self.report_expression(
                    dot,
                    G,
                    node,
                    anonymise=anonymise,
                    report_quantized=expressions == "quantized")
            else:
                num_in_edges = len(G.indexed_in_edges(node.name))
                num_out_edges = len(G.indexed_out_edges(node.name))
                ports = all_ports.setdefault(node, [None] * 2)
                names = self.build_nodebox(node,
                                           ports,
                                           num_in_edges,
                                           num_out_edges,
                                           anon=anonymise)
                dot.node(node.name,
                         nohtml(names),
                         shape='record',
                         xlabel=str(node.step_idx))
            for edge in G.in_edges(node.name):
                if edge.from_node not in nodes:
                    if not all_dims:
                        continue

                out_port, in_port = self.get_ports(all_ports, edge)
                if edge.from_node in nodes:
                    from_node_id = self.get_from_id(all_ports, edge, out_port)
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    dot.edge(from_node_id,
                             to_node_id,
                             xlabel=self.in_label(G, node, edge.to_idx,
                                                  quant_labels))
                else:
                    fake_name = f'fake_{fake_idx}'
                    fake_idx += 1
                    dot.node(fake_name, shape='point', fillcolor='black')
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    dot.edge(fake_name,
                             to_node_id,
                             xlabel=self.in_label(G, node, edge.to_idx,
                                                  quant_labels))
            if not all_dims:
                continue
            for edge_group in G.indexed_out_edges(node.name):
                if any(edge.to_node in nodes for edge in edge_group):
                    continue
                edge = edge_group[0]
                out_port, _ = self.get_ports(all_ports, edge)
                fake_name = f'fake_{fake_idx}'
                fake_idx += 1
                dot.node(fake_name,
                         shape='plaintext',
                         label=' ',
                         fillcolor='black')
                from_node_id = self.get_from_id(all_ports, edge, out_port)
                dot.edge(from_node_id,
                         fake_name,
                         xlabel=self.out_label(G, node, edge.from_idx,
                                               quant_labels))

        # dot = dot.unflatten(stagger=2)
        if filename:
            dot.render(filename, cleanup=True)
        if view:
            filename = tempfile.mktemp('.gv')
            dot.view(filename, cleanup=True, quiet=True)
        self.reset_name_cache()