Esempio n. 1
0
    def quantize_forward(self, G: NNGraph, edge_recs, result=None):
        if result is None:
            result = QuantizationSet()
        for node in [step['node'] for step in G.graph_state.steps]:
            LOG.debug("quantize forward %s", node.name)
            in_qs = self.get_in_qs(G, edge_recs, node)
            if isinstance(node, ConvFusionParameters):
                qrec, qrecs = self.quantize_fusion(G, node, in_qs)
                for node_id, fqrec in qrecs.items():
                    result[node_id] = fqrec
            elif isinstance(node, ConcatParameters):
                qrec = self.quantize_backward(G, result, edge_recs, node)
            else:
                qrec = self.calculate_q(
                    node, self._activation_stats.get(NodeId(node, None)),
                    self._filter_stats.get(NodeId(node, None)), in_qs,
                    self._min_qsnr, self._force_width)
            result[NodeId(node, None)] = qrec
            if not qrec:
                break

            for edges in G.indexed_out_edges(node.name):
                for edge in edges:
                    edge_recs[edge.params] = qrec.out_qs[edge.from_idx]
        return result
Esempio n. 2
0
    def quantize(self, G: NNGraph) -> OrderedDict:
        edge_recs = {}
        result = OrderedDict()
        for step in G.graph_state.steps:
            node = step['node']
            if isinstance(node, InputParameters):
                in_qs = []
            else:
                in_qs = [
                    edge_recs[edge.params]
                    for edge in G.indexed_in_edges(node.name)
                ]
            if isinstance(node, FusionParameters):
                fin_qs = in_qs
                for fnode in node.contained_nodes():
                    qrec = self.calculate_q(
                        fnode, self._activation_stats.get(NodeId(node, fnode)),
                        self._filter_stats.get(NodeId(node, fnode)), fin_qs,
                        self._min_qsnr, self._force_width)
                    result[NodeId(node, fnode)] = qrec
                    fin_qs = qrec.out_qs
                qrec = QuantizationRecord(in_qs=in_qs, out_qs=fin_qs)
            else:
                qrec = self.calculate_q(
                    node, self._activation_stats.get(NodeId(node, None)),
                    self._filter_stats.get(NodeId(node, None)), in_qs,
                    self._min_qsnr, self._force_width)
            result[NodeId(node, None)] = qrec
            if not qrec:
                break

            for edges in G.indexed_out_edges(node.name):
                for edge in edges:
                    edge_recs[edge.params] = qrec.out_qs[edge.from_idx]
        return result
Esempio n. 3
0
 def initialize_edge_recs(G: NNGraph, qrecs):
     '''Initialize edge rec dictionary to current quantization settings'''
     edge_recs = {}
     for node in [step['node'] for step in G.graph_state.steps]:
         nodeid = NodeId(node)
         qrec = qrecs[nodeid]
         for edges in G.indexed_out_edges(node.name):
             for edge in edges:
                 edge_recs[edge.params] = qrec.out_qs[edge.from_idx]
     return edge_recs
Esempio n. 4
0
    def quantize_forward(self, G: NNGraph, edge_recs, dtype=np.int8):
        for node in [step['node'] for step in G.graph_state.steps]:
            LOG.debug("quantize forward %s", node.name)
            in_qs = self.get_in_qs(G, edge_recs, node)
            if isinstance(node, (ConvFusionParameters, ActivationFusion)):
                qrec = self.quantize_fusion(G, node, in_qs, dtype)
            else:
                qrec = self.calculate_q(G,
                                        node,
                                        self._activation_stats.get(
                                            NodeId(node, None)),
                                        in_qs,
                                        dtype)
            self.qrecs[NodeId(node, None)] = qrec
            if not qrec:
                break

            for edges in G.indexed_out_edges(node.name):
                for edge in edges:
                    edge_recs[edge.params] = qrec.out_qs[edge.from_idx]
Esempio n. 5
0
    def propagate_forward(self, G: NNGraph, edge_recs, start_node,
                          new_out_qrec, result):
        '''Propagate a new output qrec at node start_node in the graph'''
        found_node = False
        for node in [step['node'] for step in G.graph_state.steps]:
            if found_node:
                LOG.debug("propagate forwards %s", node.name)
                in_qs = self.get_in_qs(G, edge_recs, node)
                if isinstance(node, ConvFusionParameters):
                    qrec, qrecs = self.quantize_fusion(G, node, in_qs)
                    for node_id, fqrec in qrecs.items():
                        result[node_id] = fqrec
                elif isinstance(node, ConcatParameters):
                    qrec = self.quantize_backward(G, result, edge_recs, node)
                else:
                    qrec = self.calculate_q(
                        node, self._activation_stats.get(NodeId(node, None)),
                        self._filter_stats.get(NodeId(node, None)), in_qs,
                        self._min_qsnr, self._force_width)
            else:
                if node == start_node:
                    found_node = True
                    qrec = self.quantize_backward(G,
                                                  result,
                                                  edge_recs,
                                                  node,
                                                  force_out=new_out_qrec)
                else:
                    continue

            result[NodeId(node, None)] = qrec
            if not qrec:
                break

            for edges in G.indexed_out_edges(node.name):
                for edge in edges:
                    edge_recs[edge.params] = qrec.out_qs[edge.from_idx]
    def report_graph(self,
                     G: NNGraph,
                     dot,
                     all_ports,
                     fake_idx,
                     nodes=None,
                     all_dims=False,
                     anonymise=False,
                     expressions=False,
                     qrecs=None,
                     fusions=False,
                     parent=None):
        if nodes is None:
            nodes = set(G.nodes())
        for node in G.dfs():
            if node not in nodes:
                continue
            if isinstance(node, (FusionInputParameters)):
                continue
            if expressions and isinstance(node, ExpressionFusionParameters):
                all_ports[node] = self.report_expression(
                    dot,
                    G,
                    node,
                    anonymise=anonymise,
                    report_quantized=expressions == "quantized")
            elif fusions and isinstance(node, FusionBase):
                all_ports[node] = self.report_fusion(dot,
                                                     G,
                                                     node,
                                                     all_ports,
                                                     fake_idx,
                                                     all_dims=all_dims,
                                                     anonymise=anonymise,
                                                     expressions=expressions,
                                                     qrecs=qrecs)

            else:
                num_in_edges = len(G.indexed_in_edges(node.name))
                num_out_edges = len(G.indexed_out_edges(node.name))
                ports = all_ports.setdefault(node, [None] * 2)
                if not isinstance(node, FusionOutputParameters):
                    names = self.build_nodebox(node,
                                               ports,
                                               num_in_edges,
                                               num_out_edges,
                                               anon=anonymise)
                    dot.node(
                        node.name,
                        nohtml(names),
                        shape='record',
                        xlabel=str(node.step_idx),
                        color="blue" if node.is_not_generated else "black")
            for edge in G.in_edges(node.name):
                if edge.from_node not in nodes:
                    if not all_dims:
                        continue

                out_port, in_port = self.get_ports(all_ports, edge)
                if edge.from_node in nodes:
                    from_node_id = self.get_from_id(all_ports, edge, out_port)
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    edge_label, edge_error = self.in_label(
                        G,
                        edge,
                        qrecs,
                        parent=parent,
                        from_node=not isinstance(edge.from_node,
                                                 FusionInputParameters),
                        to_node=not isinstance(edge.to_node,
                                               FusionOutputParameters))
                    dot.edge(from_node_id,
                             to_node_id,
                             xlabel=edge_label,
                             color="red" if edge_error else "black")
                else:
                    fake_name = f'fake_{fake_idx}'
                    fake_idx += 1
                    dot.node(fake_name, shape='point', fillcolor='black')
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    edge_label, edge_error = self.in_label(G,
                                                           edge,
                                                           qrecs,
                                                           parent=parent)
                    dot.edge(fake_name,
                             to_node_id,
                             xlabel=edge_label,
                             color="red" if edge_error else "black")
            if not all_dims:
                continue
            for edge_group in G.indexed_out_edges(node.name):
                if any(edge.to_node in nodes for edge in edge_group):
                    continue
                edge = edge_group[0]
                out_port, _ = self.get_ports(all_ports, edge)
                fake_name = f'fake_{fake_idx}'
                fake_idx += 1
                dot.node(fake_name,
                         shape='plaintext',
                         label=' ',
                         fillcolor='black')
                from_node_id = self.get_from_id(all_ports, edge, out_port)
                edge_label, edge_error = self.out_label(
                    G,
                    edge,
                    qrecs,
                    parent=parent,
                    from_node=not isinstance(edge.from_node,
                                             FusionInputParameters),
                    to_node=not isinstance(edge.to_node,
                                           FusionOutputParameters))
                dot.edge(from_node_id,
                         fake_name,
                         xlabel=edge_label,
                         color="red" if edge_error else "black")
Esempio n. 7
0
    def report(self,
               G: NNGraph,
               nodes=None,
               graph_format='PDF',
               all_dims=False,
               filename=None,
               view=True,
               anonymise=False,
               expressions=False,
               quant_labels=False):
        if nodes is None:
            nodes = set(G.nodes())

        self.init_name_cache()
        all_ports = {}
        graph_name = G.graphname if hasattr(G, 'graphname') else 'graph'
        dot = Digraph(comment=graph_name,
                      format=graph_format,
                      node_attr={'height': '.1'},
                      edge_attr={'fontsize': '10.0'})
        fake_idx = 0
        for node in G.dfs():
            if node not in nodes:
                continue
            if expressions and isinstance(node, ExpressionFusionParameters):
                all_ports[node] = self.report_expression(
                    dot,
                    G,
                    node,
                    anonymise=anonymise,
                    report_quantized=expressions == "quantized")
            else:
                num_in_edges = len(G.indexed_in_edges(node.name))
                num_out_edges = len(G.indexed_out_edges(node.name))
                ports = all_ports.setdefault(node, [None] * 2)
                names = self.build_nodebox(node,
                                           ports,
                                           num_in_edges,
                                           num_out_edges,
                                           anon=anonymise)
                dot.node(node.name,
                         nohtml(names),
                         shape='record',
                         xlabel=str(node.step_idx))
            for edge in G.in_edges(node.name):
                if edge.from_node not in nodes:
                    if not all_dims:
                        continue

                out_port, in_port = self.get_ports(all_ports, edge)
                if edge.from_node in nodes:
                    from_node_id = self.get_from_id(all_ports, edge, out_port)
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    dot.edge(from_node_id,
                             to_node_id,
                             xlabel=self.in_label(G, node, edge.to_idx,
                                                  quant_labels))
                else:
                    fake_name = f'fake_{fake_idx}'
                    fake_idx += 1
                    dot.node(fake_name, shape='point', fillcolor='black')
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    dot.edge(fake_name,
                             to_node_id,
                             xlabel=self.in_label(G, node, edge.to_idx,
                                                  quant_labels))
            if not all_dims:
                continue
            for edge_group in G.indexed_out_edges(node.name):
                if any(edge.to_node in nodes for edge in edge_group):
                    continue
                edge = edge_group[0]
                out_port, _ = self.get_ports(all_ports, edge)
                fake_name = f'fake_{fake_idx}'
                fake_idx += 1
                dot.node(fake_name,
                         shape='plaintext',
                         label=' ',
                         fillcolor='black')
                from_node_id = self.get_from_id(all_ports, edge, out_port)
                dot.edge(from_node_id,
                         fake_name,
                         xlabel=self.out_label(G, node, edge.from_idx,
                                               quant_labels))

        # dot = dot.unflatten(stagger=2)
        if filename:
            dot.render(filename, cleanup=True)
        if view:
            filename = tempfile.mktemp('.gv')
            dot.view(filename, cleanup=True, quiet=True)
        self.reset_name_cache()
Esempio n. 8
0
    def quantize_backward(self,
                          G: NNGraph,
                          result,
                          edge_recs,
                          node,
                          force_out=None):

        LOG.debug("quantize backwards %s", node.name)
        recalculated = False
        while True:
            in_qs = self.get_in_qs(G, edge_recs, node)
            if self.is_filter_node(node):
                if isinstance(node, ConvFusionParameters):
                    qrec, qrecs = self.quantize_fusion(G,
                                                       node,
                                                       in_qs,
                                                       force_out=force_out)
                    for node_id, fqrec in qrecs.items():
                        result[node_id] = fqrec
                else:
                    qrec = self.calculate_q(node,
                                            self._activation_stats.get(
                                                NodeId(node, None)),
                                            in_qs,
                                            self._force_width,
                                            force_out=force_out)

                if force_out and force_out.q is not None and qrec.out_qs[
                        0].q < force_out.q:
                    if recalculated:
                        raise NotImplementedError(
                            "no quantization solution found")
                    bits_to_gain = force_out.q - qrec.q
                    if bits_to_gain > in_qs[0].q:
                        raise NotImplementedError()
                    # Try to adjust the inputs to satisfy and then
                    # recalculate
                    pnode = G.in_edges(node.name)[0].from_node
                    self.quantize_backward(G,
                                           result,
                                           edge_recs,
                                           pnode,
                                           force_out=QType(bits=force_out.bits,
                                                           q=in_qs[0].q -
                                                           bits_to_gain,
                                                           signed=True))
            elif isinstance(node, ConcatParameters):
                assert not recalculated
                max_width = max(in_q.bits for in_q in in_qs)
                min_q = min(in_q.q for in_q in in_qs)
                if force_out:
                    if not self.satisfied(force_out.bits, max_width):
                        max_width = force_out.bits
                    if not self.satisfied(force_out.q, min_q):
                        min_q = force_out.q
                LOG.debug("normalizing concat to %s",
                          QType(bits=max_width, q=min_q, signed=True))
                for pidx, pnode in enumerate(
                    [edge.from_node for edge in G.in_edges(node.name)]):
                    pqrec = in_qs[pidx]
                    if pqrec.q != min_q or pqrec.bits != max_width:
                        self.quantize_backward(G,
                                               result,
                                               edge_recs,
                                               pnode,
                                               force_out=QType(bits=max_width,
                                                               q=min_q,
                                                               signed=True))
                o_q = QType(bits=max_width, q=min_q, signed=True)
                qrec = SymmetricQuantizationRecord(in_qs=self.get_in_qs(
                    G, edge_recs, node),
                                                   out_qs=[o_q])
            elif isinstance(node, SoftMaxParameters):
                raise NotImplementedError(
                    "softmax kernel cannot change width or q")
            else:
                if isinstance(node, ConvFusionParameters):
                    qrec, qrecs = self.quantize_fusion(G,
                                                       node,
                                                       in_qs,
                                                       force_out=force_out)
                    for node_id, fqrec in qrecs.items():
                        result[node_id] = fqrec
                else:
                    qrec = self.calculate_q(node,
                                            self._activation_stats.get(
                                                NodeId(node, None)),
                                            in_qs,
                                            self._force_width,
                                            force_out=force_out)
                o_q = qrec.out_qs[0]
                if not (self.satisfied(force_out.q, o_q.q)
                        and self.satisfied(force_out.bits, o_q.bits)):
                    if recalculated:
                        raise NotImplementedError(
                            "no quantization solution found")
                    if len(G.in_edges(node.name)) > 1:
                        raise NotImplementedError(
                            "Nodes with multiple input edges \
                            need custom handling")
                    pnode = G.in_edges(node.name)[0].from_node
                    self.quantize_backward(G,
                                           result,
                                           edge_recs,
                                           pnode,
                                           force_out=force_out)

            for edges in G.indexed_out_edges(node.name):
                for edge in edges:
                    edge_recs[edge.params] = qrec.out_qs[edge.from_idx]

            result[NodeId(node, None)] = qrec

            o_q = qrec.out_qs[0]
            if self.satisfied_force(force_out, o_q):
                break
            if recalculated:
                raise NotImplementedError("no quantization solution found")
            LOG.debug("recalculate %s", node.name)
            recalculated = True
        LOG.debug("back complete %s %s", node.name, qrec)
        return qrec