示例#1
0
def fix_split_in_edges(G: NNGraph):
    for split in [node for node in G.nodes() if isinstance(node, SplitParameters)]:
        in_edge = G.in_edges(split.name)[0]
        if in_edge.to_idx == 0:
            continue
        G.remove_edge(in_edge)
        G.add_edge(NNEdge(in_edge.from_node, in_edge.to_node, from_idx=in_edge.from_idx))
示例#2
0
    def create_graph(self, filename, opts):
        opts = self.get_opts(opts)
        self._name_cache = {}
        add_sys_path(os.path.dirname(__file__))
        buf = open(filename, "rb").read()
        model = Model.GetRootAsModel(buf, 0)
        LOG.info("Importing TFLITE model version %s", model.Version())
        check(model.Version() == 3, "Only support version 3 graphs at present")
        if model.SubgraphsLength() > 1:
            LOG.warning(
                "nntool only supports one subgraph. There may be errors loading this graph."
            )
        G = NNGraph(model=model,
                    filename=filename,
                    name=opts.get('name'),
                    constant_store=ConstantStore())
        if opts.get('load_quantization'):
            G.quantization = QuantizationSet()
            G.has_quantized_parameters = True
            G.quantization.schemes_present.add('SQ8')

        self._import_tflite_graph(G, model, opts)
        clean_dangling_nodes(G)
        fix_split_in_edges(G)
        MatchDuplicateConstants().match(G)
        # DrawGraphReporter().report(G)
        G.add_dimensions()
        remove_concats(G)
        if opts['remove_quantize_ops']:
            RemoveQuantizeOperators().match(G)
            G.add_dimensions()

        if opts.get('load_quantization'):
            # get rid of qrecs on nodes that were not used
            to_remove = []
            for nid in G.quantization:
                if nid.node_name not in G:
                    to_remove.append(nid)
            for nid in to_remove:
                del G.quantization[nid]
            nodes_with_bad_quantization = self.find_nodes_with_bad_quantization(
                G)
            quantizer = UnifiedQuantizer.from_quantized_graph(G)
            # check for quantization problems
            # 1) need to force softmax/Sigmoid input to POW2 quantization
            # 2) need to check that all concats and splits have same input and
            #    output quantization
            # 3) Need to check that all nodes have qrecs and that they are consistent
            nodes_with_bad_quantization |= set(
                G.nodes(node_classes=(ConcatParameters, SoftMaxParameters,
                                      SplitParameters,
                                      SigmoidActivationParameters)))
            G.quantization = quantizer.quantize(
                G, start_nodes=nodes_with_bad_quantization)
            G.add_dimensions()

        return G
示例#3
0
def create_graph(filename, opts):
    cfg = read_cfg(filename)
    out_graph = NNGraph(model=cfg,
                        filename=filename,
                        name=opts.get('name'),
                        value_cache=opts.get('value_cache'))
    create_subgraph(out_graph, cfg)
    leaf_nodes = list([n for n in out_graph.nodes()\
        if out_graph.out_degree(n) == 0 and out_graph.in_degree(n) > 0])
    for node in leaf_nodes:
        out_graph.add_edge(node, out_graph.add_output(), order=0)
    return out_graph
示例#4
0
    def report(self, G: NNGraph, nodes=None) -> Tabular:
        if nodes is None:
            nodes = G.nodes()

        nodes = sorted(nodes, key=lambda x: x.step_idx)
        start_step = nodes[0].step_idx
        end_step = nodes[-1].step_idx

        steps = G.graph_state.steps
        liveness = G.graph_state.liveness
        first_node = steps[start_step]['node']
        active_order = "x".join(first_node.out_dims[0].order)
        tab = Tabular()
        self.do_headers(active_order, tab)

        max_active = 0
        tot_params = 0
        tot_ops = 0

        for i, node, active, params_size, ops in graph_walk(steps, liveness):
            if node.step_idx < start_step or node.step_idx > end_step:
                continue

            tot_params += params_size
            if ops:
                tot_ops += ops
            if active > max_active:
                max_active = active

            if self._show_constants or not isinstance(node,
                                                      ConstantInputParameters):
                self.do_operation(node, G, tab, i, active, params_size, ops)

        if start_step != end_step:
            self.check_do_totals(tab, max_active, tot_params, tot_ops)
        return tab
    def report_graph(self,
                     G: NNGraph,
                     dot,
                     all_ports,
                     fake_idx,
                     nodes=None,
                     all_dims=False,
                     anonymise=False,
                     expressions=False,
                     qrecs=None,
                     fusions=False,
                     parent=None):
        if nodes is None:
            nodes = set(G.nodes())
        for node in G.dfs():
            if node not in nodes:
                continue
            if isinstance(node, (FusionInputParameters)):
                continue
            if expressions and isinstance(node, ExpressionFusionParameters):
                all_ports[node] = self.report_expression(
                    dot,
                    G,
                    node,
                    anonymise=anonymise,
                    report_quantized=expressions == "quantized")
            elif fusions and isinstance(node, FusionBase):
                all_ports[node] = self.report_fusion(dot,
                                                     G,
                                                     node,
                                                     all_ports,
                                                     fake_idx,
                                                     all_dims=all_dims,
                                                     anonymise=anonymise,
                                                     expressions=expressions,
                                                     qrecs=qrecs)

            else:
                num_in_edges = len(G.indexed_in_edges(node.name))
                num_out_edges = len(G.indexed_out_edges(node.name))
                ports = all_ports.setdefault(node, [None] * 2)
                if not isinstance(node, FusionOutputParameters):
                    names = self.build_nodebox(node,
                                               ports,
                                               num_in_edges,
                                               num_out_edges,
                                               anon=anonymise)
                    dot.node(
                        node.name,
                        nohtml(names),
                        shape='record',
                        xlabel=str(node.step_idx),
                        color="blue" if node.is_not_generated else "black")
            for edge in G.in_edges(node.name):
                if edge.from_node not in nodes:
                    if not all_dims:
                        continue

                out_port, in_port = self.get_ports(all_ports, edge)
                if edge.from_node in nodes:
                    from_node_id = self.get_from_id(all_ports, edge, out_port)
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    edge_label, edge_error = self.in_label(
                        G,
                        edge,
                        qrecs,
                        parent=parent,
                        from_node=not isinstance(edge.from_node,
                                                 FusionInputParameters),
                        to_node=not isinstance(edge.to_node,
                                               FusionOutputParameters))
                    dot.edge(from_node_id,
                             to_node_id,
                             xlabel=edge_label,
                             color="red" if edge_error else "black")
                else:
                    fake_name = f'fake_{fake_idx}'
                    fake_idx += 1
                    dot.node(fake_name, shape='point', fillcolor='black')
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    edge_label, edge_error = self.in_label(G,
                                                           edge,
                                                           qrecs,
                                                           parent=parent)
                    dot.edge(fake_name,
                             to_node_id,
                             xlabel=edge_label,
                             color="red" if edge_error else "black")
            if not all_dims:
                continue
            for edge_group in G.indexed_out_edges(node.name):
                if any(edge.to_node in nodes for edge in edge_group):
                    continue
                edge = edge_group[0]
                out_port, _ = self.get_ports(all_ports, edge)
                fake_name = f'fake_{fake_idx}'
                fake_idx += 1
                dot.node(fake_name,
                         shape='plaintext',
                         label=' ',
                         fillcolor='black')
                from_node_id = self.get_from_id(all_ports, edge, out_port)
                edge_label, edge_error = self.out_label(
                    G,
                    edge,
                    qrecs,
                    parent=parent,
                    from_node=not isinstance(edge.from_node,
                                             FusionInputParameters),
                    to_node=not isinstance(edge.to_node,
                                           FusionOutputParameters))
                dot.edge(from_node_id,
                         fake_name,
                         xlabel=edge_label,
                         color="red" if edge_error else "black")
示例#6
0
    def report(self,
               G: NNGraph,
               nodes=None,
               graph_format='PDF',
               all_dims=False,
               filename=None,
               view=True,
               anonymise=False,
               expressions=False,
               quant_labels=False):
        if nodes is None:
            nodes = set(G.nodes())

        self.init_name_cache()
        all_ports = {}
        graph_name = G.graphname if hasattr(G, 'graphname') else 'graph'
        dot = Digraph(comment=graph_name,
                      format=graph_format,
                      node_attr={'height': '.1'},
                      edge_attr={'fontsize': '10.0'})
        fake_idx = 0
        for node in G.dfs():
            if node not in nodes:
                continue
            if expressions and isinstance(node, ExpressionFusionParameters):
                all_ports[node] = self.report_expression(
                    dot,
                    G,
                    node,
                    anonymise=anonymise,
                    report_quantized=expressions == "quantized")
            else:
                num_in_edges = len(G.indexed_in_edges(node.name))
                num_out_edges = len(G.indexed_out_edges(node.name))
                ports = all_ports.setdefault(node, [None] * 2)
                names = self.build_nodebox(node,
                                           ports,
                                           num_in_edges,
                                           num_out_edges,
                                           anon=anonymise)
                dot.node(node.name,
                         nohtml(names),
                         shape='record',
                         xlabel=str(node.step_idx))
            for edge in G.in_edges(node.name):
                if edge.from_node not in nodes:
                    if not all_dims:
                        continue

                out_port, in_port = self.get_ports(all_ports, edge)
                if edge.from_node in nodes:
                    from_node_id = self.get_from_id(all_ports, edge, out_port)
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    dot.edge(from_node_id,
                             to_node_id,
                             xlabel=self.in_label(G, node, edge.to_idx,
                                                  quant_labels))
                else:
                    fake_name = f'fake_{fake_idx}'
                    fake_idx += 1
                    dot.node(fake_name, shape='point', fillcolor='black')
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    dot.edge(fake_name,
                             to_node_id,
                             xlabel=self.in_label(G, node, edge.to_idx,
                                                  quant_labels))
            if not all_dims:
                continue
            for edge_group in G.indexed_out_edges(node.name):
                if any(edge.to_node in nodes for edge in edge_group):
                    continue
                edge = edge_group[0]
                out_port, _ = self.get_ports(all_ports, edge)
                fake_name = f'fake_{fake_idx}'
                fake_idx += 1
                dot.node(fake_name,
                         shape='plaintext',
                         label=' ',
                         fillcolor='black')
                from_node_id = self.get_from_id(all_ports, edge, out_port)
                dot.edge(from_node_id,
                         fake_name,
                         xlabel=self.out_label(G, node, edge.from_idx,
                                               quant_labels))

        # dot = dot.unflatten(stagger=2)
        if filename:
            dot.render(filename, cleanup=True)
        if view:
            filename = tempfile.mktemp('.gv')
            dot.view(filename, cleanup=True, quiet=True)
        self.reset_name_cache()