Ejemplo n.º 1
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False

        for node in G.nodes(node_classes=MatMulOpParameters):
            in_edges = [edge for edge in G.indexed_in_edges(node.name)]
            trans_node = in_edges[1].from_node
            if not isinstance(trans_node, TransposeParameters):
                continue
            if isinstance(node, MatMulTransposedParameters):
                new_node = MatMulOpParameters(node.name)
            else:
                new_node = MatMulTransposedParameters(node.name)

            in_trans_edge = [
                edge for edge in G.indexed_in_edges(trans_node.name)
            ][0]
            G.replace_node(node.name, new_node)
            G.remove(trans_node)
            G.add_edge(
                NNEdge(in_trans_edge.from_node,
                       new_node,
                       from_idx=in_trans_edge.from_idx,
                       to_idx=1))
            has_modified_graph = True

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Ejemplo n.º 2
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        fragment = GraphMatcher(
            match_function=lambda state, frag: (frag, state['match']))
        fragment.add_node(MatScaleNodeMatch())
        has_modified_graph = False
        for frag, match in fragment.match_graph(G):
            match_edges = [
                G.indexed_in_edges(node.name)[idx]
                for node, idx in match['inputs']
            ]
            matched_node = list(frag.nodes())[0]
            out_edges = G.out_edges(matched_node.name)
            has_modified_graph = True
            G.remove(matched_node)
            fnode = MatScaleFusionParameters(
                "{}_fusion".format(matched_node.name),
                fusion_type=match['type'],
                subgraph=frag,
                input_mapping=[[(matched_node, 0)], [(matched_node, 1)]])
            G.add_node(fnode)
            for idx, edge in enumerate(match_edges):
                edge.to_node = fnode
                edge.to_idx = idx
                G.add_edge(edge)
            for edge in out_edges:
                edge.from_node = fnode
                G.add_edge(edge)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
def reverse_matmul(G: GraphView, params):
    # reverse edges
    in_edges = G.indexed_in_edges(params.name)
    for edge in in_edges[0:2:]:
        G.remove_edge(edge)
    other_idx = 1
    for edge in in_edges[0:2:]:
        G.add_edge(
            NNEdge(from_node=edge.from_node,
                   to_node=params,
                   from_idx=edge.from_idx,
                   to_idx=other_idx))
        other_idx = 1 - other_idx
    nid = NodeId(params)
    if G.quantization and nid in G.quantization:
        qrec = G.quantization[nid]
        # swap qrecs
        qrec.in_qs[0], qrec.in_qs[1] = qrec.in_qs[1], qrec.in_qs[0]

    # add transposes
    in_nodes = []
    for idx in range(2):
        tin_params = TransposeParameters(
            G.unique_name(f"{params.name}_tin{idx}"), transpose=(1, 0))
        in_nodes.append(tin_params)
        G.insert_node_before(tin_params, params, to_idx=idx, edge_class=NNEdge)
    tout_params = TransposeParameters(G.unique_name(f"{params.name}_tout"),
                                      transpose=(1, 0))
    G.insert_node_after(params, tout_params)
    return in_nodes, tout_params
Ejemplo n.º 4
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        something_changed = False
        filt_nodes = [node for node in G.nodes()
                      if isinstance(node, (Conv2DParameters, ConvFusionParameters))]
        for filt_node in filt_nodes:
            pnode = filt_node
            if isinstance(filt_node, ConvFusionParameters):
                cnodes = filt_node.contained_nodes()
                filt_node = cnodes[0]
            if not isinstance(filt_node, Conv2DParameters):
                continue
            in_dim = filt_node.in_dims
            filt_dim = filt_node.filter
            if filt_dim.h <= in_dim[0].h and filt_dim.w <= in_dim[0].w:
                continue

            min_h = min(filt_dim.h, in_dim[0].h)
            min_w = min(filt_dim.w, in_dim[0].w)
            if min_h > 1 and min_w > 1:
                LOG.warning("Filter of %s [%dx%d] bigger than input [%dx%d] not optimal but will work on AT",
                            filt_node.name, filt_dim.h, filt_dim.w, in_dim[0].h, in_dim[0].w)
                continue

            ker_h = 1 if min_h == 1 else filt_dim.h
            ker_w = 1 if min_w == 1 else filt_dim.w
            if ker_h == filt_dim.h and ker_w == filt_dim.w:
                continue
            new_filt_dim = Conv2DFilterDim(
                ker_h, ker_w, filt_dim.out_c, in_c=filt_dim.in_c)
            LOG.warning("Converting filter of %s from [%dx%d] -> [%dx%d]",
                        filt_node.name, filt_dim.h, filt_dim.w, new_filt_dim.h, new_filt_dim.w)
            filt_node.filter = new_filt_dim
            new_w_idxs = []
            for dim in filt_dim.order:
                if dim in ('out_c', 'in_c'):
                    new_w_idxs.append(slice(None))
                elif dim == 'h':
                    if new_filt_dim.h == 1:
                        new_w_idxs.append(
                            slice(filt_node.padding.t, filt_node.padding.t + 1))
                    else:
                        new_w_idxs.append(slice(0, new_filt_dim.h))
                elif dim == 'w':
                    if new_filt_dim.w == 1:
                        new_w_idxs.append(
                            slice(filt_node.padding.l, filt_node.padding.l + 1))
                    else:
                        new_w_idxs.append(slice(0, new_filt_dim.w))
            weights_node = G.indexed_in_edges(pnode.name)[1].from_node
            weights_node.value = weights_node.value[tuple(new_w_idxs)]
            weights_node.dims = Dim.unnamed(weights_node.value.shape)
            something_changed = True

        if set_identity:
            self.set_identity(G)

        return something_changed
Ejemplo n.º 5
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        for split_node in set(
            [node for node in G.nodes() if isinstance(node, SplitParameters)]):
            in_edges = G.in_edges(split_node.name)
            if len(in_edges) > 1:
                continue
            in_edge = in_edges[0]
            if not isinstance(in_edge.from_node, ConcatParameters):
                continue
            concat_node = in_edge.from_node
            if len(G.out_edges(concat_node.name)) > 1:
                continue
            if concat_node.transpose_out or split_node.transpose_in:
                continue
            if concat_node.axis != split_node.axis:
                continue
            axis = concat_node.axis
            split_out_sizes = [
                out_shape[axis] for out_shape in split_node.out_shapes
            ]
            if len(split_out_sizes) != len(concat_node.in_dims):
                continue
            if not all(split_out_sizes[idx] == in_dim.shape[axis]
                       for idx, in_dim in enumerate(concat_node.in_dims)):
                continue
            has_modified_graph = True
            LOG.info("removing unnecessary concat/split pair %s/%s",
                     concat_node.name, split_node.name)
            concat_in_edges = G.indexed_in_edges(concat_node.name)
            split_out_edges = G.indexed_out_edges(split_node.name)
            G.remove(split_node)
            G.remove(concat_node)
            for idx, in_edge in enumerate(concat_in_edges):
                for out_edge in split_out_edges[idx]:
                    G.add_edge(
                        NNEdge(from_node=in_edge.from_node,
                               from_idx=in_edge.from_idx,
                               to_node=out_edge.to_node,
                               to_idx=out_edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Ejemplo n.º 6
0
    def move_constant(cls, G: GraphView, params, in_qs):
        # looks for a constant on one of the inputs
        # if there is one we can scale by the second dimension of the second
        # tensor. If the constant is on the first tensor then move to the second
        # and transpose the operation
        in_edges = G.indexed_in_edges(params.name)
        in1_node = in_edges[0].from_node
        in2_node = in_edges[1].from_node

        if isinstance(in2_node, ConstantInputParameters):
            return in2_node, in_qs
        elif isinstance(in1_node, ConstantInputParameters):
            if len(params.in_dims) > 2:
                # check if the bias has the correct length to move constant
                # it must have a length equal to the second tensors second dimension after transpose
                bias_size = params.in_dims[2].size()
                in1_shape = params.in_dims[0].shape
                if in1_shape[1] != bias_size:
                    return None, in_qs
            for edge in in_edges[:2:]:
                G.remove_edge(edge)
            to_idx = 1
            # swap edges to move constant onto input 2
            for edge in in_edges[:2:]:
                new_edge = NNEdge(from_node=edge.from_node,
                                  to_node=edge.to_node,
                                  from_idx=edge.from_idx,
                                  to_idx=to_idx)
                G.add_edge(new_edge)
                to_idx = 1 - to_idx
            # use A.B = (BT.AT)T identity
            tin1 = TransposeParameters(G.unique_name(f'{params.name}_tin1'),
                                       transpose=(1, 0))
            tin2 = TransposeParameters(G.unique_name(f'{params.name}_tin2'),
                                       transpose=(1, 0))
            tout = TransposeParameters(G.unique_name(f'{params.name}_tout'),
                                       transpose=(1, 0))
            G.insert_node_before(tin1, params)
            G.insert_node_before(tin2, params, to_idx=1)
            G.insert_node_after(params, tout)
            LOG.warning('transposes inserted on %s - rerun adjust',
                        params.name)
            return in1_node, [in_qs[1], in_qs[0]] + in_qs[2::]
        else:
            return None, in_qs
Ejemplo n.º 7
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        fac = MatScalePairMatchFactory()
        has_modified_graph = False

        for frag, match in fac.get_matcher().match_graph(G):
            match_edges = [
                G.indexed_in_edges(node.name)[idx] for node, idx in match
            ]
            first_node = frag.inputs()[0]
            last_node = frag.outputs()[0]
            out_edges = G.out_edges(last_node.name)
            for node in frag.nodes():
                G.remove(node)

            input_mapping = MatScaleFusionParameters.get_mapping_from_edges(
                match_edges)

            fnode = MatScaleFusionParameters(
                "{}_{}_fusion".format(first_node.name, last_node.name),
                fusion_type="vec_scalar",
                subgraph=frag,
                input_mapping=MatScaleFusionParameters.convert_input_mapping(
                    input_mapping))
            has_modified_graph = True
            G.add_node(fnode)
            fnode.in_dims_hint = [None] * 3

            for idx, edge in enumerate(match_edges):
                new_edge = edge.clone(
                    to_node=fnode,
                    to_idx=list(input_mapping[edge.to_node].keys())[0])
                if new_edge.from_node.out_dims_hint:
                    fnode.in_dims_hint[idx] = new_edge.from_node.out_dims_hint[
                        edge.from_idx]
                G.add_edge(new_edge)
            for edge in out_edges:
                new_edge = edge.clone(from_node=fnode)
                G.add_edge(new_edge)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Ejemplo n.º 8
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):

        candidates = [node for node in G.nodes(node_classes=(SplitParameters, ConcatParameters))]
        need_a_copy_edges = []
        for node in candidates:
            for idx, edge in enumerate(G.indexed_in_edges(node.name)):
                real_from_node, _ = find_real_in_edge(G, edge)
                if isinstance(real_from_node, (InputParameters, ConstantInputParameters)):
                    need_a_copy_edges.append((edge, idx))
        has_modified_graph = False
        for edge in need_a_copy_edges:
            LOG.info(
                "Insert copy on split input %s", edge[0].to_node.name)
            has_modified_graph = True
            cnode = CopyParameters(G.unique_name(f'{edge[0].to_node.name}_copy'))
            G.insert_node_at_edge(cnode, edge[0])
            if G.quantization:
                G.quantization.copy_qrec(edge[0].to_node, 'in', 0, cnode)
        if set_identity:
            self.set_identity(G)
        return has_modified_graph
Ejemplo n.º 9
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for pad_node in [
                params for params in G.nodes()
                if isinstance(params, PadParameters)
        ]:
            node_list = self.get_node_list(G, pad_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx
            subgraph.add_edge(
                NNEdge(from_node=node_list.pad,
                       to_node=node_list.add,
                       to_idx=padded_input_idx))
            last_node = node_list.add
            node_list.add.force_quantized_index = 0
            if node_list.active:
                subgraph.add_edge(
                    NNEdge(from_node=node_list.add, to_node=node_list.active))
                last_node = node_list.active
            if padded_input_idx == 0:
                input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]]
            else:
                input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]]

            output_mapping = [(last_node, 0)]
            pnode = PaddedAddFusionParameters(
                "PADDED_" + node_list.add.name,
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = None
                    if isinstance(qrecs[0],
                                  (SymmetricQuantizationRecord,
                                   SymmetricScalableFilterQuantizationRecord)):
                        prec = SymmetricQuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0],
                                    (MultQuantizationRecord,
                                     MultScalableFilterQuantizationRecord)):
                        prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs,
                                                      out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0],
                                    (Float32QuantizationRecord,
                                     Float32ScalableFilterQuantizationRecord)):
                        prec = Float32QuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            if padded_input_idx == 0:
                in_edges = G.in_edges(node_list.pad.name) + G.indexed_in_edges(
                    node_list.add.name)[1::]
            else:
                in_edges = G.indexed_in_edges(
                    node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Ejemplo n.º 10
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        filter_nodes = [
            node for node in G.nodes() if isinstance(node, FilterParameters)
        ]
        for filter_node in filter_nodes:
            while True:
                out_edges = G.out_edges(filter_node.name)
                # can't fuse if there is a branch
                if len(out_edges) > 1:
                    break
                out_edge = out_edges[0]
                op_node = out_edge.to_node
                # must be a valid matrix op
                if not isinstance(op_node, tuple(OPS.keys())):
                    break
                # other edge to the op must be a constant
                other_idx = 1 if out_edge.to_idx == 0 else 0
                other_in_edge = G.indexed_in_edges(op_node.name)[other_idx]
                if not isinstance(other_in_edge.from_node,
                                  ConstantInputParameters):
                    break
                const_node = other_in_edge.from_node
                remove_constant = len(G.out_edges(const_node.name))

                flat_value = const_node.dqvalue.flatten()
                out_c = filter_node.filter.out_c
                op, weights_and_biases = OPS[op_node.__class__]
                # it would be possible to support mult bias addition by out channel but only supporting a
                # scalar at present
                if len(flat_value) != 1 and (weights_and_biases
                                             or len(flat_value) != out_c):
                    LOG.warning('could not absorb %s into %s', const_node.name,
                                filter_node.name)
                    break
                # If there is quantization then essentially the output of the filter
                # takes the quantization of the output of the operation.
                # The biases will not change since their quantization depends on the weights
                # and input
                fnid = NodeId(filter_node)
                opnid = NodeId(op_node)
                if G.quantization and (fnid in G.quantization
                                       or opnid in G.quantization):
                    if not (fnid in G.quantization
                            and opnid in G.quantization):
                        LOG.warning(
                            'could not absorb %s into %s - graph is partially quantized',
                            const_node.name, filter_node.name)
                        break
                    fqrec = G.quantization[fnid]
                    opqrec = G.quantization[opnid]
                    fqrec.out_qs[0] = opqrec.out_qs[0]

                has_modified_graph = True
                LOG.info("fusing bias in %s into %s", const_node.name,
                         filter_node.name)
                self.fuse_bias(G, filter_node, other_idx, op, flat_value, 2)
                if weights_and_biases:
                    # TODO - need to adjust weights quantization here
                    LOG.info("fusing multiplicative bias in %s into %s",
                             const_node.name, filter_node.name)
                    self.fuse_bias(G, filter_node, other_idx, op, flat_value,
                                   1)

                out_edges = G.out_edges(op_node.name)
                G.remove(op_node)
                if remove_constant:
                    G.remove(const_node)
                for edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=filter_node,
                               to_node=edge.to_node,
                               to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Ejemplo n.º 11
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        has_transposed = False
        for params in G.nodes(node_classes=MatMulOpParameters):
            while True:
                out_edges = G.out_edges(params.name)
                # can't fuse if there is a branch
                if len(out_edges) > 1:
                    break
                out_edge = out_edges[0]
                op_node = out_edge.to_node
                # must be a valid matrix op
                if not isinstance(op_node,
                                  (MatrixAddParameters, MatrixMulParameters)):
                    break
                # other edge to the op must be a constant
                other_idx = 1 if out_edge.to_idx == 0 else 0
                other_in_edge = G.indexed_in_edges(op_node.name)[other_idx]
                if not isinstance(other_in_edge.from_node,
                                  ConstantInputParameters):
                    break
                const_node = other_in_edge.from_node
                remove_constant = len(G.out_edges(const_node.name))

                flat_value = const_node.dqvalue.flatten()
                out_shape = params.out_dims[0].shape
                if len(out_shape) != 2:
                    raise ValueError(
                        f'strange outputs shape of {out_shape} for matmul {params.name}'
                    )
                if len(flat_value) != out_shape[0] and len(
                        flat_value) != out_shape[1]:
                    LOG.info(
                        "can't fuse %s into %s - value shape is not correct for bias",
                        const_node.name, params.name)
                    break
                has_bias = len(params.in_dims) == 3
                if isinstance(op_node, MatrixAddParameters):
                    if has_bias:
                        if len(flat_value.shape) != len(params.in_dims[2]):
                            LOG.info(
                                "can't fuse %s into %s - bias shape is not the same",
                                const_node.name, params.name)
                            break
                        bias_node = G.indexed_in_edges(
                            params.name)[2].from_node
                        LOG.info(
                            "folding additive bias from %s into existing bias on %s",
                            op_node.name, params.name)
                        bias_node.value = bias_node.dq_value + flat_value
                    else:
                        if len(flat_value) == out_shape[1]:
                            # matmul needs to be transposed to fuse this
                            reverse_matmul(G, params)
                            has_transposed = True
                        bias_node = ConstantInputParameters(
                            G.unique_name(f'{params.name}_bias'),
                            value=flat_value,
                            dims=Dim.unnamed(flat_value.shape))
                        G.add_edge(
                            NNEdge(from_node=bias_node,
                                   to_node=params,
                                   to_idx=2))
                        # extend the inward transpose
                        if params.transpose_in:
                            params.transpose_in = params.transpose_in + [None]
                        LOG.info(
                            "folding additive bias from %s into new bias on %s",
                            op_node.name, params.name)
                else:
                    params_in = G.indexed_in_edges(params.name)
                    consts = [
                        isinstance(edge.from_node, ConstantInputParameters)
                        for edge in params_in
                    ]
                    if not any(consts):
                        break
                    mult_const_node = params_in[1].from_node if consts[
                        1] else params_in[0].from_node
                    mult_const_node.value = mult_const_node.dqvalue * const_node.dqvalue
                    if has_bias:
                        bias_node = params_in[2].from_node
                        bias_node.value = bias_node.dqvalue * const_node.dqvalue

                    LOG.info(
                        "folding multaplicative bias from %s into new bias on %s",
                        op_node.name, params.name)

                out_edges = G.out_edges(op_node.name)
                G.remove(op_node)
                if remove_constant:
                    G.remove(const_node)
                for edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=params,
                               to_node=edge.to_node,
                               to_idx=edge.to_idx))
                G.add_dimensions()
                if G.quantization:
                    quantizer = UnifiedQuantizer.from_quantized_graph(G)
                    quantizer.quantize(G, start_nodes=[params])
                    RemoveUnnecessaryQuantizeOperators().match(G)

        if has_transposed:
            G.adjust_order()

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        if G.quantization:
            LOG.warning(
                'match_duplicate_operations does not handle quantized graphs')
            return False

        def same_source_edge_fn(x):
            return f"{x.from_node.__hash__()}##{x.from_idx}"

        def same_dest_edge(x):
            return f"{x.to_node.__hash__()}##{x.to_idx}"

        modified_graph = False
        while True:
            found_more = False
            same_source_edges = [
                list(edge_list) for _, edge_list in groupby(
                    sorted(G.edges(), key=same_source_edge_fn),
                    same_source_edge_fn)
            ]
            # all have the same origin
            same_source_edges = [
                elem for elem in same_source_edges if len(elem) > 1
            ]
            same_dest_edges = []
            same_dest_group_edges = []

            for same_source_edge in same_source_edges:
                same_source_edge = [
                    edge for edge in same_source_edge
                    if isinstance(edge.to_node, ComparableParameters)
                ]
                while same_source_edge:
                    first = same_source_edge.pop(0)

                    others = list(
                        filter(
                            partial(
                                lambda x, y: x.to_node != y.to_node and y.
                                to_node.is_same_operation_as(G, x.to_node),
                                first), same_source_edge))
                    if others:
                        same_dest_edges.append(tuple([first] + others))
                        for other in others:
                            same_source_edge.remove(other)
                        continue

                    other_groups = list(
                        filter(
                            partial(
                                lambda x, y: x.to_node != y.to_node and y.
                                to_node.can_be_grouped_with(x.to_node), first),
                            same_source_edge))
                    if other_groups:
                        same_dest_group_edges.append(
                            tuple([first] + other_groups))
                        for other in other_groups:
                            same_source_edge.remove(other)

            # all are multiple edges that go to something comparable
            save_same_dest_edges = same_dest_edges.copy()
            while same_dest_edges:
                edge_set = same_dest_edges.pop(0)
                keep_node = edge_set[0].to_node
                other_edge_sets = [
                    edges for edges in same_dest_edges
                    if any(edge.to_node == keep_node for edge in edges)
                ]
                for other_edge_set in other_edge_sets:
                    same_dest_edges.remove(other_edge_set)

                nodes_to_delete = set()
                for edge_set in [edge_set] + other_edge_sets:
                    for edge in edge_set:
                        other_node = edge.to_node
                        if other_node == keep_node or other_node in nodes_to_delete:
                            continue
                        nodes_to_delete.add(other_node)
                        for out_edge in G.out_edges(other_node):
                            G.add_edge(
                                NNEdge(from_node=keep_node,
                                       to_node=out_edge.to_node,
                                       to_idx=out_edge.to_idx))
                LOG.info(
                    f'removed duplicates {",".join(node.name for node in nodes_to_delete)} to {keep_node.name}'
                )
                for node in nodes_to_delete:
                    G.remove(node)

            # # all are multiple edges that go to something comparable

            # for edge_set in same_dest_edges:
            #     modified_graph = True
            #     found_more = True
            #     first = edge_set[0]
            #     first_node = first.to_node
            #     dup_nodes = []
            #     for other in edge_set[1::]:
            #         dest_node = other.to_node
            #         dup_nodes.append(dest_node.name)
            #         out_edges = G.out_edges(dest_node.name)
            #         G.remove(dest_node)
            #         for out_edge in out_edges:
            #             G.add_edge(NNEdge(from_node=first_node, to_node=out_edge.to_node,
            #                               from_idx=out_edge.from_idx, to_idx=out_edge.to_idx))
            #     LOG.info(
            #         f'removed duplicates {",".join(dup_nodes)} to {first_node.name}')

            for edge_set in same_dest_group_edges:
                modified_graph = True
                found_more = True
                # we will merge all the convolutions into one
                first = edge_set[0]
                first_node = first.to_node
                in_edges = G.indexed_in_edges(first_node.name)
                first_filter = first_node.filter
                weights_node = in_edges[1].from_node
                biases_node = in_edges[2].from_node
                dup_nodes = []
                num_convs = len(edge_set)
                out_shape = deepcopy(first_node.out_dims[0])
                out_shape.c *= num_convs
                # create a split after the first node splitting on channel axis
                act_slices, out_shapes, axis = SplitParameters.get_splits(
                    out_shape,
                    out_shape.get_order_idx('c'),
                    num_splits=num_convs)
                split1 = SplitParameters(
                    G.unique_name(f'{first_node.name}_split'),
                    act_slices=act_slices,
                    out_shapes=out_shapes,
                    axis=axis)
                out_num = 0
                # first node out edge goes to split
                out_edges = G.out_edges(first_node.name)
                for edge in out_edges:
                    G.remove_edge(edge)
                    G.add_edge(
                        NNEdge(from_node=split1,
                               from_idx=out_num,
                               to_node=edge.to_node,
                               to_idx=edge.to_idx))
                G.add_edge(NNEdge(from_node=first_node, to_node=split1))
                # first split output goes to original output
                for other in edge_set[1::]:
                    out_num += 1
                    node_other = other.to_node
                    dup_nodes.append(node_other.name)
                    in_edges = G.indexed_in_edges(node_other.name)
                    weights_other = in_edges[1].from_node
                    biases_other = in_edges[2].from_node
                    # merge the weights and biases diwn output channel
                    weights_node.value = np.concatenate(
                        (weights_node.value, weights_other.value),
                        axis=first_filter.get_order_idx('out_c'))
                    weights_node.dims = Dim.unnamed(weights_node.value.shape)
                    biases_node.value = np.concatenate(
                        (biases_node.value, biases_other.value))
                    biases_node.dims = Dim.unnamed(biases_node.value.shape)
                    first_filter.out_c += node_other.filter.out_c
                    # wire edge from split
                    out_edges = G.out_edges(node_other.name)
                    G.remove(node_other)
                    G.remove(weights_other)
                    G.remove(biases_other)
                    for edge in out_edges:
                        G.add_edge(
                            NNEdge(from_node=split1,
                                   from_idx=out_num,
                                   to_node=edge.to_node,
                                   to_idx=edge.to_idx))
                LOG.info(
                    f'merged convolutions {",".join(dup_nodes)} into {first_node.name}'
                )
            if not found_more:
                break

        if set_identity:
            self.set_identity(G)

        return modified_graph
Ejemplo n.º 13
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        edge_groups = []
        for node in G.nodes(node_classes=SplitParameters):
            cur_group = None
            for out_edge_bundle in G.indexed_out_edges(node):
                if len(out_edge_bundle) == 1:
                    out_edge = out_edge_bundle[0]
                    concat_node_edges = search_down(G,
                                                    out_edge,
                                                    ConcatParameters,
                                                    can_pass=(CopyParameters,
                                                              NoOPParameters))
                    if concat_node_edges:
                        if cur_group:
                            this_concat_edge = concat_node_edges[-1]
                            last_concat_edge = cur_group[-1][-1]
                            if this_concat_edge.to_node == last_concat_edge.to_node and this_concat_edge.to_idx == last_concat_edge.to_idx + 1:
                                cur_group.append(concat_node_edges)
                                continue
                            if len(cur_group) > 1:
                                edge_groups.append(cur_group)
                        cur_group = [concat_node_edges]
                        continue
                if cur_group:
                    if len(cur_group) > 1:
                        edge_groups.append(cur_group)
                    cur_group = None
            if cur_group:
                if len(cur_group) > 1:
                    edge_groups.append(cur_group)
                cur_group = None
        # we leave the splits and concats after this since they will be cleared up by remove_noops
        for edge_group in edge_groups:
            split_node = edge_group[0][0].from_node
            concat_node = edge_group[0][-1].to_node
            from_idx = edge_group[0][0].from_idx
            to_idx = edge_group[-1][0].from_idx
            LOG.info(
                f"combining outputs {from_idx}:{to_idx} on split node {split_node.name} followed by concat {concat_node.name}"
            )
            # combine slices and shapes on edges in group
            new_slice, new_shape = reduce_slices(
                split_node.act_slices[from_idx:to_idx + 1],
                split_node.out_shapes[from_idx:to_idx + 1])
            split_node.act_slices = split_node.act_slices[:from_idx] + [
                new_slice
            ] + split_node.act_slices[to_idx + 1:]
            split_node.out_shapes = split_node.out_shapes[:from_idx] + [
                new_shape
            ] + split_node.out_shapes[to_idx + 1:]
            # remove all edges and intermediate nodes on all edge groups except the first
            for edge_list in edge_group[1:]:
                remove_edges(G, edge_list)
            out_edge_bundles = G.indexed_out_edges(split_node)
            # move edges beyond the edge group after the first index
            for offset, edge_list in enumerate(out_edge_bundles[to_idx + 1:]):
                assert len(edge_list) == 1
                edge = edge_list[0]
                G.remove_edge(edge)
                G.add_edge(NNEdge.clone(edge, from_idx=from_idx + 1 + offset))
            # reindex the in edges in the concat
            from_idx = edge_group[0][-1].to_idx
            to_idx = edge_group[-1][-1].to_idx
            in_edges = G.indexed_in_edges(concat_node)
            for offset, in_edge in enumerate(in_edges[to_idx + 1:]):
                G.remove_edge(in_edge)
                G.add_edge(NNEdge.clone(in_edge, to_idx=from_idx + 1 + offset))

        return bool(edge_groups)
Ejemplo n.º 14
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        for pad_node in [
                params for params in G.nodes()
                if isinstance(params, PadParameters)
        ]:
            node_list = self.get_node_list(G, pad_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx
            subgraph.add_edge(
                NNEdge(from_node=node_list.pad,
                       to_node=node_list.add,
                       to_idx=padded_input_idx))
            last_node = node_list.add
            node_list.add.force_quantized_index = 0
            if node_list.active:
                subgraph.add_edge(
                    NNEdge(from_node=node_list.add, to_node=node_list.active))
                last_node = node_list.active
            if padded_input_idx == 0:
                input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]]
            else:
                input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]]

            output_mapping = [(last_node, 0)]
            pnode = PaddedAddFusionParameters(
                "PADDED_" + node_list.add.name,
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                # if there are quantization stats then clear them. They need to be created again
                G.quantization.stats = None
                if qrecs:
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=qrecs[0].in_qs,
                                           out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            if padded_input_idx == 0:
                in_edges = G.in_edges(node_list.pad.name) + \
                    G.indexed_in_edges(node_list.add.name)[1::]
            else:
                in_edges = G.indexed_in_edges(
                    node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph