Exemplo n.º 1
0
    def quantize(self, G: NNGraph) -> OrderedDict:
        edge_recs = {}
        result = OrderedDict()
        for step in G.graph_state.steps:
            node = step['node']
            if isinstance(node, InputParameters):
                in_qs = []
            else:
                in_qs = [
                    edge_recs[edge.params]
                    for edge in G.indexed_in_edges(node.name)
                ]
            if isinstance(node, FusionParameters):
                fin_qs = in_qs
                for fnode in node.contained_nodes():
                    qrec = self.calculate_q(
                        fnode, self._activation_stats.get(NodeId(node, fnode)),
                        self._filter_stats.get(NodeId(node, fnode)), fin_qs,
                        self._min_qsnr, self._force_width)
                    result[NodeId(node, fnode)] = qrec
                    fin_qs = qrec.out_qs
                qrec = QuantizationRecord(in_qs=in_qs, out_qs=fin_qs)
            else:
                qrec = self.calculate_q(
                    node, self._activation_stats.get(NodeId(node, None)),
                    self._filter_stats.get(NodeId(node, None)), in_qs,
                    self._min_qsnr, self._force_width)
            result[NodeId(node, None)] = qrec
            if not qrec:
                break

            for edges in G.indexed_out_edges(node.name):
                for edge in edges:
                    edge_recs[edge.params] = qrec.out_qs[edge.from_idx]
        return result
    def report_graph(self,
                     G: NNGraph,
                     dot,
                     all_ports,
                     fake_idx,
                     nodes=None,
                     all_dims=False,
                     anonymise=False,
                     expressions=False,
                     qrecs=None,
                     fusions=False,
                     parent=None):
        if nodes is None:
            nodes = set(G.nodes())
        for node in G.dfs():
            if node not in nodes:
                continue
            if isinstance(node, (FusionInputParameters)):
                continue
            if expressions and isinstance(node, ExpressionFusionParameters):
                all_ports[node] = self.report_expression(
                    dot,
                    G,
                    node,
                    anonymise=anonymise,
                    report_quantized=expressions == "quantized")
            elif fusions and isinstance(node, FusionBase):
                all_ports[node] = self.report_fusion(dot,
                                                     G,
                                                     node,
                                                     all_ports,
                                                     fake_idx,
                                                     all_dims=all_dims,
                                                     anonymise=anonymise,
                                                     expressions=expressions,
                                                     qrecs=qrecs)

            else:
                num_in_edges = len(G.indexed_in_edges(node.name))
                num_out_edges = len(G.indexed_out_edges(node.name))
                ports = all_ports.setdefault(node, [None] * 2)
                if not isinstance(node, FusionOutputParameters):
                    names = self.build_nodebox(node,
                                               ports,
                                               num_in_edges,
                                               num_out_edges,
                                               anon=anonymise)
                    dot.node(
                        node.name,
                        nohtml(names),
                        shape='record',
                        xlabel=str(node.step_idx),
                        color="blue" if node.is_not_generated else "black")
            for edge in G.in_edges(node.name):
                if edge.from_node not in nodes:
                    if not all_dims:
                        continue

                out_port, in_port = self.get_ports(all_ports, edge)
                if edge.from_node in nodes:
                    from_node_id = self.get_from_id(all_ports, edge, out_port)
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    edge_label, edge_error = self.in_label(
                        G,
                        edge,
                        qrecs,
                        parent=parent,
                        from_node=not isinstance(edge.from_node,
                                                 FusionInputParameters),
                        to_node=not isinstance(edge.to_node,
                                               FusionOutputParameters))
                    dot.edge(from_node_id,
                             to_node_id,
                             xlabel=edge_label,
                             color="red" if edge_error else "black")
                else:
                    fake_name = f'fake_{fake_idx}'
                    fake_idx += 1
                    dot.node(fake_name, shape='point', fillcolor='black')
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    edge_label, edge_error = self.in_label(G,
                                                           edge,
                                                           qrecs,
                                                           parent=parent)
                    dot.edge(fake_name,
                             to_node_id,
                             xlabel=edge_label,
                             color="red" if edge_error else "black")
            if not all_dims:
                continue
            for edge_group in G.indexed_out_edges(node.name):
                if any(edge.to_node in nodes for edge in edge_group):
                    continue
                edge = edge_group[0]
                out_port, _ = self.get_ports(all_ports, edge)
                fake_name = f'fake_{fake_idx}'
                fake_idx += 1
                dot.node(fake_name,
                         shape='plaintext',
                         label=' ',
                         fillcolor='black')
                from_node_id = self.get_from_id(all_ports, edge, out_port)
                edge_label, edge_error = self.out_label(
                    G,
                    edge,
                    qrecs,
                    parent=parent,
                    from_node=not isinstance(edge.from_node,
                                             FusionInputParameters),
                    to_node=not isinstance(edge.to_node,
                                           FusionOutputParameters))
                dot.edge(from_node_id,
                         fake_name,
                         xlabel=edge_label,
                         color="red" if edge_error else "black")
Exemplo n.º 3
0
    def report(self,
               G: NNGraph,
               nodes=None,
               graph_format='PDF',
               all_dims=False,
               filename=None,
               view=True,
               anonymise=False,
               expressions=False,
               quant_labels=False):
        if nodes is None:
            nodes = set(G.nodes())

        self.init_name_cache()
        all_ports = {}
        graph_name = G.graphname if hasattr(G, 'graphname') else 'graph'
        dot = Digraph(comment=graph_name,
                      format=graph_format,
                      node_attr={'height': '.1'},
                      edge_attr={'fontsize': '10.0'})
        fake_idx = 0
        for node in G.dfs():
            if node not in nodes:
                continue
            if expressions and isinstance(node, ExpressionFusionParameters):
                all_ports[node] = self.report_expression(
                    dot,
                    G,
                    node,
                    anonymise=anonymise,
                    report_quantized=expressions == "quantized")
            else:
                num_in_edges = len(G.indexed_in_edges(node.name))
                num_out_edges = len(G.indexed_out_edges(node.name))
                ports = all_ports.setdefault(node, [None] * 2)
                names = self.build_nodebox(node,
                                           ports,
                                           num_in_edges,
                                           num_out_edges,
                                           anon=anonymise)
                dot.node(node.name,
                         nohtml(names),
                         shape='record',
                         xlabel=str(node.step_idx))
            for edge in G.in_edges(node.name):
                if edge.from_node not in nodes:
                    if not all_dims:
                        continue

                out_port, in_port = self.get_ports(all_ports, edge)
                if edge.from_node in nodes:
                    from_node_id = self.get_from_id(all_ports, edge, out_port)
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    dot.edge(from_node_id,
                             to_node_id,
                             xlabel=self.in_label(G, node, edge.to_idx,
                                                  quant_labels))
                else:
                    fake_name = f'fake_{fake_idx}'
                    fake_idx += 1
                    dot.node(fake_name, shape='point', fillcolor='black')
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    dot.edge(fake_name,
                             to_node_id,
                             xlabel=self.in_label(G, node, edge.to_idx,
                                                  quant_labels))
            if not all_dims:
                continue
            for edge_group in G.indexed_out_edges(node.name):
                if any(edge.to_node in nodes for edge in edge_group):
                    continue
                edge = edge_group[0]
                out_port, _ = self.get_ports(all_ports, edge)
                fake_name = f'fake_{fake_idx}'
                fake_idx += 1
                dot.node(fake_name,
                         shape='plaintext',
                         label=' ',
                         fillcolor='black')
                from_node_id = self.get_from_id(all_ports, edge, out_port)
                dot.edge(from_node_id,
                         fake_name,
                         xlabel=self.out_label(G, node, edge.from_idx,
                                               quant_labels))

        # dot = dot.unflatten(stagger=2)
        if filename:
            dot.render(filename, cleanup=True)
        if view:
            filename = tempfile.mktemp('.gv')
            dot.view(filename, cleanup=True, quiet=True)
        self.reset_name_cache()