Esempio n. 1
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        modified_graph = False
        candidates = [node for node in G.nodes()
                      if len(G.indexed_out_edges(node.name)) == 1 and len(G.out_edges(node.name)) > 1]
        while candidates:
            node = candidates.pop(0)
            strings = self.explore(G, [node])
            if not strings:
                continue
            modified_graph = True
            primary = strings.pop(0)
            for pnode in primary:
                if pnode in candidates:
                    candidates.remove(pnode)
            out_edges = []
            for other in strings:
                out_edges.extend(G.out_edges(other[-1].name))
                for other_node in other:
                    if other_node in candidates:
                        candidates.remove(other_node)
                    G.remove(other_node)
                    nid = NodeId(other_node)
                    if G.quantization and nid in G.quantization:
                        del G.quantization[nid]
                LOG.info(
                    f'removed duplicates from {primary[0].name} {",".join(node.name for node in other)}')
            pend = primary[-1]
            for edge in out_edges:
                G.add_edge(
                    NNEdge(from_node=pend, to_node=edge.to_node, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return modified_graph
Esempio n. 2
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        fragment = GraphMatcher(
            match_function=lambda state, frag: (frag, state['match']))
        fragment.add_node(MatScaleNodeMatch())
        has_modified_graph = False
        for frag, match in fragment.match_graph(G):
            match_edges = [
                G.indexed_in_edges(node.name)[idx]
                for node, idx in match['inputs']
            ]
            matched_node = list(frag.nodes())[0]
            out_edges = G.out_edges(matched_node.name)
            has_modified_graph = True
            G.remove(matched_node)
            fnode = MatScaleFusionParameters(
                "{}_fusion".format(matched_node.name),
                fusion_type=match['type'],
                subgraph=frag,
                input_mapping=[[(matched_node, 0)], [(matched_node, 1)]])
            G.add_node(fnode)
            for idx, edge in enumerate(match_edges):
                edge.to_node = fnode
                edge.to_idx = idx
                G.add_edge(edge)
            for edge in out_edges:
                edge.from_node = fnode
                G.add_edge(edge)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 3
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False

        for node in G.nodes(node_classes=MatMulOpParameters):
            in_edges = [edge for edge in G.indexed_in_edges(node.name)]
            trans_node = in_edges[1].from_node
            if not isinstance(trans_node, TransposeParameters):
                continue
            if isinstance(node, MatMulTransposedParameters):
                new_node = MatMulOpParameters(node.name)
            else:
                new_node = MatMulTransposedParameters(node.name)

            in_trans_edge = [
                edge for edge in G.indexed_in_edges(trans_node.name)
            ][0]
            G.replace_node(node.name, new_node)
            G.remove(trans_node)
            G.add_edge(
                NNEdge(in_trans_edge.from_node,
                       new_node,
                       from_idx=in_trans_edge.from_idx,
                       to_idx=1))
            has_modified_graph = True

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        for node in [
                node for node in G.nodes() if self.node_does_nothing(G, node)
        ]:
            has_modified_graph = True
            in_edge = G.in_edges(node.name)[0]
            G.remove_edge(in_edge)
            for out_edge in G.out_edges(node.name):
                G.remove_edge(out_edge)
                G.add_edge(
                    NNEdge(in_edge.from_node,
                           out_edge.to_node,
                           from_idx=in_edge.from_idx,
                           to_idx=out_edge.to_idx))
            LOG.info(f'removing {node.name} that does nothing')
            G.remove(node)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 5
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False

        for node in G.nodes(node_classes=tuple(VALID_FUSIONS.keys())):
            node_list = self.get_node_list(G, node,
                                           FusionMatch(self._default_ktype))
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for snode in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(
                        NNEdge(from_node=last_node, to_node=snode))
                last_node = snode
            # assumption here is that the first node could have multiple inputs but definitely has only
            # one output
            input_mapping = [[
                (node_list.node, idx)
            ] for idx in range(G.num_in_edges(node_list.node.name))]
            output_mapping = [(last_node, 0)]
            pnode = node_list.fusions_class(node_list.node.name + '_fusion',
                                            fusion_type=node_list.fusion_type,
                                            subgraph=subgraph,
                                            input_mapping=input_mapping,
                                            output_mapping=output_mapping)
            if G.quantization:
                # TODO - stats
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=qrecs[0].in_qs,
                                           out_qs=qrecs[-1].out_qs)
                    for fnode in pnode.contained_nodes():
                        G.quantization.move_to_fusion(fnode, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.node.name)
            out_edges = G.out_edges(last_node.name)
            for snode in node_list.order:
                G.remove(snode)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 6
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for conv_node in [params for params in G.nodes() if isinstance(params, Conv2DParameters)]:
            node_list = self.get_node_list(G, conv_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            if node_list.fusion_type == 'conv_active_pool':
                if node_list.pool.pool_type == "average":
                    node_list.order = node_list.order[:2:]
                    node_list.pool = None
            elif node_list.fusion_type == 'conv_pool_active':
                if node_list.pool.pool_type == "average" and node_list.active.activation != "relu":
                    continue
            LOG.info("fusing nodes %s", ",".join((node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(NNEdge(from_node=last_node, to_node=node))
                last_node = node
            input_mapping = [[(node_list.conv, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = ConvFusionParameters(
                node_list.conv.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                in_dims_hint=node_list.conv.in_dims_hint,
                out_dims_hint=node_list.conv.out_dims_hint,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = None
                    if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)):
                        prec = SymmetricQuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)):
                        prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)):
                        prec = Float32QuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.conv.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 7
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        rnn_nodes = [
            self.find_unpack(G, node) for node in G.nodes()
            if isinstance(node, RNNBaseParameters) and node.n_output_cells > 1
        ]
        rnn_nodes_by_slice = self.validate_slices(G, rnn_nodes)
        rnn_nodes_by_slice = self.validate_multi_branch(G, rnn_nodes_by_slice)
        if not rnn_nodes_by_slice:
            return False

        for unpack_node, rnn_unpacks in rnn_nodes_by_slice.items():
            modified_nodes = set()
            for rnn_unpack in rnn_unpacks:
                self.process_path(G, rnn_unpack, modified_nodes)
            # since process path will have removed all unnecessary nodes the edges will be correct here
            out_edges = G.out_edges(unpack_node.name)
            in_edges = G.in_edges(unpack_node.name)
            assert len(in_edges
                       ) == 1, "expecting unpack node to have only one in edge"
            in_edge = in_edges[0]
            changes_shape = unpack_node.changes_shape if isinstance(
                unpack_node, StridedSliceParameters) else False

            LOG.info("Eliminating last cell unpack: %s", unpack_node.name)
            G.remove(unpack_node)

            # Here the strided slice can change the output shape of the RNN
            # so insert a reshape to do the shape change
            if changes_shape:
                reshape = ReshapeParameters(
                    unpack_node.name + '_reshape',
                    old_shape=Dim.unnamed(unpack_node.post_slice_shape),
                    shape=Dim.unnamed(unpack_node.out_shape))
                G.add_edge(
                    NNEdge(from_node=in_edge.from_node,
                           to_node=reshape,
                           from_idx=in_edge.from_idx))
                for out_edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=reshape,
                               to_node=out_edge.to_node,
                               to_idx=out_edge.to_idx))
                if G.quantization:
                    G.quantization[NodeId(reshape)] = G.quantization[NodeId(
                        unpack)]
            else:
                for out_edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=in_edge.from_node,
                               to_node=out_edge.to_node,
                               from_idx=in_edge.from_idx,
                               to_idx=out_edge.to_idx))
            if G.quantization:
                del G.quantization[NodeId(unpack_node)]

        if set_identity:
            self.set_identity(G)

        return True
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        group_identity = kwargs.get('group_identity')
        if group_identity == 'pow2_match_group':
            valid_activations = VALID_ACTIVATIONS_POW2
        else:
            valid_activations = VALID_ACTIVATIONS_SQ8
        for fc_node in [params for params in G.nodes() if isinstance(params, FcParameters)]:
            node_list = self.get_node_list(G, fc_node, valid_activations)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(
                        NNEdge(from_node=last_node, to_node=node))
                last_node = node
            input_mapping = [[(node_list.linear, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = LinearFusionParameters(
                node_list.linear.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                # TODO - stats
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(
                        qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.linear.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(NNEdge(edge.from_node, pnode,
                                  from_idx=edge.from_idx, to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(NNEdge(pnode, edge.to_node,
                                  from_idx=edge.from_idx, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 9
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        for split_node in set(
            [node for node in G.nodes() if isinstance(node, SplitParameters)]):
            in_edges = G.in_edges(split_node.name)
            if len(in_edges) > 1:
                continue
            in_edge = in_edges[0]
            if not isinstance(in_edge.from_node, ConcatParameters):
                continue
            concat_node = in_edge.from_node
            if len(G.out_edges(concat_node.name)) > 1:
                continue
            if concat_node.transpose_out or split_node.transpose_in:
                continue
            if concat_node.axis != split_node.axis:
                continue
            axis = concat_node.axis
            split_out_sizes = [
                out_shape[axis] for out_shape in split_node.out_shapes
            ]
            if len(split_out_sizes) != len(concat_node.in_dims):
                continue
            if not all(split_out_sizes[idx] == in_dim.shape[axis]
                       for idx, in_dim in enumerate(concat_node.in_dims)):
                continue
            has_modified_graph = True
            LOG.info("removing unnecessary concat/split pair %s/%s",
                     concat_node.name, split_node.name)
            concat_in_edges = G.indexed_in_edges(concat_node.name)
            split_out_edges = G.indexed_out_edges(split_node.name)
            G.remove(split_node)
            G.remove(concat_node)
            for idx, in_edge in enumerate(concat_in_edges):
                for out_edge in split_out_edges[idx]:
                    G.add_edge(
                        NNEdge(from_node=in_edge.from_node,
                               from_idx=in_edge.from_idx,
                               to_node=out_edge.to_node,
                               to_idx=out_edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 10
0
    def match(self, G: GraphView, set_identity: bool = True):
        # Only works for reverses connected to one RNN node
        reverse_nodes = set([
            node for node in G.nodes()
            if (isinstance(node, ReverseParameters)
                and len(G.out_edges(node.name)) == 1 and isinstance(
                    G.out_edges(node.name)[0].to_node, RNNBaseParameters))
        ])

        has_modified_graph = False
        for reverse_node in reverse_nodes:
            in_edges = G.in_edges(reverse_node.name)
            rnn_edge = G.out_edges(reverse_node.name)[0]
            if rnn_edge.to_idx != 0:
                LOG.warning("reverse on rnn input %s", rnn_edge.to_idx)
                continue
            assert not rnn_edge.to_node.revert, "RNN node is already reversed!"
            rnn_edge.to_node.revert = True
            LOG.info("fusing reverses into node %s", rnn_edge.to_node.name)
            has_modified_graph = True
            G.remove(reverse_node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           rnn_edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=rnn_edge.to_idx))

            for edge in G.out_edges(rnn_edge.to_node.name):
                if not isinstance(edge.to_node, ReverseParameters):
                    continue
                if edge.from_idx != 0:
                    LOG.warning("reverse on rnn output %s", edge.from_idx)
                    continue
                rev_edges = G.out_edges(edge.to_node.name)
                G.remove(edge.to_node)
                for rev_edge in rev_edges:
                    G.add_edge(
                        NNEdge(edge.from_node,
                               rev_edge.to_node,
                               from_idx=edge.from_idx,
                               to_idx=rev_edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 11
0
    def match(self, G: GraphView, set_identity: bool = True):
        visited_edges = {}
        nodes_to_remove = []
        has_modified_graph = False
        for node in G.inputs():
            # check if constantinput. if is then check if positive and check max value
            if isinstance(node, ConstantInputParameters):
                if node.value is not None:
                    if G.has_quantized_parameters:
                        qrec = G.quantization[NodeId(node)]
                        qtype = qrec.out_qs[0]
                        if hasattr(qtype, 'wrapped'):
                            qtype = qtype.wrapped
                        val = qtype.dequantize(node.value)
                    else:
                        val = node.value
                    if val.min() >= 0:
                        status = (True, val.max())
                    else:
                        status = (False, False)
            else:
                status = (False, False)

            for edge in G.out_edges(node.name):
                visited_edges[edge] = status
                nodes_to_remove += find_redundant_relus(
                    G, edge.to_node, visited_edges)
        for node in nodes_to_remove:
            has_modified_graph = True
            # Only relus so only one in edge
            in_edge = G.in_edges(node.name)[0]
            for edge in G.out_edges(node.name):
                G.add_edge(
                    NNEdge(from_node=in_edge.from_node,
                           from_idx=in_edge.from_idx,
                           to_node=edge.to_node,
                           to_idx=edge.to_idx))
            G.remove(node)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 12
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        fac = MatScalePairMatchFactory()
        has_modified_graph = False

        for frag, match in fac.get_matcher().match_graph(G):
            match_edges = [
                G.indexed_in_edges(node.name)[idx] for node, idx in match
            ]
            first_node = frag.inputs()[0]
            last_node = frag.outputs()[0]
            out_edges = G.out_edges(last_node.name)
            for node in frag.nodes():
                G.remove(node)

            input_mapping = MatScaleFusionParameters.get_mapping_from_edges(
                match_edges)

            fnode = MatScaleFusionParameters(
                "{}_{}_fusion".format(first_node.name, last_node.name),
                fusion_type="vec_scalar",
                subgraph=frag,
                input_mapping=MatScaleFusionParameters.convert_input_mapping(
                    input_mapping))
            has_modified_graph = True
            G.add_node(fnode)
            fnode.in_dims_hint = [None] * 3

            for idx, edge in enumerate(match_edges):
                new_edge = edge.clone(
                    to_node=fnode,
                    to_idx=list(input_mapping[edge.to_node].keys())[0])
                if new_edge.from_node.out_dims_hint:
                    fnode.in_dims_hint[idx] = new_edge.from_node.out_dims_hint[
                        edge.from_idx]
                G.add_edge(new_edge)
            for edge in out_edges:
                new_edge = edge.clone(from_node=fnode)
                G.add_edge(new_edge)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 13
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        modified_graph = True
        while modified_graph:
            modified_graph = False
            for reshape in G.nodes(node_classes=(ReshapeParameters, )):
                if not reshape.has_transpose and reshape.shape.shape == reshape.old_shape.shape:
                    modified_graph = True
                    LOG.info('removing reshape that does nothing %s',
                             reshape.name)
                    G.remove_and_reconnect(reshape, edge_class=NNEdge)
                    nid = NodeId(reshape)
                    if G.quantization and nid in G.quantization:
                        del G.quantization[nid]
            res = None
            for reshape in G.nodes(node_classes=(ReshapeParameters, )):
                res = self.validate_reshape(G, reshape)
                if res:
                    LOG.info('unnecessary reshape found after %s',
                             reshape.name)
                    modified_graph = True
                    (reshape, candidates, out_shape) = res
                    for candidate in candidates:
                        LOG.info(
                            'removing unnecessary reshape or transpose %s',
                            candidate.name)
                        edges = G.out_edges(candidate.name)
                        G.remove(candidate)
                        nid = NodeId(candidate)
                        if G.quantization and nid in G.quantization:
                            del G.quantization[nid]
                        for edge in edges:
                            G.add_edge(
                                NNEdge(from_node=reshape,
                                       to_node=edge.to_node,
                                       to_idx=edge.to_idx))
                    reshape.shape = Dim.unnamed(out_shape)
                    break

        if set_identity:
            self.set_identity(G)

        return modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        for node in G.nodes(node_classes=SplitParameters):
            same_op_edges = self.moveable_same_operation_edges(G, node)
            if not same_op_edges:
                continue
            has_modified_graph = True
            in_edges = G.in_edges(node.name)
            assert len(in_edges) == 1
            # sort by name to ensure that operation is repeatable
            same_op_edges.sort(key=lambda x: x.to_node.name)
            keep_node = same_op_edges[0].to_node
            LOG.info('split node %s has duplicate operations on its out edges',
                     node.name)
            LOG.info('moving %s before split node %s', keep_node.name,
                     node.name)
            for edge in G.out_edges(node.name):
                node_out_edges = G.out_edges(edge.to_node.name)
                G.remove(edge.to_node)
                if edge.to_node != keep_node:
                    LOG.info('deleting duplicate node %s', edge.to_node.name)
                    if G.quantization:
                        nid = NodeId(edge.to_node)
                        if nid in G.quantization:
                            del G.quantization[nid]
                for out_edge in node_out_edges:
                    G.add_edge(
                        NNEdge(from_node=node,
                               from_idx=edge.from_idx,
                               to_node=out_edge.to_node,
                               to_idx=out_edge.to_idx))
            G.insert_node_at_edge(keep_node, in_edges[0], edge_class=NNEdge)
            if G.quantization:
                quantizer = NewQuantizer.from_quantized_graph(G)
                quantizer.quantize()

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 15
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        nodes = list(G.nodes(node_classes=GlobalPoolingParameters))
        modified_graph = False
        while nodes:
            node = nodes.pop()
            node_group = self.reductions(G, node)
            if len(node_group) <= 1:
                continue
            modified_graph = True
            reduction_axes, new_shape, has_keepdims, _ = reduce(
                reduce_reduction, node_group, None)
            new_node = node_group[0]
            new_node.axis = sorted(list(reduction_axes))
            new_node.keep_dims = has_keepdims
            out_edges = G.out_edges(node_group[-1].name)
            if G.quantization:
                last_qrec = G.quantization[NodeId(node_group[-1])]
                G.quantization[NodeId(new_node)].out_qs = last_qrec.out_qs
            for node in node_group[1::]:
                G.remove(node.name)
                nid = NodeId(node)
                if G.quantization and nid in G.quantization:
                    del G.quantization[nid]
            if has_keepdims and len(new_shape) != len(
                    new_node.in_dims[0].shape):
                rparams = ReshapeParameters(
                    G.unique_name(f'{new_node.name}_reshape'),
                    shape=Dim.unnamed(new_shape))
                if G.quantization:
                    G.quantization.copy_qrec(last_qrec, 'out', 0, rparams)
                G.add_edge(NNEdge(new_node, rparams))
                new_node = rparams
            for edge in out_edges:
                G.add_edge(NNEdge(new_node, edge.to_node, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return modified_graph
Esempio n. 16
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        visited_edges = {}
        nodes_to_remove = []
        has_modified_graph = False
        for node in G.inputs():
            # check if constantinput. if is then check if positive and check max value
            if isinstance(node, ConstantInputParameters):
                if node.value is not None:
                    val = node.dqvalue
                    if np.min(val) >= 0:
                        status = (True, np.max(val))
                    else:
                        status = (False, False)
                else:
                    status = (False, False)
            else:
                status = (False, False)

            for edge in G.out_edges(node.name):
                visited_edges[edge] = status
                nodes_to_remove += find_redundant_relus(
                    G, edge.to_node, visited_edges)
        for node in nodes_to_remove:
            has_modified_graph = True
            # Only relus so only one in edge
            LOG.info("removing redundant relu %s", node.name)
            in_edge = G.in_edges(node.name)[0]
            out_edges = G.out_edges(node.name)
            G.remove(node)
            for edge in out_edges:
                G.add_edge(NNEdge(from_node=in_edge.from_node,
                                  from_idx=in_edge.from_idx,
                                  to_node=edge.to_node,
                                  to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 17
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        has_transposed = False
        for params in G.nodes(node_classes=MatMulOpParameters):
            while True:
                out_edges = G.out_edges(params.name)
                # can't fuse if there is a branch
                if len(out_edges) > 1:
                    break
                out_edge = out_edges[0]
                op_node = out_edge.to_node
                # must be a valid matrix op
                if not isinstance(op_node,
                                  (MatrixAddParameters, MatrixMulParameters)):
                    break
                # other edge to the op must be a constant
                other_idx = 1 if out_edge.to_idx == 0 else 0
                other_in_edge = G.indexed_in_edges(op_node.name)[other_idx]
                if not isinstance(other_in_edge.from_node,
                                  ConstantInputParameters):
                    break
                const_node = other_in_edge.from_node
                remove_constant = len(G.out_edges(const_node.name))

                flat_value = const_node.dqvalue.flatten()
                out_shape = params.out_dims[0].shape
                if len(out_shape) != 2:
                    raise ValueError(
                        f'strange outputs shape of {out_shape} for matmul {params.name}'
                    )
                if len(flat_value) != out_shape[0] and len(
                        flat_value) != out_shape[1]:
                    LOG.info(
                        "can't fuse %s into %s - value shape is not correct for bias",
                        const_node.name, params.name)
                    break
                has_bias = len(params.in_dims) == 3
                if isinstance(op_node, MatrixAddParameters):
                    if has_bias:
                        if len(flat_value.shape) != len(params.in_dims[2]):
                            LOG.info(
                                "can't fuse %s into %s - bias shape is not the same",
                                const_node.name, params.name)
                            break
                        bias_node = G.indexed_in_edges(
                            params.name)[2].from_node
                        LOG.info(
                            "folding additive bias from %s into existing bias on %s",
                            op_node.name, params.name)
                        bias_node.value = bias_node.dq_value + flat_value
                    else:
                        if len(flat_value) == out_shape[1]:
                            # matmul needs to be transposed to fuse this
                            reverse_matmul(G, params)
                            has_transposed = True
                        bias_node = ConstantInputParameters(
                            G.unique_name(f'{params.name}_bias'),
                            value=flat_value,
                            dims=Dim.unnamed(flat_value.shape))
                        G.add_edge(
                            NNEdge(from_node=bias_node,
                                   to_node=params,
                                   to_idx=2))
                        # extend the inward transpose
                        if params.transpose_in:
                            params.transpose_in = params.transpose_in + [None]
                        LOG.info(
                            "folding additive bias from %s into new bias on %s",
                            op_node.name, params.name)
                else:
                    params_in = G.indexed_in_edges(params.name)
                    consts = [
                        isinstance(edge.from_node, ConstantInputParameters)
                        for edge in params_in
                    ]
                    if not any(consts):
                        break
                    mult_const_node = params_in[1].from_node if consts[
                        1] else params_in[0].from_node
                    mult_const_node.value = mult_const_node.dqvalue * const_node.dqvalue
                    if has_bias:
                        bias_node = params_in[2].from_node
                        bias_node.value = bias_node.dqvalue * const_node.dqvalue

                    LOG.info(
                        "folding multaplicative bias from %s into new bias on %s",
                        op_node.name, params.name)

                out_edges = G.out_edges(op_node.name)
                G.remove(op_node)
                if remove_constant:
                    G.remove(const_node)
                for edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=params,
                               to_node=edge.to_node,
                               to_idx=edge.to_idx))
                G.add_dimensions()
                if G.quantization:
                    quantizer = UnifiedQuantizer.from_quantized_graph(G)
                    quantizer.quantize(G, start_nodes=[params])
                    RemoveUnnecessaryQuantizeOperators().match(G)

        if has_transposed:
            G.adjust_order()

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        group_identity = kwargs.get('group_identity')
        if group_identity == 'pow2_match_group':
            valid_activations = VALID_ACTIVATIONS_POW2
        else:
            valid_activations = VALID_ACTIVATIONS_SQ8
        for conv_node in [
                params for params in G.nodes()
                if isinstance(params, Conv2DParameters)
        ]:
            node_list = self.get_node_list(G, conv_node, valid_activations)
            if node_list is None or len(node_list.order) < 2:
                continue
            if node_list.fusion_type == 'conv_active_pool':
                if node_list.pool.pool_type == "average":
                    node_list.order = node_list.order[:2:]
                    node_list.pool = None
            elif node_list.fusion_type == 'conv_pool_active':
                # NOTE: This is only for old POW2 kernels - SQ8 can handle this
                if node_list.pool.pool_type == "average" and node_list.active.activation != "relu":
                    continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(NNEdge(from_node=last_node,
                                             to_node=node))
                last_node = node
            input_mapping = [[(node_list.conv, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = ConvFusionParameters(
                node_list.conv.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                in_dims_hint=node_list.conv.in_dims_hint,
                out_dims_hint=node_list.conv.out_dims_hint,
                in_dims=deepcopy(node_list.conv.in_dims),
                out_dims=deepcopy(node_list.order[-1].out_dims),
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    # TODO - stats
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=deepcopy(qrecs[0].in_qs),
                                           out_qs=deepcopy(qrecs[-1].out_qs))
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.conv.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 19
0
    def match(self, G: GraphView, set_identity: bool = True):
        rnn_nodes = [
            self.find_unpack(G, node) for node in G.nodes()
            if isinstance(node, RNNBaseParameters)
        ]
        has_modified_graph = False
        for rnn_unpack in rnn_nodes:
            if not rnn_unpack:
                continue
            unpack_node = rnn_unpack[-1]
            rnn_node = rnn_unpack[0]
            time_axis = rnn_node.transpose_out[0].index(
                0) if rnn_node.transpose_out else 0
            if isinstance(unpack_node, StridedSliceParameters):
                if unpack_node.act_slice[time_axis][1] != rnn_node.n_cells:
                    LOG.debug("can't remove %s. Slice not equal to cells",
                              unpack_node.name)
                    continue
                if unpack_node.act_slice[time_axis][2] != 1:
                    LOG.debug("can't remove %s. Slice not of length 1",
                              unpack_node.name)
                    continue
                if unpack_node.act_slice[time_axis][0] != rnn_node.n_cells - 1:
                    LOG.debug("can't remove %s. Slice isn't last cell",
                              unpack_node.name)
                    continue
                out_edge = G.out_edges(unpack_node.name)[0]
            elif isinstance(unpack_node, SplitParameters):
                out_edges = G.out_edges(unpack_node.name)
                if len(out_edges) > 1:
                    LOG.debug("can't remove %s. More than one output edge",
                              unpack_node.name)
                    continue
                out_edge = out_edges[0]
                if out_edge.from_idx != len(unpack_node.act_slices) - 1:
                    LOG.debug("can't remove %s. Not last output",
                              unpack_node.name)
                    continue
                act_slice = unpack_node.act_slices[-1]
                if act_slice[time_axis][1] != rnn_node.n_cells:
                    LOG.debug("can't remove %s. Slice not equal to cells",
                              unpack_node.name)
                    continue
                if act_slice[time_axis][0] != rnn_node.n_cells - 1:
                    LOG.debug("can't remove %s. Slice isn't last cell",
                              unpack_node.name)
                    continue
                out_edge = G.out_edges(unpack_node.name)[0]
            else:
                continue

            has_modified_graph = True
            for node in rnn_unpack[1::]:
                LOG.info("Eliminating last cell unpack: %s", node.name)
                if G.quantization:
                    del G.quantization[NodeId(node)]
                G.remove(node)
            rnn_node.n_output_cells = 1
            rnn_node.out_dims[0] = unpack_node.out_dims[out_edge.from_idx]
            rnn_node.out_dims_hint = [
                unpack_node.out_dims_hint[out_edge.from_idx]
            ]
            rnn_node.transpose_out = None
            G.add_edge(
                NNEdge(rnn_node, out_edge.to_node, to_idx=out_edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        if G.quantization:
            LOG.warning(
                'match_duplicate_operations does not handle quantized graphs')
            return False

        def same_source_edge_fn(x):
            return f"{x.from_node.__hash__()}##{x.from_idx}"

        def same_dest_edge(x):
            return f"{x.to_node.__hash__()}##{x.to_idx}"

        modified_graph = False
        while True:
            found_more = False
            same_source_edges = [
                list(edge_list) for _, edge_list in groupby(
                    sorted(G.edges(), key=same_source_edge_fn),
                    same_source_edge_fn)
            ]
            # all have the same origin
            same_source_edges = [
                elem for elem in same_source_edges if len(elem) > 1
            ]
            same_dest_edges = []
            same_dest_group_edges = []

            for same_source_edge in same_source_edges:
                same_source_edge = [
                    edge for edge in same_source_edge
                    if isinstance(edge.to_node, ComparableParameters)
                ]
                while same_source_edge:
                    first = same_source_edge.pop(0)

                    others = list(
                        filter(
                            partial(
                                lambda x, y: x.to_node != y.to_node and y.
                                to_node.is_same_operation_as(G, x.to_node),
                                first), same_source_edge))
                    if others:
                        same_dest_edges.append(tuple([first] + others))
                        for other in others:
                            same_source_edge.remove(other)
                        continue

                    other_groups = list(
                        filter(
                            partial(
                                lambda x, y: x.to_node != y.to_node and y.
                                to_node.can_be_grouped_with(x.to_node), first),
                            same_source_edge))
                    if other_groups:
                        same_dest_group_edges.append(
                            tuple([first] + other_groups))
                        for other in other_groups:
                            same_source_edge.remove(other)

            # all are multiple edges that go to something comparable
            save_same_dest_edges = same_dest_edges.copy()
            while same_dest_edges:
                edge_set = same_dest_edges.pop(0)
                keep_node = edge_set[0].to_node
                other_edge_sets = [
                    edges for edges in same_dest_edges
                    if any(edge.to_node == keep_node for edge in edges)
                ]
                for other_edge_set in other_edge_sets:
                    same_dest_edges.remove(other_edge_set)

                nodes_to_delete = set()
                for edge_set in [edge_set] + other_edge_sets:
                    for edge in edge_set:
                        other_node = edge.to_node
                        if other_node == keep_node or other_node in nodes_to_delete:
                            continue
                        nodes_to_delete.add(other_node)
                        for out_edge in G.out_edges(other_node):
                            G.add_edge(
                                NNEdge(from_node=keep_node,
                                       to_node=out_edge.to_node,
                                       to_idx=out_edge.to_idx))
                LOG.info(
                    f'removed duplicates {",".join(node.name for node in nodes_to_delete)} to {keep_node.name}'
                )
                for node in nodes_to_delete:
                    G.remove(node)

            # # all are multiple edges that go to something comparable

            # for edge_set in same_dest_edges:
            #     modified_graph = True
            #     found_more = True
            #     first = edge_set[0]
            #     first_node = first.to_node
            #     dup_nodes = []
            #     for other in edge_set[1::]:
            #         dest_node = other.to_node
            #         dup_nodes.append(dest_node.name)
            #         out_edges = G.out_edges(dest_node.name)
            #         G.remove(dest_node)
            #         for out_edge in out_edges:
            #             G.add_edge(NNEdge(from_node=first_node, to_node=out_edge.to_node,
            #                               from_idx=out_edge.from_idx, to_idx=out_edge.to_idx))
            #     LOG.info(
            #         f'removed duplicates {",".join(dup_nodes)} to {first_node.name}')

            for edge_set in same_dest_group_edges:
                modified_graph = True
                found_more = True
                # we will merge all the convolutions into one
                first = edge_set[0]
                first_node = first.to_node
                in_edges = G.indexed_in_edges(first_node.name)
                first_filter = first_node.filter
                weights_node = in_edges[1].from_node
                biases_node = in_edges[2].from_node
                dup_nodes = []
                num_convs = len(edge_set)
                out_shape = deepcopy(first_node.out_dims[0])
                out_shape.c *= num_convs
                # create a split after the first node splitting on channel axis
                act_slices, out_shapes, axis = SplitParameters.get_splits(
                    out_shape,
                    out_shape.get_order_idx('c'),
                    num_splits=num_convs)
                split1 = SplitParameters(
                    G.unique_name(f'{first_node.name}_split'),
                    act_slices=act_slices,
                    out_shapes=out_shapes,
                    axis=axis)
                out_num = 0
                # first node out edge goes to split
                out_edges = G.out_edges(first_node.name)
                for edge in out_edges:
                    G.remove_edge(edge)
                    G.add_edge(
                        NNEdge(from_node=split1,
                               from_idx=out_num,
                               to_node=edge.to_node,
                               to_idx=edge.to_idx))
                G.add_edge(NNEdge(from_node=first_node, to_node=split1))
                # first split output goes to original output
                for other in edge_set[1::]:
                    out_num += 1
                    node_other = other.to_node
                    dup_nodes.append(node_other.name)
                    in_edges = G.indexed_in_edges(node_other.name)
                    weights_other = in_edges[1].from_node
                    biases_other = in_edges[2].from_node
                    # merge the weights and biases diwn output channel
                    weights_node.value = np.concatenate(
                        (weights_node.value, weights_other.value),
                        axis=first_filter.get_order_idx('out_c'))
                    weights_node.dims = Dim.unnamed(weights_node.value.shape)
                    biases_node.value = np.concatenate(
                        (biases_node.value, biases_other.value))
                    biases_node.dims = Dim.unnamed(biases_node.value.shape)
                    first_filter.out_c += node_other.filter.out_c
                    # wire edge from split
                    out_edges = G.out_edges(node_other.name)
                    G.remove(node_other)
                    G.remove(weights_other)
                    G.remove(biases_other)
                    for edge in out_edges:
                        G.add_edge(
                            NNEdge(from_node=split1,
                                   from_idx=out_num,
                                   to_node=edge.to_node,
                                   to_idx=edge.to_idx))
                LOG.info(
                    f'merged convolutions {",".join(dup_nodes)} into {first_node.name}'
                )
            if not found_more:
                break

        if set_identity:
            self.set_identity(G)

        return modified_graph
Esempio n. 21
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        slices_by_origin = {}
        for slice_node in [
                node for node in G.nodes()
                if isinstance(node, StridedSliceParameters)
        ]:
            in_edge = G.in_edges(slice_node.name)[0]
            group = slices_by_origin.setdefault(
                (in_edge.from_node, in_edge.from_idx), [])
            group.append(slice_node)
        for in_edge, slice_nodes in slices_by_origin.items():
            slices = list(zip(*[node.act_slice for node in slice_nodes]))
            if len(slice_nodes) == 1:
                self.slice_to_split(G, slice_nodes, slices)
                continue

            diff_slices = [(idx, elems) for idx, elems in enumerate(slices)
                           if not all(elems[0] == elem for elem in elems[1::])]
            if len(diff_slices) != 1:
                continue
            # strides must be one
            if any(sl[2] != 1 for sl in diff_slices[0][1]):
                continue
            # check if slices are consecutive and non overlapping
            slices = sorted(diff_slices[0][1], key=lambda x: x[0])
            if not all(sl[0] + sl[1] == slices[i + 1][0]
                       for i, sl in enumerate(slices[:-1:])):
                continue
            szes = [sl[1] - sl[0] for sl in slices]
            axis = diff_slices[0][0]
            slice_nodes = sorted(slice_nodes,
                                 key=lambda x: x.act_slice[axis][0])
            act_slices, out_shapes, axis = SplitParameters.get_splits(
                slice_nodes[0].in_dims[0].shape, axis, splits=szes)
            params = SplitParameters(slice_nodes[0].name + '_split',
                                     act_slices=act_slices,
                                     out_shapes=out_shapes,
                                     axis=axis)
            in_edge = G.in_edges(slice_nodes[0].name)[0]
            G.add_edge(
                NNEdge(from_node=in_edge.from_node,
                       to_node=params,
                       from_idx=in_edge.from_idx))
            sub_names = []
            for idx, node in enumerate(slice_nodes):
                sub_names.append(node.name)
                out_edges = G.out_edges(node.name)
                G.remove(node)
                for out_edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=params,
                               to_node=out_edge.to_node,
                               from_idx=idx,
                               to_idx=out_edge.to_idx))
            if G.quantization:
                G.add_dimensions()
                quantizer = UnifiedQuantizer.from_quantized_graph(G)
                quantizer.quantize(G, start_nodes=[params])
                RemoveUnnecessaryQuantizeOperators().match(G)

            LOG.info(
                f'replaced slice nodes {",".join(sub_names)} with split node {sub_names[0]}'
            )

            has_modified_graph = True

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 22
0
    def match(self, G: GraphView, set_identity: bool = True) -> bool:
        has_modified_graph = False
        for pad_params in [
                pad for pad in G.nodes() if isinstance(pad, PadParameters)
        ]:
            pad_in_edges = G.in_edges(pad_params.name)
            pad_out_edges = G.out_edges(pad_params.name)
            dont_delete = False
            for pad_out_edge in pad_out_edges:
                filter_like_node, is_1d = self.find_conv(
                    G, pad_out_edge.to_node)
                if not filter_like_node:
                    dont_delete = True
                    continue
                if not filter_like_node.in_dims_hint or not filter_like_node.in_dims_hint[
                        0]:
                    raise ValueError(
                        f"filter {filter_like_node.name} doesn't have a input hint"
                    )
                in_hint = filter_like_node.in_dims_hint[0]
                if is_1d:
                    if len(pad_params.padding) != 2:
                        LOG.warning(
                            "pad node %s is applied to 1d convolution but has length %s",
                            pad_params.name, len(pad_params.padding))
                        dont_delete = True
                        continue
                    expanded_padding = [
                        pad_params.padding[0], (0, 0), pad_params.padding[1]
                    ]
                else:
                    if len(pad_params.padding) != 3:
                        LOG.warning(
                            "pad node %s is applied to 2d convolution but has length %s",
                            pad_params.name, len(pad_params.padding))
                        dont_delete = True
                        continue
                    expanded_padding = pad_params.padding

                hinted_pad = {
                    in_hint[idx]: pad
                    for idx, pad in enumerate(expanded_padding) if sum(pad) > 0
                }
                key_set = set(hinted_pad.keys())
                key_set -= set(['h', 'w'])
                if len(key_set) > 0:
                    dont_delete = True
                    LOG.error(
                        "node %s has padding on axes %s and cannot be fused with filter %s",
                        pad_params.name, key_set, filter_like_node.name)
                    continue
                if any(pval != 0 for val in pad_params.pad_vals
                       for pval in val):
                    dont_delete = True
                    LOG.error(
                        "node %s has non zero pad values and cannot be fused with filter %s",
                        pad_params.name, filter_like_node.name)
                    continue

                LOG.info("adding padding from: %s to %s filter: %s",
                         pad_params.name, is_1d and "1D" or "2D",
                         filter_like_node.name)

                for key in ['h', 'w']:
                    if key not in hinted_pad:
                        hinted_pad[key] = (0, 0)

                filter_like_node.padding = PadDim(*(list(hinted_pad['h']) +
                                                    list(hinted_pad['w'])))
                filter_like_node.pad_type = "zero"
                has_modified_graph = True
                G.remove_edge(pad_out_edge)
                if is_1d:
                    reshape_node = pad_out_edge.to_node
                    reshape_node.old_shape = self.remove_padding(
                        reshape_node.old_shape, pad_params.padding)
                    reshape_node.shape = self.remove_padding(
                        reshape_node.shape, expanded_padding)
                for in_edge in pad_in_edges:
                    G.add_edge(
                        NNEdge(from_node=in_edge.from_node,
                               to_node=pad_out_edge.to_node,
                               from_idx=in_edge.from_idx,
                               to_idx=pad_out_edge.to_idx))

            if not dont_delete:
                G.remove(pad_params)
                if G.quantization:
                    G.quantization.remove_node(pad_params)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 23
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool:
        has_modified_graph = False
        for pad_params in [pad for pad in G.nodes() if isinstance(pad, PadParameters)]:
            pad_in_edges = G.in_edges(pad_params.name)
            pad_out_edges = G.out_edges(pad_params.name)
            dont_delete = False
            if len(pad_in_edges) == 1 and all(sum(padding) == 0 for padding in pad_params.padding):
                LOG.info("removing zero padding node %s",
                         pad_params.name)
                G.remove(pad_params)
                if G.quantization:
                    G.quantization.remove_node(pad_params)
                dont_delete = True
                in_edge = pad_in_edges[0]
                for out_edge in pad_out_edges:
                    G.add_edge(NNEdge(from_node=in_edge.from_node,
                                      to_node=out_edge.to_node,
                                      from_idx=in_edge.from_idx,
                                      to_idx=out_edge.to_idx))
            else:
                for pad_out_edge in pad_out_edges:
                    filter_like_node, expanded_padding, reshapes = self.find_conv(
                        G, pad_out_edge.to_node, pad_params.padding)
                    if not filter_like_node:
                        dont_delete = True
                        continue
                    if not filter_like_node.in_dims_hint or not filter_like_node.in_dims_hint[0]:
                        raise ValueError(
                            f"filter {filter_like_node.name} doesn't have a input hint")
                    in_hint = filter_like_node.in_dims_hint[0]

                    hinted_pad = {in_hint[idx]: pad for idx,
                                  pad in enumerate(expanded_padding) if sum(pad) > 0}
                    key_set = set(hinted_pad.keys())
                    key_set -= set(['h', 'w'])
                    if len(key_set) > 0:
                        dont_delete = True
                        LOG.error("node %s has padding on axes %s and cannot be fused with filter %s",
                                  pad_params.name, key_set, filter_like_node.name)
                        continue
                    if any(pval != 0 for val in pad_params.pad_vals for pval in val):
                        dont_delete = True
                        LOG.error("node %s has non zero pad values and cannot be fused with filter %s",
                                  pad_params.name, filter_like_node.name)
                        continue

                    LOG.info("adding padding from: %s to filter: %s - has %s reshapes",
                             pad_params.name, filter_like_node.name, len(reshapes))

                    for key in ['h', 'w']:
                        if key not in hinted_pad:
                            hinted_pad[key] = (0, 0)

                    filter_like_node.padding = PadDim(
                        *(list(hinted_pad['h']) + list(hinted_pad['w'])))
                    filter_like_node.pad_type = "zero"
                    has_modified_graph = True
                    G.remove_edge(pad_out_edge)
                    for reshape_node, old_padding, new_padding in reshapes:
                        reshape_node.old_shape = self.remove_padding(
                            reshape_node.old_shape, old_padding)
                        reshape_node.shape = self.remove_padding(
                            reshape_node.shape, new_padding)

                    for in_edge in pad_in_edges:
                        G.add_edge(NNEdge(from_node=in_edge.from_node,
                                          to_node=pad_out_edge.to_node,
                                          from_idx=in_edge.from_idx,
                                          to_idx=pad_out_edge.to_idx))

            if not dont_delete:
                G.remove(pad_params)
                if G.quantization:
                    G.quantization.remove_node(pad_params)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 24
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for pad_node in [
                params for params in G.nodes()
                if isinstance(params, PadParameters)
        ]:
            node_list = self.get_node_list(G, pad_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx
            subgraph.add_edge(
                NNEdge(from_node=node_list.pad,
                       to_node=node_list.add,
                       to_idx=padded_input_idx))
            last_node = node_list.add
            node_list.add.force_quantized_index = 0
            if node_list.active:
                subgraph.add_edge(
                    NNEdge(from_node=node_list.add, to_node=node_list.active))
                last_node = node_list.active
            if padded_input_idx == 0:
                input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]]
            else:
                input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]]

            output_mapping = [(last_node, 0)]
            pnode = PaddedAddFusionParameters(
                "PADDED_" + node_list.add.name,
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = None
                    if isinstance(qrecs[0],
                                  (SymmetricQuantizationRecord,
                                   SymmetricScalableFilterQuantizationRecord)):
                        prec = SymmetricQuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0],
                                    (MultQuantizationRecord,
                                     MultScalableFilterQuantizationRecord)):
                        prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs,
                                                      out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0],
                                    (Float32QuantizationRecord,
                                     Float32ScalableFilterQuantizationRecord)):
                        prec = Float32QuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            if padded_input_idx == 0:
                in_edges = G.in_edges(node_list.pad.name) + G.indexed_in_edges(
                    node_list.add.name)[1::]
            else:
                in_edges = G.indexed_in_edges(
                    node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 25
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for matmul_node in [
                params for params in G.nodes()
                if isinstance(params, MatMulOpParameters)
        ]:
            node_list = self.get_node_list(G, matmul_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            if node_list.active is not None:
                subgraph.add_edge(
                    NNEdge(from_node=node_list.matmul,
                           to_node=node_list.active))
            input_mapping = [[(node_list.matmul, idx)] for idx in range(2)]
            if node_list.add:
                input_mapping += [[(node_list.matmul, 2)]]
            output_mapping = [(node_list.active,
                               0)] if node_list.active else [(node_list.matmul,
                                                              0)]
            pnode = MatMulOpFusionParameters(node_list.matmul.name + '_fusion',
                                             fusion_type=node_list.fusion_type,
                                             subgraph=subgraph,
                                             input_mapping=input_mapping,
                                             output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = None
                    if isinstance(qrecs[0],
                                  (SymmetricQuantizationRecord,
                                   SymmetricScalableFilterQuantizationRecord)):
                        prec = SymmetricQuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0],
                                    (MultQuantizationRecord,
                                     MultScalableFilterQuantizationRecord)):
                        prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs,
                                                      out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0],
                                    (Float32QuantizationRecord,
                                     Float32ScalableFilterQuantizationRecord)):
                        prec = Float32QuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.matmul.name)
            if node_list.add:
                bias_edge = [
                    add_edge for add_edge in G.in_edges(node_list.add.name)
                    if isinstance(add_edge.from_node, ConstantInputParameters)
                ][0]
            out_edges = G.out_edges(node_list.order[-1].name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            if node_list.add:
                G.add_edge(
                    NNEdge(bias_edge.from_node,
                           pnode,
                           from_idx=bias_edge.from_idx,
                           to_idx=2))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 26
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        filter_nodes = [
            node for node in G.nodes() if isinstance(node, FilterParameters)
        ]
        for filter_node in filter_nodes:
            while True:
                out_edges = G.out_edges(filter_node.name)
                # can't fuse if there is a branch
                if len(out_edges) > 1:
                    break
                out_edge = out_edges[0]
                op_node = out_edge.to_node
                # must be a valid matrix op
                if not isinstance(op_node, tuple(OPS.keys())):
                    break
                # other edge to the op must be a constant
                other_idx = 1 if out_edge.to_idx == 0 else 0
                other_in_edge = G.indexed_in_edges(op_node.name)[other_idx]
                if not isinstance(other_in_edge.from_node,
                                  ConstantInputParameters):
                    break
                const_node = other_in_edge.from_node
                remove_constant = len(G.out_edges(const_node.name))

                flat_value = const_node.dqvalue.flatten()
                out_c = filter_node.filter.out_c
                op, weights_and_biases = OPS[op_node.__class__]
                # it would be possible to support mult bias addition by out channel but only supporting a
                # scalar at present
                if len(flat_value) != 1 and (weights_and_biases
                                             or len(flat_value) != out_c):
                    LOG.warning('could not absorb %s into %s', const_node.name,
                                filter_node.name)
                    break
                # If there is quantization then essentially the output of the filter
                # takes the quantization of the output of the operation.
                # The biases will not change since their quantization depends on the weights
                # and input
                fnid = NodeId(filter_node)
                opnid = NodeId(op_node)
                if G.quantization and (fnid in G.quantization
                                       or opnid in G.quantization):
                    if not (fnid in G.quantization
                            and opnid in G.quantization):
                        LOG.warning(
                            'could not absorb %s into %s - graph is partially quantized',
                            const_node.name, filter_node.name)
                        break
                    fqrec = G.quantization[fnid]
                    opqrec = G.quantization[opnid]
                    fqrec.out_qs[0] = opqrec.out_qs[0]

                has_modified_graph = True
                LOG.info("fusing bias in %s into %s", const_node.name,
                         filter_node.name)
                self.fuse_bias(G, filter_node, other_idx, op, flat_value, 2)
                if weights_and_biases:
                    # TODO - need to adjust weights quantization here
                    LOG.info("fusing multiplicative bias in %s into %s",
                             const_node.name, filter_node.name)
                    self.fuse_bias(G, filter_node, other_idx, op, flat_value,
                                   1)

                out_edges = G.out_edges(op_node.name)
                G.remove(op_node)
                if remove_constant:
                    G.remove(const_node)
                for edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=filter_node,
                               to_node=edge.to_node,
                               to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 27
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        slices_by_origin = {}
        for slice_node in [
                node for node in G.nodes()
                if isinstance(node, StridedSliceParameters)
        ]:
            in_edge = G.in_edges(slice_node.name)[0]
            group = slices_by_origin.setdefault(
                (in_edge.from_node, in_edge.from_idx), [])
            group.append(slice_node)
        for in_edge, slice_nodes in slices_by_origin.items():
            slices = list(zip(*[node.act_slice for node in slice_nodes]))
            if len(slice_nodes) == 1:
                self.slice_to_split(G, slice_nodes, slices)
                continue

            # strides must be one
            if any(sl[2] != 1 for sl_axis in slices for sl in sl_axis):
                continue

            diff_axes = list([
                idx for idx, elems in enumerate(slices)
                if not all(elems[0] == elem for elem in elems[1::])
            ])
            not_diff_axes = [
                idx for idx in range(len(slices)) if idx not in diff_axes
            ]
            diff_slices = [
                sl for idx, sl in enumerate(slices) if idx in diff_axes
            ]
            axis_lengths = in_edge[0].out_dims[in_edge[1]].shape
            if not_diff_axes and min(not_diff_axes) < max(diff_axes):
                transpose_from = tuple(range(len(slices)))
                transpose_to = tuple(diff_axes + not_diff_axes)
                axis_lengths = [axis_lengths[idx] for idx in transpose_to]
            else:
                transpose_from = transpose_to = None
            diff_axis_lengths = axis_lengths[0:len(diff_axes):]

            diff_slices = combine_slices(diff_axis_lengths, diff_slices,
                                         slice_nodes)
            if diff_slices is None:
                continue

            if len(diff_axes) > 1:
                reshape_from = axis_lengths
                reshape_to = [np.prod(diff_axis_lengths)] + \
                    axis_lengths[len(diff_axes)::]
            else:
                reshape_from = None
                reshape_to = slice_nodes[0].in_dims[0].shape
                if transpose_from:
                    reshape_to = [reshape_to[idx] for idx in transpose_to]

            sizes, shapes, sorted_nodes = slices_to_sizes(
                diff_slices, axis_lengths[len(diff_axes)::])

            name_prefix = sorted_nodes[0].name

            in_edge = G.in_edges(sorted_nodes[0].name)[0]
            in_node = in_edge.from_node
            in_idx = in_edge.from_idx

            if transpose_from:
                params = TransposeParameters(G.unique_name(name_prefix +
                                                           '_tin'),
                                             transpose=transpose_to)
                G.add_edge(
                    NNEdge(from_node=in_node, to_node=params, from_idx=in_idx))
                in_node = params
                in_idx = 0

            if reshape_from:
                params = ReshapeParameters(G.unique_name(name_prefix +
                                                         '_reshape'),
                                           old_shape=Dim.unnamed(reshape_from),
                                           shape=Dim.unnamed(reshape_to))
                G.add_edge(
                    NNEdge(from_node=in_node, to_node=params, from_idx=in_idx))
                in_node = params
                in_idx = 0

            act_slices, out_shapes, axis = SplitParameters.get_splits(
                reshape_to, 0, splits=sizes)
            split_node = SplitParameters(G.unique_name(name_prefix + '_split'),
                                         act_slices=act_slices,
                                         out_shapes=out_shapes,
                                         axis=axis)

            G.add_edge(
                NNEdge(from_node=in_node, from_idx=in_idx, to_node=split_node))

            sub_names = []
            for idx, node in enumerate(sorted_nodes):
                sub_names.append(node.name)
                out_edges = G.out_edges(node.name)
                G.remove(node)
                for out_edge in out_edges:
                    params = split_node
                    out_idx = idx
                    if reshape_from:
                        from_node = params
                        params = ReshapeParameters(
                            G.unique_name(name_prefix + f'_reshape{idx}'),
                            shape=Dim.unnamed(shapes[idx]))
                        G.add_edge(
                            NNEdge(from_node=from_node,
                                   to_node=params,
                                   from_idx=out_idx))
                        out_idx = 0
                    if transpose_from:
                        from_node = params
                        params = TransposeParameters(
                            G.unique_name(name_prefix + f'_tout{idx}'),
                            transpose=reverse_transpose(transpose_to))
                        G.add_edge(
                            NNEdge(from_node=from_node,
                                   to_node=params,
                                   from_idx=out_idx))
                        out_idx = 0

                    G.add_edge(
                        NNEdge(from_node=params,
                               to_node=out_edge.to_node,
                               from_idx=out_idx,
                               to_idx=out_edge.to_idx))
            if G.quantization:
                G.add_dimensions()
                quantizer = NewQuantizer.from_quantized_graph(G)
                quantizer.quantize()
                RemoveUnnecessaryQuantizeOperators().match(G)

            LOG.info(
                f'replaced slice nodes {",".join(sub_names)} with split node {split_node.name}'
            )

            has_modified_graph = True

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 28
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        gathers_by_origin = {}
        for gather in [
                node for node in G.nodes()
                if isinstance(node, GatherParameters)
        ]:
            in_edge = G.in_edges(gather.name)[0]
            group = gathers_by_origin.setdefault(
                (in_edge.from_node, in_edge.from_idx), [])
            group.append(gather)
        for in_edge, gathers in gathers_by_origin.items():
            # This is too difficult to handle if there are multiple slices
            axis = gathers[0].axis
            if not all(gather.axis == axis and len(gather.indices.shape) <= 1
                       for gather in gathers[1::]):
                continue
            # sort all the indices
            gathers = sorted(gathers,
                             key=lambda x: x.indices
                             if len(x.indices.shape) == 0 else x.indices[0])
            indices = [
                elem for gather in gathers
                for elem in ([int(gather.indices)] if len(gather.indices.shape)
                             == 0 else list(gather.indices))
            ]
            # All the indices must be independant and sum to the out dim (this could be relaxed but
            # then needs to handle gaps)
            in_shape = in_edge[0].out_dims[in_edge[1]].shape
            in_shape_without_axis = in_shape[:axis:] + in_shape[axis + 1::]
            if len(set(indices)) != len(indices) and len(
                    set(indices)) == in_shape[axis]:
                continue
            # good for a split
            LOG.info("gathers from %s[%s] converted to a split",
                     in_edge[0].name, in_edge[1])
            splits = []
            shapes = []
            out_edges = []
            for gather in gathers:
                splits.append(
                    [tuple([int(gather.indices),
                            int(gather.indices) + 1, 1])])
                shapes.append(in_shape_without_axis)
                out_edges.append(G.out_edges(gather.name))
                G.remove(gather)
            params = SplitParameters("%s_split" % in_edge[0].name,
                                     act_slices=splits,
                                     out_shapes=shapes,
                                     axis=axis)
            if axis != 0:
                trans = [axis] + list(range(0, axis)) + list(
                    range(axis, len(in_shape)))
                params.transpose_out = [[
                    trans.index(idx) for idx in range(len(trans))
                ]]
                params.transpose_in = [trans]
            for idx, edges in enumerate(out_edges):
                for edge in edges:
                    G.add_edge(
                        NNEdge(from_node=params,
                               to_node=edge.to_node,
                               from_idx=idx,
                               to_idx=edge.to_idx))
            G.add_edge(
                NNEdge(from_node=in_edge[0],
                       to_node=params,
                       from_idx=in_edge[1]))
            has_modified_graph = True

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 29
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        group_identity = kwargs.get('group_identity')
        if group_identity == 'pow2_match_group':
            valid_activations = VALID_ACTIVATIONS_POW2
            valid_activations_wo_pool = VALID_ACTIVATIONS_POW2_WO_POOL
        else:
            valid_activations = VALID_ACTIVATIONS_SQ8
            valid_activations_wo_pool = VALID_ACTIVATIONS_SQ8_WO_POOL
        for pool_node in G.nodes(node_classes=(PoolingParameters,
                                               GlobalPoolingParameters)):
            node_list = self.get_node_list(G, pool_node, valid_activations,
                                           valid_activations_wo_pool)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(NNEdge(from_node=last_node,
                                             to_node=node))
                last_node = node
            input_mapping = [[(node_list.pool, 0)]]
            output_mapping = [(last_node, 0)]
            pnode = ActivationFusion(node_list.pool.name + '_fusion',
                                     fusion_type=node_list.fusion_type,
                                     subgraph=subgraph,
                                     input_mapping=input_mapping,
                                     output_mapping=output_mapping)
            if G.quantization:
                # TODO - stats
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=qrecs[0].in_qs,
                                           out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                        if isinstance(node, GlobalPoolingParameters):
                            # Global pooling fused with activations need to have only the activation scale
                            G.quantization[NodeId(pnode,
                                                  node)].out_qs[0] = deepcopy(
                                                      G.quantization[NodeId(
                                                          pnode,
                                                          node)].in_qs[0])
                            G.quantization[NodeId(
                                pnode, node)].out_qs[0].dtype = np.int32
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.pool.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 30
0
    def match(self, G: GraphView, set_identity: bool = True):
        def same_source_edge(x):
            return f"{x.from_node.__hash__()}##{x.from_idx}"

        def same_dest_edge(x):
            return f"{x.to_node.__hash__()}##{x.to_idx}"

        modified_graph = False
        same_source_edges = [
            list(edge_list) for _, edge_list in groupby(
                sorted(G.edges(), key=same_source_edge), same_source_edge)
        ]
        # all have the same origin
        same_source_edges = [
            elem for elem in same_source_edges if len(elem) > 1
        ]
        same_dest_edges = []

        for same_source_edge in same_source_edges:
            same_source_edge = [
                edge for edge in same_source_edge
                if isinstance(edge.to_node, ComparableParameters)
            ]
            while same_source_edge:
                first = same_source_edge.pop(0)
                others = list(
                    filter(
                        partial(
                            lambda x, y: y.to_node.is_same_operation_as(
                                x.to_node), first), same_source_edge))
                if others:
                    same_dest_edges.append(tuple([first] + others))
                    for other in others:
                        same_source_edge.remove(other)

        # all are multiple edges that go to something comparable

        for edge_set in same_dest_edges:
            first = edge_set[0]
            first_node = first.to_node
            dup_nodes = []
            for other in edge_set[1::]:
                modified_graph = True
                dest_node = other.to_node
                dup_nodes.append(dest_node.name)
                out_edges = G.out_edges(dest_node.name)
                G.remove(dest_node)
                for out_edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=first_node,
                               to_node=out_edge.to_node,
                               from_idx=out_edge.from_idx,
                               to_idx=out_edge.to_idx))
            LOG.info(
                f'removed duplicates {",".join(dup_nodes)} to {first_node.name}'
            )

        if set_identity:
            self.set_identity(G)

        return modified_graph