예제 #1
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False

        for node in G.nodes(node_classes=tuple(VALID_FUSIONS.keys())):
            node_list = self.get_node_list(G, node,
                                           FusionMatch(self._default_ktype))
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for snode in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(
                        NNEdge(from_node=last_node, to_node=snode))
                last_node = snode
            # assumption here is that the first node could have multiple inputs but definitely has only
            # one output
            input_mapping = [[
                (node_list.node, idx)
            ] for idx in range(G.num_in_edges(node_list.node.name))]
            output_mapping = [(last_node, 0)]
            pnode = node_list.fusions_class(node_list.node.name + '_fusion',
                                            fusion_type=node_list.fusion_type,
                                            subgraph=subgraph,
                                            input_mapping=input_mapping,
                                            output_mapping=output_mapping)
            if G.quantization:
                # if there are quantization stats then clear them. They need to be created again
                G.quantization.stats = None
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=qrecs[0].in_qs,
                                           out_qs=qrecs[-1].out_qs)
                    for fnode in pnode.contained_nodes():
                        G.quantization.move_to_fusion(fnode, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.node.name)
            out_edges = G.out_edges(last_node.name)
            for snode in node_list.order:
                G.remove(snode)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
예제 #2
0
    def match(self, G: GraphView, set_identity: bool = True):
        # Only works for reverses connected to one RNN node
        reverse_nodes = set([
            node for node in G.nodes()
            if (isinstance(node, ReverseParameters)
                and len(G.out_edges(node.name)) == 1 and isinstance(
                    G.out_edges(node.name)[0].to_node, RNNBaseParameters))
        ])

        has_modified_graph = False
        for reverse_node in reverse_nodes:
            in_edges = G.in_edges(reverse_node.name)
            rnn_edge = G.out_edges(reverse_node.name)[0]
            if rnn_edge.to_idx != 0:
                LOG.warning("reverse on rnn input %s", rnn_edge.to_idx)
                continue
            assert not rnn_edge.to_node.revert, "RNN node is already reversed!"
            rnn_edge.to_node.revert = True
            LOG.info("fusing reverses into node %s", rnn_edge.to_node.name)
            has_modified_graph = True
            G.remove(reverse_node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           rnn_edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=rnn_edge.to_idx))

            for edge in G.out_edges(rnn_edge.to_node.name):
                if not isinstance(edge.to_node, ReverseParameters):
                    continue
                if edge.from_idx != 0:
                    LOG.warning("reverse on rnn output %s", edge.from_idx)
                    continue
                rev_edges = G.out_edges(edge.to_node.name)
                G.remove(edge.to_node)
                for rev_edge in rev_edges:
                    G.add_edge(
                        NNEdge(edge.from_node,
                               rev_edge.to_node,
                               from_idx=edge.from_idx,
                               to_idx=rev_edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
예제 #3
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool:
        has_modified_graph = False
        for pad_params in [pad for pad in G.nodes() if isinstance(pad, PadParameters)]:
            pad_in_edges = G.in_edges(pad_params.name)
            pad_out_edges = G.out_edges(pad_params.name)
            dont_delete = False
            if len(pad_in_edges) == 1 and all(sum(padding) == 0 for padding in pad_params.padding):
                LOG.info("removing zero padding node %s",
                         pad_params.name)
                G.remove(pad_params)
                if G.quantization:
                    G.quantization.remove_node(pad_params)
                dont_delete = True
                in_edge = pad_in_edges[0]
                for out_edge in pad_out_edges:
                    G.add_edge(NNEdge(from_node=in_edge.from_node,
                                      to_node=out_edge.to_node,
                                      from_idx=in_edge.from_idx,
                                      to_idx=out_edge.to_idx))
            else:
                for pad_out_edge in pad_out_edges:
                    filter_like_node, expanded_padding, reshapes = self.find_conv(
                        G, pad_out_edge.to_node, pad_params.padding)
                    if not filter_like_node:
                        dont_delete = True
                        continue
                    if not filter_like_node.in_dims_hint or not filter_like_node.in_dims_hint[0]:
                        raise ValueError(
                            f"filter {filter_like_node.name} doesn't have a input hint")
                    in_hint = filter_like_node.in_dims_hint[0]

                    hinted_pad = {in_hint[idx]: pad for idx,
                                  pad in enumerate(expanded_padding) if sum(pad) > 0}
                    key_set = set(hinted_pad.keys())
                    key_set -= set(['h', 'w'])
                    if len(key_set) > 0:
                        dont_delete = True
                        LOG.error("node %s has padding on axes %s and cannot be fused with filter %s",
                                  pad_params.name, key_set, filter_like_node.name)
                        continue
                    if any(pval != 0 for val in pad_params.pad_vals for pval in val):
                        dont_delete = True
                        LOG.error("node %s has non zero pad values and cannot be fused with filter %s",
                                  pad_params.name, filter_like_node.name)
                        continue

                    LOG.info("adding padding from: %s to filter: %s - has %s reshapes",
                             pad_params.name, filter_like_node.name, len(reshapes))

                    for key in ['h', 'w']:
                        if key not in hinted_pad:
                            hinted_pad[key] = (0, 0)

                    filter_like_node.padding = PadDim(
                        *(list(hinted_pad['h']) + list(hinted_pad['w'])))
                    filter_like_node.pad_type = "zero"
                    has_modified_graph = True
                    G.remove_edge(pad_out_edge)
                    for reshape_node, old_padding, new_padding in reshapes:
                        reshape_node.old_shape = self.remove_padding(
                            reshape_node.old_shape, old_padding)
                        reshape_node.shape = self.remove_padding(
                            reshape_node.shape, new_padding)

                    for in_edge in pad_in_edges:
                        G.add_edge(NNEdge(from_node=in_edge.from_node,
                                          to_node=pad_out_edge.to_node,
                                          from_idx=in_edge.from_idx,
                                          to_idx=pad_out_edge.to_idx))

            if not dont_delete:
                G.remove(pad_params)
                if G.quantization:
                    G.quantization.remove_node(pad_params)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
예제 #4
0
    def match(self, G: GraphView, set_identity: bool = True):
        rnn_nodes = [
            self.find_unpack(G, node) for node in G.nodes()
            if isinstance(node, RNNBaseParameters) and node.n_output_cells > 1
        ]
        rnn_nodes_by_slice = self.validate_slices(G, rnn_nodes)
        rnn_nodes_by_slice = self.validate_multi_branch(G, rnn_nodes_by_slice)
        if not rnn_nodes_by_slice:
            return False

        for unpack_node, rnn_unpacks in rnn_nodes_by_slice.items():
            modified_nodes = set()
            for rnn_unpack in rnn_unpacks:
                self.process_path(G, rnn_unpack, modified_nodes)
            # since process path will have removed all unnecessary nodes the edges will be correct here
            out_edges = G.out_edges(unpack_node.name)
            in_edges = G.in_edges(unpack_node.name)
            assert len(in_edges
                       ) == 1, "expecting unpack node to have only one in edge"
            in_edge = in_edges[0]
            changes_shape = unpack_node.changes_shape if isinstance(
                unpack_node, StridedSliceParameters) else False

            LOG.info("Eliminating last cell unpack: %s", unpack_node.name)
            G.remove(unpack_node)

            # Here the strided slice can change the output shape of the RNN
            # so insert a reshape to do the shape change
            if changes_shape:
                reshape = ReshapeParameters(
                    unpack_node.name + '_reshape',
                    old_shape=Dim.unnamed(unpack_node.post_slice_shape),
                    shape=Dim.unnamed(unpack_node.out_shape))
                G.add_edge(
                    NNEdge(from_node=in_edge.from_node,
                           to_node=reshape,
                           from_idx=in_edge.from_idx))
                for out_edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=reshape,
                               to_node=out_edge.to_node,
                               to_idx=out_edge.to_idx))
                if G.quantization:
                    G.quantization[NodeId(reshape)] = G.quantization[NodeId(
                        unpack)]
            else:
                for out_edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=in_edge.from_node,
                               to_node=out_edge.to_node,
                               from_idx=in_edge.from_idx,
                               to_idx=out_edge.to_idx))
            if G.quantization:
                del G.quantization[NodeId(unpack_node)]

        if set_identity:
            self.set_identity(G)

        return True
예제 #5
0
 def match_function(self, G: GraphView):
     sub = GraphView()
     sub.add_node(ConcatMatcher('0'))
     return G.match_fragment(sub)
예제 #6
0
    def match_function(self, G: GraphView):
        sub = GraphView()
        sub.add_node(
            MatchNode(
                '0',
                matcher=lambda node: isinstance(node, ReluActivationParameters
                                                ) and node.upper_bound == 6))
        sub.add_node(
            MatchNode(
                '1',
                matcher=lambda node: isinstance(node, MatrixMulParameters)))
        sub.add_node(
            MatchNode(
                '2',
                matcher=lambda node: isinstance(node, ConstantInputParameters)
                and check_equals(G, node, 1.0 / 6.0)))
        sub.add_edge(Edge('0', '1', to_idx=0))
        sub.add_edge(Edge('2', '1', to_idx=1))

        return G.match_fragment(sub)
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        edge_groups = []
        for node in G.nodes(node_classes=SplitParameters):
            cur_group = None
            for out_edge_bundle in G.indexed_out_edges(node):
                if len(out_edge_bundle) == 1:
                    out_edge = out_edge_bundle[0]
                    concat_node_edges = search_down(G,
                                                    out_edge,
                                                    ConcatParameters,
                                                    can_pass=(CopyParameters,
                                                              NoOPParameters))
                    if concat_node_edges:
                        if cur_group:
                            this_concat_edge = concat_node_edges[-1]
                            last_concat_edge = cur_group[-1][-1]
                            if this_concat_edge.to_node == last_concat_edge.to_node and this_concat_edge.to_idx == last_concat_edge.to_idx + 1:
                                cur_group.append(concat_node_edges)
                                continue
                            if len(cur_group) > 1:
                                edge_groups.append(cur_group)
                        cur_group = [concat_node_edges]
                        continue
                if cur_group:
                    if len(cur_group) > 1:
                        edge_groups.append(cur_group)
                    cur_group = None
            if cur_group:
                if len(cur_group) > 1:
                    edge_groups.append(cur_group)
                cur_group = None
        # we leave the splits and concats after this since they will be cleared up by remove_noops
        for edge_group in edge_groups:
            split_node = edge_group[0][0].from_node
            concat_node = edge_group[0][-1].to_node
            from_idx = edge_group[0][0].from_idx
            to_idx = edge_group[-1][0].from_idx
            LOG.info(
                f"combining outputs {from_idx}:{to_idx} on split node {split_node.name} followed by concat {concat_node.name}"
            )
            # combine slices and shapes on edges in group
            new_slice, new_shape = reduce_slices(
                split_node.act_slices[from_idx:to_idx + 1],
                split_node.out_shapes[from_idx:to_idx + 1])
            split_node.act_slices = split_node.act_slices[:from_idx] + [
                new_slice
            ] + split_node.act_slices[to_idx + 1:]
            split_node.out_shapes = split_node.out_shapes[:from_idx] + [
                new_shape
            ] + split_node.out_shapes[to_idx + 1:]
            # remove all edges and intermediate nodes on all edge groups except the first
            for edge_list in edge_group[1:]:
                remove_edges(G, edge_list)
            out_edge_bundles = G.indexed_out_edges(split_node)
            # move edges beyond the edge group after the first index
            for offset, edge_list in enumerate(out_edge_bundles[to_idx + 1:]):
                assert len(edge_list) == 1
                edge = edge_list[0]
                G.remove_edge(edge)
                G.add_edge(NNEdge.clone(edge, from_idx=from_idx + 1 + offset))
            # reindex the in edges in the concat
            from_idx = edge_group[0][-1].to_idx
            to_idx = edge_group[-1][-1].to_idx
            in_edges = G.indexed_in_edges(concat_node)
            for offset, in_edge in enumerate(in_edges[to_idx + 1:]):
                G.remove_edge(in_edge)
                G.add_edge(NNEdge.clone(in_edge, to_idx=from_idx + 1 + offset))

        return bool(edge_groups)
예제 #8
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        slices_by_origin = {}
        for slice_node in [
                node for node in G.nodes()
                if isinstance(node, StridedSliceParameters)
        ]:
            in_edge = G.in_edges(slice_node.name)[0]
            group = slices_by_origin.setdefault(
                (in_edge.from_node, in_edge.from_idx), [])
            group.append(slice_node)
        for in_edge, slice_nodes in slices_by_origin.items():
            slices = list(zip(*[node.act_slice for node in slice_nodes]))
            if len(slice_nodes) == 1:
                self.slice_to_split(G, slice_nodes, slices)
                continue

            diff_slices = [(idx, elems) for idx, elems in enumerate(slices)
                           if not all(elems[0] == elem for elem in elems[1::])]
            if len(diff_slices) != 1:
                continue
            # strides must be one
            if any(sl[2] != 1 for sl in diff_slices[0][1]):
                continue
            # check if slices are consecutive and non overlapping
            slices = sorted(diff_slices[0][1], key=lambda x: x[0])
            if not all(sl[0] + sl[1] == slices[i + 1][0]
                       for i, sl in enumerate(slices[:-1:])):
                continue
            szes = [sl[1] - sl[0] for sl in slices]
            axis = diff_slices[0][0]
            slice_nodes = sorted(slice_nodes,
                                 key=lambda x: x.act_slice[axis][0])
            act_slices, out_shapes, axis = SplitParameters.get_splits(
                slice_nodes[0].in_dims[0].shape, axis, splits=szes)
            params = SplitParameters(slice_nodes[0].name + '_split',
                                     act_slices=act_slices,
                                     out_shapes=out_shapes,
                                     axis=axis)
            in_edge = G.in_edges(slice_nodes[0].name)[0]
            G.add_edge(
                NNEdge(from_node=in_edge.from_node,
                       to_node=params,
                       from_idx=in_edge.from_idx))
            sub_names = []
            for idx, node in enumerate(slice_nodes):
                sub_names.append(node.name)
                out_edges = G.out_edges(node.name)
                G.remove(node)
                for out_edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=params,
                               to_node=out_edge.to_node,
                               from_idx=idx,
                               to_idx=out_edge.to_idx))
            if G.quantization:
                G.add_dimensions()
                quantizer = UnifiedQuantizer.from_quantized_graph(G)
                quantizer.quantize(G, start_nodes=[params])
                RemoveUnnecessaryQuantizeOperators().match(G)

            LOG.info(
                f'replaced slice nodes {",".join(sub_names)} with split node {sub_names[0]}'
            )

            has_modified_graph = True

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
예제 #9
0
    def match(self, G: GraphView, set_identity: bool = True) -> bool:
        has_modified_graph = False
        gathers_by_origin = {}
        for gather in [
                node for node in G.nodes()
                if isinstance(node, GatherParameters)
        ]:
            in_edge = G.in_edges(gather.name)[0]
            group = gathers_by_origin.setdefault(
                (in_edge.from_node, in_edge.from_idx), [])
            group.append(gather)
        for in_edge, gathers in gathers_by_origin.items():
            # This is too difficult to handle if there are multiple slices
            axis = gathers[0].axis
            if not all(gather.axis == axis and len(gather.indices.shape) <= 1
                       for gather in gathers[1::]):
                continue
            # sort all the indices
            gathers = sorted(gathers,
                             key=lambda x: x.indices
                             if len(x.indices.shape) == 0 else x.indices[0])
            indices = [
                elem for gather in gathers
                for elem in ([int(gather.indices)] if len(gather.indices.shape)
                             == 0 else list(gather.indices))
            ]
            # All the indices must be independant and sum to the out dim (this could be relaxed but
            # then needs to handle gaps)
            in_shape = in_edge[0].out_dims[in_edge[1]].shape
            in_shape_without_axis = in_shape[:axis:] + in_shape[axis + 1::]
            if len(set(indices)) != len(indices) and len(
                    set(indices)) == in_shape[axis]:
                continue
            # good for a split
            LOG.info("gathers from %s[%s] converted to a split",
                     in_edge[0].name, in_edge[1])
            splits = []
            shapes = []
            out_edges = []
            for gather in gathers:
                splits.append(
                    [tuple([int(gather.indices),
                            int(gather.indices) + 1, 1])])
                shapes.append(in_shape_without_axis)
                out_edges.append(G.out_edges(gather.name))
                G.remove(gather)
            params = SplitParameters("%s_split" % in_edge[0].name,
                                     act_slices=splits,
                                     out_shapes=shapes,
                                     axis=axis)
            if axis != 0:
                trans = [axis] + list(range(0, axis)) + list(
                    range(axis, len(in_shape)))
                params.transpose_out = [[
                    trans.index(idx) for idx in range(len(trans))
                ]]
                params.transpose_in = [trans]
            for idx, edges in enumerate(out_edges):
                for edge in edges:
                    G.add_edge(
                        NNEdge(from_node=params,
                               to_node=edge.to_node,
                               from_idx=idx,
                               to_idx=edge.to_idx))
            G.add_edge(
                NNEdge(from_node=in_edge[0],
                       to_node=params,
                       from_idx=in_edge[1]))
            has_modified_graph = True

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
예제 #10
0
 def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
     # get a list of all the nodes that are transposable but not transposes
     # Need to do this first to avoid mutating it when doing the modifications
     tnodes = list(
         filter(
             lambda n: isinstance(n, Transposable) and not isinstance(
                 n, TransposeParameters), G.nodes()))
     has_modified_graph = False
     for node in tnodes:
         if node.transpose_in:
             for idx, edge in enumerate(G.in_edges(node.name)):
                 if edge.to_idx >= len(node.transpose_in):
                     continue
                 trans = node.transpose_in[edge.to_idx]
                 if trans is None:
                     continue
                 LOG.info("Expand transpose in on node %s", node.name)
                 has_modified_graph = True
                 in_params = TransposeParameters("%s_TIN_%s" %
                                                 (node.name, idx),
                                                 transpose=trans,
                                                 block_search_up=True)
                 if node.in_dims_hint and node.in_dims_hint[edge.to_idx]:
                     in_hint = node.in_dims_hint[edge.to_idx]
                     out_hint = apply_transpose_to_hint(in_hint, trans)
                     in_params.in_dims_hint = [in_hint.copy()]
                     in_params.out_dims_hint = [out_hint.copy()]
                     node.in_dims_hint[edge.to_idx] = out_hint
                 if G.quantization:
                     G.quantization.copy_qrec(node, 'in', edge.to_idx,
                                              in_params)
                 G.insert_node(in_params,
                               edge.from_node.name,
                               edge.to_node.name,
                               from_idx=edge.from_idx,
                               to_idx=edge.to_idx,
                               edge_class=NNEdge)
             node.transpose_in = None
         if node.transpose_out:
             for idx, edge in enumerate(G.out_edges(node.name)):
                 if edge.from_idx >= len(node.transpose_out):
                     continue
                 trans = node.transpose_out[edge.from_idx]
                 if trans is None:
                     continue
                 LOG.info("Expand transpose out on node %s", node.name)
                 has_modified_graph = True
                 out_params = TransposeParameters("%s_TOUT_%s" %
                                                  (node.name, idx),
                                                  transpose=trans,
                                                  block_search_down=True)
                 if node.out_dims_hint:
                     out_hint = node.out_dims_hint[edge.from_idx]
                     in_hint = apply_reverse_transpose_to_hint(
                         out_hint, trans)
                     out_params.in_dims_hint = [in_hint.copy()]
                     out_params.out_dims_hint = [out_hint.copy()]
                     node.out_dims_hint[edge.from_idx] = in_hint
                 if G.quantization:
                     G.quantization.copy_qrec(node, 'out', edge.from_idx,
                                              out_params)
                 G.insert_node(out_params,
                               edge.from_node.name,
                               edge.to_node.name,
                               from_idx=edge.from_idx,
                               to_idx=edge.to_idx,
                               edge_class=NNEdge)
             node.transpose_out = None
     if set_identity:
         self.set_identity(G)
     return has_modified_graph
예제 #11
0
    def match_function(self, G: GraphView):
        sub = GraphView()
        sub.add_node(MatchNode('0', matcher=lambda node:\
                isinstance(node, FilterParameters)))
        sub.add_node(MatchNode('1', matcher=lambda node:\
                isinstance(node, MatrixAddParameters)))
        sub.add_node(MatchNode('2', matcher=lambda node:\
                isinstance(node, ConstantInputParameters)))
        sub.add_edge(Edge('0', '1', to_idx=0))
        sub.add_edge(Edge('2', '1', to_idx=1))

        return G.match_fragment(sub)
예제 #12
0
    def match_function(self, G: GraphView):
        sub = GraphView()
        sub.add_node(MatchNode('0', matcher=lambda node:\
                isinstance(node, Conv2DParameters) and\
                self.valid_convolution(node)))
        if self.match_activation and self.match_pool:
            if self.pool_after_activation:
                self.add_activation('1', sub)
                self.add_pooling('2', sub)
            else:
                self.add_pooling('1', sub)
                self.add_activation('2', sub)
            sub.add_edge(Edge('0', '1'))
            sub.add_edge(Edge('1', '2'))
        elif self.match_activation:
            self.add_activation('1', sub)
            sub.add_edge(Edge('0', '1'))
        elif self.match_pool:
            self.add_pooling('1', sub)
            sub.add_edge(Edge('0', '1'))

        return G.match_fragment(sub)
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        for node in [
                node for node in G.nodes() if self.node_does_nothing(G, node)
        ]:
            has_modified_graph = True
            in_edge = G.in_edges(node.name)[0]
            G.remove_edge(in_edge)
            for out_edge in G.out_edges(node.name):
                G.remove_edge(out_edge)
                G.add_edge(
                    NNEdge(in_edge.from_node,
                           out_edge.to_node,
                           from_idx=in_edge.from_idx,
                           to_idx=out_edge.to_idx))
            LOG.info(f'removing {node.name} that does nothing')
            G.remove(node)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        filter_nodes = [node for node in G.nodes(
        ) if isinstance(node, FilterParameters)]
        for params in filter_nodes:
            filter_node = params
            seen_reshape = []
            while True:
                out_edges = G.out_edges(params.name)
                # can't fuse if there is a branch
                if len(out_edges) > 1:
                    break
                out_edge = out_edges[0]
                op_node = out_edge.to_node
                if isinstance(op_node, ReshapeParameters):
                    seen_reshape = [op_node]
                    params = op_node
                    continue
                # must be a valid matrix op
                if not isinstance(op_node, tuple(OPS.keys())):
                    break
                # other edge to the op must be a constant
                other_idx = 1 if out_edge.to_idx == 0 else 0
                other_in_edge = G.indexed_in_edges(op_node.name)[other_idx]
                if not isinstance(other_in_edge.from_node, ConstantInputParameters):
                    break
                const_node = other_in_edge.from_node
                remove_constant = len(G.out_edges(const_node.name))

                flat_value = const_node.dqvalue.flatten()
                out_c = filter_node.filter.out_c
                op, weights_and_biases = OPS[op_node.__class__]
                # it would be possible to support mult bias addition by out channel but only supporting a
                # scalar at present
                if len(flat_value) != 1 and (weights_and_biases or len(flat_value) != out_c):
                    LOG.warning('could not absorb %s into %s',
                                const_node.name, filter_node.name)
                    break
                # If there is quantization then essentially the output of the filter
                # takes the quantization of the output of the operation.
                # The biases will not change since their quantization depends on the weights
                # and input
                fnid = NodeId(filter_node)
                opnid = NodeId(op_node)
                if G.quantization and (fnid in G.quantization or opnid in G.quantization):
                    if not (fnid in G.quantization and opnid in G.quantization):
                        LOG.warning(
                            'could not absorb %s into %s - graph is partially quantized', const_node.name, filter_node.name)
                        break
                    fqrec = G.quantization[fnid]
                    opqrec = G.quantization[opnid]
                    fqrec.out_qs[0] = opqrec.out_qs[0]

                has_modified_graph = True
                LOG.info("fusing bias in %s into %s",
                         const_node.name, filter_node.name)
                self.fuse_bias(G, filter_node, other_idx, op, flat_value, 2)
                if weights_and_biases:
                    # TODO - need to adjust weights quantization here
                    LOG.info("fusing multiplicative bias in %s into %s",
                             const_node.name, filter_node.name)
                    self.fuse_bias(G, filter_node, other_idx,
                                   op, flat_value, 1)

                out_edges = G.out_edges(op_node.name)
                G.remove(op_node)
                if remove_constant:
                    G.remove(const_node)
                from_node = seen_reshape[-1] if seen_reshape else filter_node
                for edge in out_edges:
                    G.add_edge(NNEdge(from_node=from_node,
                                      to_node=edge.to_node, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
예제 #15
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        modified_graph = False
        concats = set(G.nodes(node_classes=ConcatParameters))
        while concats:
            concat = concats.pop()
            if concat.axis != 0:
                continue
            subgraph = find_concats_up(G, concat)
            found = set(subgraph.nodes(node_classes=ConcatParameters))
            if len(found) <= 1:
                continue
            LOG.info(
                f"Combining concats {','.join([node.name for node in found])}")
            modified_graph = True
            concats -= found

            in_edges = [inp.edge for inp in subgraph.inputs()]
            in_dims = [
                edge.from_node.out_dims[edge.from_idx] for edge in in_edges
            ]
            nodes_to_remove = [
                node for node in subgraph.nodes()
                if node != concat and not isinstance(node, DummyInput)
            ]
            for edge in in_edges:
                G.remove_edge(edge)
            for node in nodes_to_remove:
                if node.name in G:
                    G.remove(node)
                nid = NodeId(node)
                if G.quantization and nid in G.quantization:
                    del G.quantization[nid]

            # remove_internal_graph(G, subgraph)
            out_dim = concat.out_dims[0]
            in_qs = []
            for idx, edge in enumerate(in_edges):
                from_node = edge.from_node
                from_idx = edge.from_idx
                if len(in_dims[idx]) > 1:
                    reshape = ReshapeParameters(
                        G.unique_name(f'{concat.name}_flat{idx}'),
                        old_shape=in_dims[idx],
                        shape=Dim.unnamed([in_dims[idx].size()]))
                    G.add_edge(
                        NNEdge(from_node=from_node,
                               from_idx=from_idx,
                               to_node=reshape))
                    from_node = reshape
                    from_idx = 0
                G.add_edge(
                    NNEdge(from_node=from_node,
                           from_idx=from_idx,
                           to_node=concat,
                           to_idx=idx))
                if in_qs is not None and G.quantization:
                    nid = NodeId(edge.from_node)
                    if nid in G.quantization:
                        qrec = G.quantization[nid]
                        in_qs.append(qrec.out_qs[edge.from_idx])
                    else:
                        in_qs = None
                else:
                    in_qs = None
            if in_qs is not None and G.quantization:
                nid = NodeId(concat)
                if nid in G.quantization:
                    G.quantization[nid].in_qs = in_qs
            reshape = ReshapeParameters(G.unique_name(f'{concat.name}_expand'),
                                        old_shape=Dim.unnamed([out_dim.size()
                                                               ]),
                                        shape=out_dim)
            G.insert_node_after(concat, reshape, edge_class=NNEdge)

        if set_identity:
            self.set_identity(G)

        return modified_graph
    def match(self, G: GraphView, set_identity: bool = True):
        if not G.quantization:
            return
        for nid in [
                nid for nid, qrec in G.quantization.sorted_iterator(G)
                if qrec is None or not (qrec.in_qs and qrec.out_qs)
        ]:
            if nid.fnode_name:
                LOG.warning("can't add quantization to fused node %s",
                            nid.fnode_name)
                continue
            if nid.node_name not in G:
                # previous fusions may have removed nodes from the graph
                continue

            node = nid.get_node(G)
            predecessors = [NodeId(pred) for pred in G.predecessors(node.name)]
            successors = [
                NodeId(succ) for succs in G.successors(node.name)
                for succ in succs
            ]
            go_back = not successors or (predecessors
                                         and all(pred in G.quantization
                                                 for pred in predecessors))
            go_forward = not predecessors or (successors
                                              and all(succ in G.quantization
                                                      for succ in successors))

            if not (go_back or go_forward):
                LOG.warning(
                    "node %s is not connected to anything and has no quantization",
                    node.name)
                continue

            if go_forward:
                out_qrecs = set(G.quantization[nid] for nid in successors)
                if not all(
                        isinstance(out_qrec, MultQuantizationRecord)
                        for out_qrec in out_qrecs):
                    continue
                out_qtypes = reduce_qtypes([
                    (edge.from_idx,
                     G.quantization[NodeId(edge.to_node)].in_qs[edge.to_idx])
                    for edge in G.out_edges(node.name)
                ])
            else:
                out_qtypes = None
            if go_back:
                in_qrecs = set(G.quantization[nid] for nid in predecessors)
                if not all(
                        isinstance(in_qrec, MultQuantizationRecord)
                        for in_qrec in in_qrecs):
                    continue
                in_qtypes = reduce_qtypes([(edge.to_idx, G.quantization[NodeId(
                    edge.from_node)].out_qs[edge.from_idx])
                                           for edge in G.in_edges(node.name)])
            else:
                in_qtypes = None

            if not in_qtypes:
                if not predecessors:
                    LOG.info("setting quantization on input node %s",
                             node.name)
                    qrec = MultQuantizationRecord(in_qs=deepcopy(out_qtypes),
                                                  out_qs=deepcopy(out_qtypes))
                else:
                    raise NotImplementedError(
                        "propagating qrecs not implemented")
            elif not out_qtypes:
                if not successors:
                    LOG.info("setting quantization on output node %s",
                             node.name)
                    qrec = MultQuantizationRecord(in_qs=deepcopy(in_qtypes),
                                                  out_qs=deepcopy(in_qtypes))
                else:
                    raise NotImplementedError(
                        "propagating qrecs not implemented")
            else:
                LOG.info("setting quantization on node %s", node.name)
                qrec = MultQuantizationRecord(in_qs=deepcopy(in_qtypes),
                                              out_qs=deepcopy(out_qtypes))

            G.quantization[nid] = qrec

        if set_identity:
            self.set_identity(G)
예제 #17
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        to_quantize = []
        for frag in find_connected_groups(G):
            Symbol.set_default_control(SymbolStats())
            has_modified_graph = True
            in_edges, out_edges = external_edges(G, frag)
            in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group]
                          for edge_group in in_edges.values()]
            in_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in in_edges
            ]
            out_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in out_edges
            ]
            out_mapping = list(out_edges.keys())
            constant_inputs = [
                node_edge_idx[0] for node_edge_idx in in_edges
                if isinstance(node_edge_idx[0], ConstantInputParameters)
            ]
            LOG.debug(
                "inputs coming from: %s",
                ",".join(f"{from_node.__repr__()}:{from_idx}"
                         for from_node, from_idx in in_edges))
            LOG.info("fusing nodes: %s into expr_%s",
                     ",".join(node.__repr__() for node in frag.nodes()),
                     self._expr_num)
            expr = ExpressionFusionParameters(
                G.unique_name(f"expr_{self._expr_num}"),
                subgraph=frag,
                qrecs=G.quantization,
                input_mapping=in_mapping,
                output_mapping=out_mapping,
                in_dims=in_dims,
                out_dims=out_dims,
                constant_inputs=constant_inputs)
            in_edge_mapping = list(in_edges.keys())
            out_edge_mapping = [[(edge.to_node, edge.to_idx)
                                 for edge in edge_set]
                                for edge_set in out_edges.values()]
            G.replace_fragment(
                frag,
                expr,
                frag_in_edges=list(set.union(*in_edges.values())),
                frag_out_edges=list(set.union(*out_edges.values())),
                edge_in_mapping=in_edge_mapping,
                edge_out_mapping=out_edge_mapping,
                edge_class=NNEdge)
            if G.quantization:
                qrecs = G.quantization
                in_qs = [
                    qrecs[NodeId(in_map[0][0])].in_qs[in_map[0][1]]
                    for in_map in in_mapping
                ]
                out_qs = [
                    qrecs[NodeId(node)].out_qs[idx]
                    for node, idx in out_mapping
                ]
                stats = Symbol.CURRENT_CONTROL.stats
                func_col = expr.func_col
                for idx, qtype in enumerate(in_qs):
                    symbol = func_col.variables[func_col.input_names[idx]]
                    stats[symbol.name] = {
                        'min': qtype.min_val,
                        'max': qtype.max_val
                    }
                for idx, qtype in enumerate(out_qs):
                    symbol = func_col.variables[func_col.output_names[idx]]
                    stats[symbol.name] = {
                        'min': qtype.min_val,
                        'max': qtype.max_val
                    }
                nid = NodeId(expr)
                G.quantization[nid] = QRec(in_qs=in_qs,
                                           out_qs=out_qs,
                                           expression=stats,
                                           ktype='scaled')
                if G.quantization.stats:
                    G.quantization.stats[nid] = {
                        'range_in': [{
                            'min': qtype.min_val,
                            'max': qtype.max_val
                        } for qtype in in_qs],
                        'range_out': [{
                            'min': qtype.min_val,
                            'max': qtype.max_val
                        } for qtype in out_qs],
                        'expression':
                        stats.copy()
                    }

                # delete any quantize parameters on outputs to allow the quantizer
                # to fuse them into the expression
                out_edges = G.out_edges(expr.name)
                for edge in out_edges:
                    if isinstance(edge.to_node, QuantizeParameters):
                        G.remove_and_reconnect(edge.to_node, edge_class=NNEdge)
                        if NodeId(edge.to_node) in G.quantization:
                            del G.quantization[NodeId(edge.to_node)]
                to_quantize.append(expr)

            self._expr_num += 1

        if to_quantize:
            quantizer = NewQuantizer.from_quantized_graph(G)
            quantizer.quantize()

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
예제 #18
0
 def _match(self, G: GraphView, node: Node, edge: Edge):
     return isinstance(node, ConcatParameters) and G.num_in_edges(node.name) == 1
예제 #19
0
    def match(self, G: GraphView, set_identity: bool = True):
        rnn_nodes = [
            self.find_unpack(G, node) for node in G.nodes()
            if isinstance(node, RNNBaseParameters)
        ]
        has_modified_graph = False
        for rnn_unpack in rnn_nodes:
            if not rnn_unpack:
                continue
            unpack_node = rnn_unpack[-1]
            rnn_node = rnn_unpack[0]
            time_axis = 0
            if isinstance(unpack_node, StridedSliceParameters):
                if unpack_node.act_slice[time_axis][1] != rnn_node.n_cells:
                    LOG.debug("can't remove %s. Slice not equal to cells",
                              unpack_node.name)
                    continue
                if unpack_node.act_slice[time_axis][2] != 1:
                    LOG.debug("can't remove %s. Slice not of length 1",
                              unpack_node.name)
                    continue
                if unpack_node.act_slice[time_axis][0] != rnn_node.n_cells - 1:
                    LOG.debug("can't remove %s. Slice isn't last cell",
                              unpack_node.name)
                    continue
                out_edge = G.out_edges(unpack_node.name)[0]
                changes_shape = unpack_node.changes_shape
            elif isinstance(unpack_node, SplitParameters):
                out_edges = G.out_edges(unpack_node.name)
                if len(out_edges) > 1:
                    LOG.debug("can't remove %s. More than one output edge",
                              unpack_node.name)
                    continue
                out_edge = out_edges[0]
                if out_edge.from_idx != len(unpack_node.act_slices) - 1:
                    LOG.debug("can't remove %s. Not last output",
                              unpack_node.name)
                    continue
                act_slice = unpack_node.act_slices[-1]
                if act_slice[time_axis][1] != rnn_node.n_cells:
                    LOG.debug("can't remove %s. Slice not equal to cells",
                              unpack_node.name)
                    continue
                if act_slice[time_axis][0] != rnn_node.n_cells - 1:
                    LOG.debug("can't remove %s. Slice isn't last cell",
                              unpack_node.name)
                    continue
                changes_shape = False
                out_edge = G.out_edges(unpack_node.name)[0]
            else:
                continue

            has_modified_graph = True
            LOG.info("Eliminating last cell unpack: %s", unpack_node.name)
            for node in rnn_unpack[1:-1:]:
                LOG.info("Eliminating others: %s", node.name)
                if G.quantization:
                    del G.quantization[NodeId(node)]
                G.remove(node)
            G.remove(unpack_node)
            rnn_node.n_output_cells = 1
            rnn_node.out_dims[0] = unpack_node.out_dims[out_edge.from_idx]
            if unpack_node.out_dims_hint and unpack_node.out_dims_hint[
                    out_edge.from_idx]:
                rnn_node.out_dims_hint = [
                    unpack_node.out_dims_hint[out_edge.from_idx]
                ]
            else:
                rnn_node.out_dims_hint = None
            # Here the strided slice can change the output shape of the RNN
            # so insert a reshape to do the shape change
            if changes_shape:
                reshape = ReshapeParameters(
                    unpack_node.name + '_reshape',
                    old_shape=Dim.unnamed(unpack_node.post_slice_shape),
                    shape=Dim.unnamed(unpack_node.out_shape))
                G.add_edge(NNEdge(rnn_node, reshape))
                G.add_edge(
                    NNEdge(reshape, out_edge.to_node, to_idx=out_edge.to_idx))
                if G.quantization:
                    G.quantization[NodeId(reshape)] = G.quantization[NodeId(
                        unpack)]
            else:
                G.add_edge(
                    NNEdge(rnn_node, out_edge.to_node, to_idx=out_edge.to_idx))
            if G.quantization:
                del G.quantization[NodeId(unpack_node)]

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
예제 #20
0
    def match(self, G: GraphView, set_identity: bool = True) -> bool:
        has_modified_graph = False
        for split_node in set([node for node in G.nodes() if isinstance(node, SplitParameters)]):
            in_edges = G.in_edges(split_node.name)
            if len(in_edges) > 1:
                continue
            in_edge = in_edges[0]
            if not isinstance(in_edge.from_node, ConcatParameters):
                continue
            concat_node = in_edge.from_node
            if len(G.out_edges(concat_node.name)) > 1:
                continue
            if concat_node.transpose_out or split_node.transpose_in:
                continue
            if concat_node.axis != split_node.axis:
                continue
            axis = concat_node.axis
            split_out_sizes = [out_shape[axis] for out_shape in split_node.out_shapes]
            if len(split_out_sizes) != len(concat_node.in_dims):
                continue
            if not all(split_out_sizes[idx] == in_dim.shape[axis] for idx, in_dim in enumerate(concat_node.in_dims)):
                continue
            has_modified_graph = True
            LOG.info("removing unnecessary concat/split pair %s/%s", concat_node.name, split_node.name)
            concat_in_edges = G.indexed_in_edges(concat_node.name)
            split_out_edges = G.indexed_out_edges(split_node.name)
            G.remove(split_node)
            G.remove(concat_node)
            for idx, in_edge in enumerate(concat_in_edges):
                for out_edge in split_out_edges[idx]:
                    G.add_edge(NNEdge(from_node=in_edge.from_node, from_idx=in_edge.from_idx,
                                      to_node=out_edge.to_node, to_idx=out_edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
def reverse_matmul(G: GraphView, params):
    # reverse edges
    in_edges = G.indexed_in_edges(params.name)
    for edge in in_edges[0:2:]:
        G.remove_edge(edge)
    other_idx = 1
    for edge in in_edges[0:2:]:
        G.add_edge(
            NNEdge(from_node=edge.from_node,
                   to_node=params,
                   from_idx=edge.from_idx,
                   to_idx=other_idx))
        other_idx = 1 - other_idx
    nid = NodeId(params)
    if G.quantization and nid in G.quantization:
        qrec = G.quantization[nid]
        # swap qrecs
        qrec.in_qs[0], qrec.in_qs[1] = qrec.in_qs[1], qrec.in_qs[0]

    # add transposes
    in_nodes = []
    for idx in range(2):
        tin_params = TransposeParameters(
            G.unique_name(f"{params.name}_tin{idx}"), transpose=(1, 0))
        in_nodes.append(tin_params)
        G.insert_node_before(tin_params, params, to_idx=idx, edge_class=NNEdge)
    tout_params = TransposeParameters(G.unique_name(f"{params.name}_tout"),
                                      transpose=(1, 0))
    G.insert_node_after(params, tout_params)
    return in_nodes, tout_params
예제 #22
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for conv_node in [params for params in G.nodes() if isinstance(params, Conv2DParameters)]:
            node_list = self.get_node_list(G, conv_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            if node_list.fusion_type == 'conv_active_pool':
                if node_list.pool.pool_type == "average":
                    node_list.order = node_list.order[:2:]
                    node_list.pool = None
            elif node_list.fusion_type == 'conv_pool_active':
                if node_list.pool.pool_type == "average" and node_list.active.activation != "relu":
                    continue
            LOG.info("fusing nodes %s", ",".join((node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(NNEdge(from_node=last_node, to_node=node))
                last_node = node
            input_mapping = [[(node_list.conv, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = ConvFusionParameters(
                node_list.conv.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                in_dims_hint=node_list.conv.in_dims_hint,
                out_dims_hint=node_list.conv.out_dims_hint,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = None
                    if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)):
                        prec = SymmetricQuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)):
                        prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)):
                        prec = Float32QuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.conv.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        has_transposed = False
        for params in G.nodes(node_classes=MatMulOpParameters):
            matmul = params
            seen_reshape = []
            while True:
                out_edges = G.out_edges(params.name)
                # can't fuse if there is a branch
                if len(out_edges) > 1:
                    break
                out_edge = out_edges[0]
                op_node = out_edge.to_node
                if isinstance(op_node, ReshapeParameters):
                    seen_reshape.append(op_node)
                    params = op_node
                    continue
                # must be a valid matrix op
                if not isinstance(op_node,
                                  (MatrixAddParameters, MatrixMulParameters)):
                    break
                # other edge to the op must be a constant
                other_idx = 1 if out_edge.to_idx == 0 else 0
                other_in_edge = G.indexed_in_edges(op_node.name)[other_idx]
                if not isinstance(other_in_edge.from_node,
                                  ConstantInputParameters):
                    break
                const_node = other_in_edge.from_node
                remove_constant = len(G.out_edges(const_node.name))

                flat_value = const_node.dqvalue.flatten()
                out_shape = matmul.out_dims[0].shape
                if len(out_shape) != 2:
                    raise ValueError(
                        f'strange outputs shape of {out_shape} for matmul {params.name}'
                    )
                if len(flat_value) != out_shape[0] and len(
                        flat_value) != out_shape[1]:
                    LOG.info(
                        "can't fuse %s into %s - value shape is not correct for bias",
                        const_node.name, matmul.name)
                    break
                has_bias = len(matmul.in_dims) == 3
                in_nodes = [matmul]
                out_node = seen_reshape[-1] if seen_reshape else matmul
                if isinstance(op_node, MatrixAddParameters):
                    if has_bias:
                        if len(flat_value.shape) != len(matmul.in_dims[2]):
                            LOG.info(
                                "can't fuse %s into %s - bias shape is not the same",
                                const_node.name, matmul.name)
                            break
                        bias_node = G.indexed_in_edges(
                            matmul.name)[2].from_node
                        LOG.info(
                            "folding additive bias from %s into existing bias on %s",
                            op_node.name, matmul.name)
                        bias_node.value = bias_node.dq_value + flat_value
                    else:
                        if len(flat_value) == out_shape[1]:
                            # matmul needs to be transposed to fuse this
                            in_nodes, trans_node = reverse_matmul(G, matmul)
                            if seen_reshape:
                                out_node = seen_reshape[-1]
                            else:
                                out_node = trans_node
                            has_transposed = True
                        bias_node = ConstantInputParameters(
                            G.unique_name(f'{matmul.name}_bias'),
                            value=flat_value,
                            dims=Dim.unnamed(flat_value.shape))
                        G.add_edge(
                            NNEdge(from_node=bias_node,
                                   to_node=matmul,
                                   to_idx=2))
                        LOG.info(
                            "folding additive bias from %s into new bias on %s",
                            op_node.name, matmul.name)
                else:
                    params_in = G.indexed_in_edges(matmul.name)
                    consts = [
                        isinstance(edge.from_node, ConstantInputParameters)
                        for edge in params_in
                    ]
                    if not any(consts):
                        break
                    mult_const_node = params_in[1].from_node if consts[
                        1] else params_in[0].from_node
                    mult_const_node.value = mult_const_node.dqvalue * const_node.dqvalue
                    if has_bias:
                        bias_node = params_in[2].from_node
                        bias_node.value = bias_node.dqvalue * const_node.dqvalue

                    LOG.info(
                        "folding multaplicative bias from %s into new bias on %s",
                        op_node.name, matmul.name)

                out_edges = G.out_edges(op_node.name)
                G.remove(op_node)
                if remove_constant:
                    G.remove(const_node)
                for edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=out_node,
                               to_node=edge.to_node,
                               to_idx=edge.to_idx))
                G.add_dimensions()
                if G.quantization:
                    quantizer = NewQuantizer.from_quantized_graph(G)
                    quantizer.quantize()
                    RemoveUnnecessaryQuantizeOperators().match(G)

        if has_transposed:
            G.adjust_order()

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
예제 #24
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        for pad_node in [
                params for params in G.nodes()
                if isinstance(params, PadParameters)
        ]:
            node_list = self.get_node_list(G, pad_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx
            subgraph.add_edge(
                NNEdge(from_node=node_list.pad,
                       to_node=node_list.add,
                       to_idx=padded_input_idx))
            last_node = node_list.add
            node_list.add.force_quantized_index = 0
            if node_list.active:
                subgraph.add_edge(
                    NNEdge(from_node=node_list.add, to_node=node_list.active))
                last_node = node_list.active
            if padded_input_idx == 0:
                input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]]
            else:
                input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]]

            output_mapping = [(last_node, 0)]
            pnode = PaddedAddFusionParameters(
                "PADDED_" + node_list.add.name,
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                # TODO - stats
                if qrecs:
                    prec = QRec.copy_ktype(qrecs[1],
                                           in_qs=qrecs[1].in_qs,
                                           out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            if padded_input_idx == 0:
                in_edges = G.in_edges(node_list.pad.name) + \
                    G.indexed_in_edges(node_list.add.name)[1::]
            else:
                in_edges = G.indexed_in_edges(
                    node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        group_identity = kwargs.get('group_identity')
        if group_identity == 'pow2_match_group':
            valid_activations = VALID_ACTIVATIONS_POW2
        else:
            valid_activations = VALID_ACTIVATIONS_SQ8
        for fc_node in [params for params in G.nodes() if isinstance(params, FcParameters)]:
            node_list = self.get_node_list(G, fc_node, valid_activations)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(
                        NNEdge(from_node=last_node, to_node=node))
                last_node = node
            input_mapping = [[(node_list.linear, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = LinearFusionParameters(
                node_list.linear.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                # TODO - stats
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(
                        qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.linear.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(NNEdge(edge.from_node, pnode,
                                  from_idx=edge.from_idx, to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(NNEdge(pnode, edge.to_node,
                                  from_idx=edge.from_idx, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph