Exemplo n.º 1
0
def fix_split_in_edges(G: NNGraph):
    for split in [node for node in G.nodes() if isinstance(node, SplitParameters)]:
        in_edge = G.in_edges(split.name)[0]
        if in_edge.to_idx == 0:
            continue
        G.remove_edge(in_edge)
        G.add_edge(NNEdge(in_edge.from_node, in_edge.to_node, from_idx=in_edge.from_idx))
Exemplo n.º 2
0
def propagate_downwards(G: NNGraph):
    for node in G.dfs():
        # First propagate the in dim hints to the out dim hints
        # Any node that does not want this to happen should set its out dim hints

        if node.in_dims_hint is not None:
            if isinstance(node, ReshapeParameters):
                if len(node.old_shape) == len(node.in_dims_hint[0]):
                    LOG.debug("set reshape %s in dims hint %s", node.name,
                              node.in_dims_hint[0])
                    node.old_shape.apply_naming_hints(node.in_dims_hint[0])
            elif isinstance(node, GlobalPoolParameters):
                if node.keep_dims:
                    node.out_dims_hint = deepcopy(node.in_dims_hint)
            elif isinstance(node, MatrixBroadcastedLinearOpParameters):
                max_hint = None
                for hint in node.in_dims_hint:
                    if hint is not None and (max_hint is None
                                             or len(hint) > len(max_hint)):
                        max_hint = hint
                if max_hint is not None:
                    node.out_dims_hint = [max_hint]
            elif isinstance(node, ConcatParameters):
                # if any incoming edge of the concat doesn't have a hint
                # set it the same as the others
                any_in_hint = next(
                    (hint for hint in node.in_dims_hint if hint is not None),
                    None)
                if any_in_hint:
                    LOG.debug("set concat %s in dims hint %s", node.name,
                              any_in_hint)
                    for edge in G.in_edges(node.name):
                        if not node.in_dims_hint[edge.to_idx]:
                            node.in_dims_hint[edge.to_idx] = any_in_hint
                    node.out_dims_hint = [any_in_hint]
            else:
                if node.out_dims_hint is None:
                    node.out_dims_hint = deepcopy(node.in_dims_hint)

        # if we have an out dim hint then propagate it to downstream nodes
        if node.out_dims_hint is not None:
            LOG.debug("propagate down hint from %s", node.name)
            for edge in G.out_edges(node.name):
                hint = node.out_dims_hint[edge.from_idx]
                if hint is None:
                    continue
                if edge.to_node.in_dims_hint is None:
                    edge.to_node.in_dims_hint = SparseList()
                if edge.to_node.in_dims_hint[edge.to_idx] is None:
                    edge.to_node.in_dims_hint[edge.to_idx] = hint
Exemplo n.º 3
0
def propagate_upwards(G: NNGraph):
    for node in G.dfs(reverse=True):
        # First propagate the out dim hints to the in dim hints
        # Any node that does not want this to happen should set its in dim hints

        if node.out_dims_hint is not None:
            if isinstance(node, ReshapeParameters):
                if len(node.shape) < len(node.out_dims_hint[0]):
                    node.shape = Dim.unnamed((
                        [1] * (len(node.out_dims_hint[0]) - len(node.shape))) +
                                             node.shape.shape)
                node.shape.apply_naming_hints(node.out_dims_hint[0])
                if node.in_dims_hint is None:
                    node.in_dims_hint = SparseList(
                        [["%s" % i for i in range(len(node.old_shape))]])
            elif isinstance(node, MatrixBroadcastedLinearOpParameters):
                node.in_dims_hint = [node.out_dims_hint[0]] * 2
            elif isinstance(node, MatrixMulParameters):
                continue
            elif isinstance(node, GlobalPoolParameters):
                if node.keep_dims:
                    node.in_dims_hint = deepcopy(node.out_dims_hint)
            elif isinstance(
                    node, ConstantInputParameters) and not node.dims.is_named:
                node.dims.apply_naming_hints(node.out_dims_hint[0])
            else:
                if node.in_dims_hint is None:
                    node.in_dims_hint = deepcopy(node.out_dims_hint)

        # if we have an in dim hint then propagate it to upstream nodes
        if node.in_dims_hint is not None:
            for edge in G.in_edges(node.name):
                hint = node.in_dims_hint[edge.to_idx]
                if hint is None:
                    continue
                if edge.from_node.out_dims_hint is None:
                    edge.from_node.out_dims_hint = SparseList()
                if edge.from_node.out_dims_hint[edge.from_idx] is None:
                    edge.from_node.out_dims_hint[edge.from_idx] = hint
                    if isinstance(edge.from_node, InputParameters):
                        assert edge.from_idx == 0, "input node should only have one output"
                        dims_len = len(edge.from_node.dims)
                        hint_len = len(hint)
                        if dims_len < hint_len:
                            edge.from_node.dims = Dim.unnamed(
                                [1] * (hint_len - dims_len) +
                                edge.from_node.dims.shape)
Exemplo n.º 4
0
def propagate_upwards(G: NNGraph):
    for node in G.dfs(reverse=True):
        # First propagate the out dim hints to the in dim hints
        # Any node that does not want this to happen should set its in dim hints

        if node.out_dims_hint is not None:
            if isinstance(node, ReshapeParameters):
                assert len(node.shape) == len(node.out_dims_hint[0])
                node.shape.apply_naming_hints(node.out_dims_hint[0])
                if node.in_dims_hint is None:
                    node.in_dims_hint = SparseList([["%s" % i for i in range(len(node.old_shape))]])
            else:
                if node.in_dims_hint is None:
                    node.in_dims_hint = deepcopy(node.out_dims_hint)

        # if we have an in dim hint then propagate it to upstream nodes
        if node.in_dims_hint is not None:
            for edge in G.in_edges(node.name):
                hint = node.in_dims_hint[edge.to_idx]
                if edge.from_node.out_dims_hint is None:
                    edge.from_node.out_dims_hint = SparseList()
                if edge.from_node.out_dims_hint[edge.from_idx] is None:
                    edge.from_node.out_dims_hint[edge.from_idx] = hint
    def report_graph(self,
                     G: NNGraph,
                     dot,
                     all_ports,
                     fake_idx,
                     nodes=None,
                     all_dims=False,
                     anonymise=False,
                     expressions=False,
                     qrecs=None,
                     fusions=False,
                     parent=None):
        if nodes is None:
            nodes = set(G.nodes())
        for node in G.dfs():
            if node not in nodes:
                continue
            if isinstance(node, (FusionInputParameters)):
                continue
            if expressions and isinstance(node, ExpressionFusionParameters):
                all_ports[node] = self.report_expression(
                    dot,
                    G,
                    node,
                    anonymise=anonymise,
                    report_quantized=expressions == "quantized")
            elif fusions and isinstance(node, FusionBase):
                all_ports[node] = self.report_fusion(dot,
                                                     G,
                                                     node,
                                                     all_ports,
                                                     fake_idx,
                                                     all_dims=all_dims,
                                                     anonymise=anonymise,
                                                     expressions=expressions,
                                                     qrecs=qrecs)

            else:
                num_in_edges = len(G.indexed_in_edges(node.name))
                num_out_edges = len(G.indexed_out_edges(node.name))
                ports = all_ports.setdefault(node, [None] * 2)
                if not isinstance(node, FusionOutputParameters):
                    names = self.build_nodebox(node,
                                               ports,
                                               num_in_edges,
                                               num_out_edges,
                                               anon=anonymise)
                    dot.node(
                        node.name,
                        nohtml(names),
                        shape='record',
                        xlabel=str(node.step_idx),
                        color="blue" if node.is_not_generated else "black")
            for edge in G.in_edges(node.name):
                if edge.from_node not in nodes:
                    if not all_dims:
                        continue

                out_port, in_port = self.get_ports(all_ports, edge)
                if edge.from_node in nodes:
                    from_node_id = self.get_from_id(all_ports, edge, out_port)
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    edge_label, edge_error = self.in_label(
                        G,
                        edge,
                        qrecs,
                        parent=parent,
                        from_node=not isinstance(edge.from_node,
                                                 FusionInputParameters),
                        to_node=not isinstance(edge.to_node,
                                               FusionOutputParameters))
                    dot.edge(from_node_id,
                             to_node_id,
                             xlabel=edge_label,
                             color="red" if edge_error else "black")
                else:
                    fake_name = f'fake_{fake_idx}'
                    fake_idx += 1
                    dot.node(fake_name, shape='point', fillcolor='black')
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    edge_label, edge_error = self.in_label(G,
                                                           edge,
                                                           qrecs,
                                                           parent=parent)
                    dot.edge(fake_name,
                             to_node_id,
                             xlabel=edge_label,
                             color="red" if edge_error else "black")
            if not all_dims:
                continue
            for edge_group in G.indexed_out_edges(node.name):
                if any(edge.to_node in nodes for edge in edge_group):
                    continue
                edge = edge_group[0]
                out_port, _ = self.get_ports(all_ports, edge)
                fake_name = f'fake_{fake_idx}'
                fake_idx += 1
                dot.node(fake_name,
                         shape='plaintext',
                         label=' ',
                         fillcolor='black')
                from_node_id = self.get_from_id(all_ports, edge, out_port)
                edge_label, edge_error = self.out_label(
                    G,
                    edge,
                    qrecs,
                    parent=parent,
                    from_node=not isinstance(edge.from_node,
                                             FusionInputParameters),
                    to_node=not isinstance(edge.to_node,
                                           FusionOutputParameters))
                dot.edge(from_node_id,
                         fake_name,
                         xlabel=edge_label,
                         color="red" if edge_error else "black")
Exemplo n.º 6
0
    def report(self,
               G: NNGraph,
               nodes=None,
               graph_format='PDF',
               all_dims=False,
               filename=None,
               view=True,
               anonymise=False,
               expressions=False,
               quant_labels=False):
        if nodes is None:
            nodes = set(G.nodes())

        self.init_name_cache()
        all_ports = {}
        graph_name = G.graphname if hasattr(G, 'graphname') else 'graph'
        dot = Digraph(comment=graph_name,
                      format=graph_format,
                      node_attr={'height': '.1'},
                      edge_attr={'fontsize': '10.0'})
        fake_idx = 0
        for node in G.dfs():
            if node not in nodes:
                continue
            if expressions and isinstance(node, ExpressionFusionParameters):
                all_ports[node] = self.report_expression(
                    dot,
                    G,
                    node,
                    anonymise=anonymise,
                    report_quantized=expressions == "quantized")
            else:
                num_in_edges = len(G.indexed_in_edges(node.name))
                num_out_edges = len(G.indexed_out_edges(node.name))
                ports = all_ports.setdefault(node, [None] * 2)
                names = self.build_nodebox(node,
                                           ports,
                                           num_in_edges,
                                           num_out_edges,
                                           anon=anonymise)
                dot.node(node.name,
                         nohtml(names),
                         shape='record',
                         xlabel=str(node.step_idx))
            for edge in G.in_edges(node.name):
                if edge.from_node not in nodes:
                    if not all_dims:
                        continue

                out_port, in_port = self.get_ports(all_ports, edge)
                if edge.from_node in nodes:
                    from_node_id = self.get_from_id(all_ports, edge, out_port)
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    dot.edge(from_node_id,
                             to_node_id,
                             xlabel=self.in_label(G, node, edge.to_idx,
                                                  quant_labels))
                else:
                    fake_name = f'fake_{fake_idx}'
                    fake_idx += 1
                    dot.node(fake_name, shape='point', fillcolor='black')
                    to_node_id = self.get_to_id(all_ports, edge, in_port)
                    dot.edge(fake_name,
                             to_node_id,
                             xlabel=self.in_label(G, node, edge.to_idx,
                                                  quant_labels))
            if not all_dims:
                continue
            for edge_group in G.indexed_out_edges(node.name):
                if any(edge.to_node in nodes for edge in edge_group):
                    continue
                edge = edge_group[0]
                out_port, _ = self.get_ports(all_ports, edge)
                fake_name = f'fake_{fake_idx}'
                fake_idx += 1
                dot.node(fake_name,
                         shape='plaintext',
                         label=' ',
                         fillcolor='black')
                from_node_id = self.get_from_id(all_ports, edge, out_port)
                dot.edge(from_node_id,
                         fake_name,
                         xlabel=self.out_label(G, node, edge.from_idx,
                                               quant_labels))

        # dot = dot.unflatten(stagger=2)
        if filename:
            dot.render(filename, cleanup=True)
        if view:
            filename = tempfile.mktemp('.gv')
            dot.view(filename, cleanup=True, quiet=True)
        self.reset_name_cache()
Exemplo n.º 7
0
    def quantize_backward(self,
                          G: NNGraph,
                          result,
                          edge_recs,
                          node,
                          force_out=None):

        LOG.debug("quantize backwards %s", node.name)
        recalculated = False
        while True:
            in_qs = self.get_in_qs(G, edge_recs, node)
            if self.is_filter_node(node):
                if isinstance(node, ConvFusionParameters):
                    qrec, qrecs = self.quantize_fusion(G,
                                                       node,
                                                       in_qs,
                                                       force_out=force_out)
                    for node_id, fqrec in qrecs.items():
                        result[node_id] = fqrec
                else:
                    qrec = self.calculate_q(node,
                                            self._activation_stats.get(
                                                NodeId(node, None)),
                                            in_qs,
                                            self._force_width,
                                            force_out=force_out)

                if force_out and force_out.q is not None and qrec.out_qs[
                        0].q < force_out.q:
                    if recalculated:
                        raise NotImplementedError(
                            "no quantization solution found")
                    bits_to_gain = force_out.q - qrec.q
                    if bits_to_gain > in_qs[0].q:
                        raise NotImplementedError()
                    # Try to adjust the inputs to satisfy and then
                    # recalculate
                    pnode = G.in_edges(node.name)[0].from_node
                    self.quantize_backward(G,
                                           result,
                                           edge_recs,
                                           pnode,
                                           force_out=QType(bits=force_out.bits,
                                                           q=in_qs[0].q -
                                                           bits_to_gain,
                                                           signed=True))
            elif isinstance(node, ConcatParameters):
                assert not recalculated
                max_width = max(in_q.bits for in_q in in_qs)
                min_q = min(in_q.q for in_q in in_qs)
                if force_out:
                    if not self.satisfied(force_out.bits, max_width):
                        max_width = force_out.bits
                    if not self.satisfied(force_out.q, min_q):
                        min_q = force_out.q
                LOG.debug("normalizing concat to %s",
                          QType(bits=max_width, q=min_q, signed=True))
                for pidx, pnode in enumerate(
                    [edge.from_node for edge in G.in_edges(node.name)]):
                    pqrec = in_qs[pidx]
                    if pqrec.q != min_q or pqrec.bits != max_width:
                        self.quantize_backward(G,
                                               result,
                                               edge_recs,
                                               pnode,
                                               force_out=QType(bits=max_width,
                                                               q=min_q,
                                                               signed=True))
                o_q = QType(bits=max_width, q=min_q, signed=True)
                qrec = SymmetricQuantizationRecord(in_qs=self.get_in_qs(
                    G, edge_recs, node),
                                                   out_qs=[o_q])
            elif isinstance(node, SoftMaxParameters):
                raise NotImplementedError(
                    "softmax kernel cannot change width or q")
            else:
                if isinstance(node, ConvFusionParameters):
                    qrec, qrecs = self.quantize_fusion(G,
                                                       node,
                                                       in_qs,
                                                       force_out=force_out)
                    for node_id, fqrec in qrecs.items():
                        result[node_id] = fqrec
                else:
                    qrec = self.calculate_q(node,
                                            self._activation_stats.get(
                                                NodeId(node, None)),
                                            in_qs,
                                            self._force_width,
                                            force_out=force_out)
                o_q = qrec.out_qs[0]
                if not (self.satisfied(force_out.q, o_q.q)
                        and self.satisfied(force_out.bits, o_q.bits)):
                    if recalculated:
                        raise NotImplementedError(
                            "no quantization solution found")
                    if len(G.in_edges(node.name)) > 1:
                        raise NotImplementedError(
                            "Nodes with multiple input edges \
                            need custom handling")
                    pnode = G.in_edges(node.name)[0].from_node
                    self.quantize_backward(G,
                                           result,
                                           edge_recs,
                                           pnode,
                                           force_out=force_out)

            for edges in G.indexed_out_edges(node.name):
                for edge in edges:
                    edge_recs[edge.params] = qrec.out_qs[edge.from_idx]

            result[NodeId(node, None)] = qrec

            o_q = qrec.out_qs[0]
            if self.satisfied_force(force_out, o_q):
                break
            if recalculated:
                raise NotImplementedError("no quantization solution found")
            LOG.debug("recalculate %s", node.name)
            recalculated = True
        LOG.debug("back complete %s %s", node.name, qrec)
        return qrec