def reverse_matmul(G: GraphView, params):
    # reverse edges
    in_edges = G.indexed_in_edges(params.name)
    for edge in in_edges[0:2:]:
        G.remove_edge(edge)
    other_idx = 1
    for edge in in_edges[0:2:]:
        G.add_edge(
            NNEdge(from_node=edge.from_node,
                   to_node=params,
                   from_idx=edge.from_idx,
                   to_idx=other_idx))
        other_idx = 1 - other_idx
    nid = NodeId(params)
    if G.quantization and nid in G.quantization:
        qrec = G.quantization[nid]
        # swap qrecs
        qrec.in_qs[0], qrec.in_qs[1] = qrec.in_qs[1], qrec.in_qs[0]

    # add transposes
    in_nodes = []
    for idx in range(2):
        tin_params = TransposeParameters(
            G.unique_name(f"{params.name}_tin{idx}"), transpose=(1, 0))
        in_nodes.append(tin_params)
        G.insert_node_before(tin_params, params, to_idx=idx, edge_class=NNEdge)
    tout_params = TransposeParameters(G.unique_name(f"{params.name}_tout"),
                                      transpose=(1, 0))
    G.insert_node_after(params, tout_params)
    return in_nodes, tout_params
Esempio n. 2
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        for node in [
                node for node in G.nodes(node_classes=StridedSliceParameters)
        ]:
            if node.slice_shape != tuple(node.in_dims[0].shape):
                continue
            has_modified_graph = True
            nid = NodeId(node)
            if node.slice_shape == node.out_shape:
                LOG.info(
                    f'removing strided slice {node.name} that does nothing')
                G.remove_and_reconnect(node, edge_class=NNEdge)
                if G.quantization and nid in G.quantization:
                    del G.quantization[nid]
            else:
                reshape = ReshapeParameters(
                    G.unique_name(f'{node.name}_reshape'),
                    old_shape=node.slice_shape,
                    shape=node.out_shape)
                LOG.info(
                    f'replacing strided slice {node.name} with reshape {reshape.name}'
                )
                G.replace_node(node, reshape)
                if G.quantization and nid in G.quantization:
                    G.quantization[NodeId(reshape)] = G.quantization[nid]
                    del G.quantization[nid]

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 3
0
    def move_constant(cls, G: GraphView, params, in_qs):
        # looks for a constant on one of the inputs
        # if there is one we can scale by the second dimension of the second
        # tensor. If the constant is on the first tensor then move to the second
        # and transpose the operation
        in_edges = G.indexed_in_edges(params.name)
        in1_node = in_edges[0].from_node
        in2_node = in_edges[1].from_node

        if isinstance(in2_node, ConstantInputParameters):
            return in2_node, in_qs
        elif isinstance(in1_node, ConstantInputParameters):
            if len(params.in_dims) > 2:
                # check if the bias has the correct length to move constant
                # it must have a length equal to the second tensors second dimension after transpose
                bias_size = params.in_dims[2].size()
                in1_shape = params.in_dims[0].shape
                if in1_shape[1] != bias_size:
                    return None, in_qs
            for edge in in_edges[:2:]:
                G.remove_edge(edge)
            to_idx = 1
            # swap edges to move constant onto input 2
            for edge in in_edges[:2:]:
                new_edge = NNEdge(from_node=edge.from_node,
                                  to_node=edge.to_node,
                                  from_idx=edge.from_idx,
                                  to_idx=to_idx)
                G.add_edge(new_edge)
                to_idx = 1 - to_idx
            # use A.B = (BT.AT)T identity
            tin1 = TransposeParameters(G.unique_name(f'{params.name}_tin1'),
                                       transpose=(1, 0))
            tin2 = TransposeParameters(G.unique_name(f'{params.name}_tin2'),
                                       transpose=(1, 0))
            tout = TransposeParameters(G.unique_name(f'{params.name}_tout'),
                                       transpose=(1, 0))
            G.insert_node_before(tin1, params)
            G.insert_node_before(tin2, params, to_idx=1)
            G.insert_node_after(params, tout)
            LOG.warning('transposes inserted on %s - rerun adjust',
                        params.name)
            return in1_node, [in_qs[1], in_qs[0]] + in_qs[2::]
        else:
            return None, in_qs
Esempio n. 4
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):

        candidates = [
            node for node in G.nodes(node_classes=SplitParameters)
            if search_up_for_input(G, node)
        ]
        has_modified_graph = False
        for node in candidates:
            LOG.info("Insert copy on split input %s", node.name)
            has_modified_graph = True
            cnode = CopyParameters(G.unique_name(f'{node.name}_copy'))
            G.insert_node_at_edge(cnode, G.in_edges(node.name)[0])
            if G.quantization:
                G.quantization.copy_qrec(node, 'in', 0, cnode)
        if set_identity:
            self.set_identity(G)
        return has_modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        something_changed = False
        for relu_node in [node for node in G.nodes(node_classes=ReluActivationParameters) if node.upper_bound == 6]:
            out_edges = G.out_edges(relu_node)
            if len(out_edges) != 1 or not isinstance(out_edges[0].to_node, MatrixMulParameters):
                continue
            mul_node = out_edges[0].to_node
            in_edges = G.in_edges(mul_node)
            if len(in_edges) != 2:
                continue
            other_edge = (set(in_edges) - {out_edges[0]}).pop()
            constant_node = other_edge.from_node
            if len(G.out_edges(constant_node)) != 1:
                continue
            if (not isinstance(constant_node, ConstantInputParameters) or
                    not check_equals(G, constant_node, 1.0/6.0)):
                continue

            something_changed = True
            activation = HSigmoidActivationParameters(
                G.unique_name(f'{mul_node.name}_hsigmoid'), offset=0)

            in_edges = G.in_edges(relu_node)
            out_edges = G.out_edges(mul_node)

            nodes_to_replace = [relu_node, mul_node, constant_node]

            LOG.info(f'fusing {", ".join(node.name for node in nodes_to_replace)} into HSIGMOID {activation.name}')
            G.remove_all(nodes_to_replace)

            for in_edge in in_edges:
                G.add_edge(NNEdge.clone(in_edge, to_node=activation, to_idx=0))
            for out_edge in out_edges:
                G.add_edge(NNEdge.clone(
                    out_edge, from_node=activation, from_idx=0))

            if G.quantization:
                reluqrec = G.quantization[NodeId(relu_node)]
                mulqrec = G.quantization[NodeId(mul_node)]
                del G.quantization[NodeId(constant_node)]
                pqrec = QRec.copy_ktype(
                    reluqrec, in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs)
                G.quantization[NodeId(activation)] = pqrec

        return something_changed
Esempio n. 6
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):

        candidates = [node for node in G.nodes(node_classes=(SplitParameters, ConcatParameters))]
        need_a_copy_edges = []
        for node in candidates:
            for idx, edge in enumerate(G.indexed_in_edges(node.name)):
                real_from_node, _ = find_real_in_edge(G, edge)
                if isinstance(real_from_node, (InputParameters, ConstantInputParameters)):
                    need_a_copy_edges.append((edge, idx))
        has_modified_graph = False
        for edge in need_a_copy_edges:
            LOG.info(
                "Insert copy on split input %s", edge[0].to_node.name)
            has_modified_graph = True
            cnode = CopyParameters(G.unique_name(f'{edge[0].to_node.name}_copy'))
            G.insert_node_at_edge(cnode, edge[0])
            if G.quantization:
                G.quantization.copy_qrec(edge[0].to_node, 'in', 0, cnode)
        if set_identity:
            self.set_identity(G)
        return has_modified_graph
Esempio n. 7
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        nodes = list(G.nodes(node_classes=GlobalPoolingParameters))
        modified_graph = False
        while nodes:
            node = nodes.pop()
            node_group = self.reductions(G, node)
            if len(node_group) <= 1:
                continue
            modified_graph = True
            reduction_axes, new_shape, has_keepdims, _ = reduce(
                reduce_reduction, node_group, None)
            new_node = node_group[0]
            new_node.axis = sorted(list(reduction_axes))
            new_node.keep_dims = has_keepdims
            out_edges = G.out_edges(node_group[-1].name)
            if G.quantization:
                last_qrec = G.quantization[NodeId(node_group[-1])]
                G.quantization[NodeId(new_node)].out_qs = last_qrec.out_qs
            for node in node_group[1::]:
                G.remove(node.name)
                nid = NodeId(node)
                if G.quantization and nid in G.quantization:
                    del G.quantization[nid]
            if has_keepdims and len(new_shape) != len(
                    new_node.in_dims[0].shape):
                rparams = ReshapeParameters(
                    G.unique_name(f'{new_node.name}_reshape'),
                    shape=Dim.unnamed(new_shape))
                if G.quantization:
                    G.quantization.copy_qrec(last_qrec, 'out', 0, rparams)
                G.add_edge(NNEdge(new_node, rparams))
                new_node = rparams
            for edge in out_edges:
                G.add_edge(NNEdge(new_node, edge.to_node, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return modified_graph
Esempio n. 8
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        has_transposed = False
        for params in G.nodes(node_classes=MatMulOpParameters):
            while True:
                out_edges = G.out_edges(params.name)
                # can't fuse if there is a branch
                if len(out_edges) > 1:
                    break
                out_edge = out_edges[0]
                op_node = out_edge.to_node
                # must be a valid matrix op
                if not isinstance(op_node,
                                  (MatrixAddParameters, MatrixMulParameters)):
                    break
                # other edge to the op must be a constant
                other_idx = 1 if out_edge.to_idx == 0 else 0
                other_in_edge = G.indexed_in_edges(op_node.name)[other_idx]
                if not isinstance(other_in_edge.from_node,
                                  ConstantInputParameters):
                    break
                const_node = other_in_edge.from_node
                remove_constant = len(G.out_edges(const_node.name))

                flat_value = const_node.dqvalue.flatten()
                out_shape = params.out_dims[0].shape
                if len(out_shape) != 2:
                    raise ValueError(
                        f'strange outputs shape of {out_shape} for matmul {params.name}'
                    )
                if len(flat_value) != out_shape[0] and len(
                        flat_value) != out_shape[1]:
                    LOG.info(
                        "can't fuse %s into %s - value shape is not correct for bias",
                        const_node.name, params.name)
                    break
                has_bias = len(params.in_dims) == 3
                if isinstance(op_node, MatrixAddParameters):
                    if has_bias:
                        if len(flat_value.shape) != len(params.in_dims[2]):
                            LOG.info(
                                "can't fuse %s into %s - bias shape is not the same",
                                const_node.name, params.name)
                            break
                        bias_node = G.indexed_in_edges(
                            params.name)[2].from_node
                        LOG.info(
                            "folding additive bias from %s into existing bias on %s",
                            op_node.name, params.name)
                        bias_node.value = bias_node.dq_value + flat_value
                    else:
                        if len(flat_value) == out_shape[1]:
                            # matmul needs to be transposed to fuse this
                            reverse_matmul(G, params)
                            has_transposed = True
                        bias_node = ConstantInputParameters(
                            G.unique_name(f'{params.name}_bias'),
                            value=flat_value,
                            dims=Dim.unnamed(flat_value.shape))
                        G.add_edge(
                            NNEdge(from_node=bias_node,
                                   to_node=params,
                                   to_idx=2))
                        # extend the inward transpose
                        if params.transpose_in:
                            params.transpose_in = params.transpose_in + [None]
                        LOG.info(
                            "folding additive bias from %s into new bias on %s",
                            op_node.name, params.name)
                else:
                    params_in = G.indexed_in_edges(params.name)
                    consts = [
                        isinstance(edge.from_node, ConstantInputParameters)
                        for edge in params_in
                    ]
                    if not any(consts):
                        break
                    mult_const_node = params_in[1].from_node if consts[
                        1] else params_in[0].from_node
                    mult_const_node.value = mult_const_node.dqvalue * const_node.dqvalue
                    if has_bias:
                        bias_node = params_in[2].from_node
                        bias_node.value = bias_node.dqvalue * const_node.dqvalue

                    LOG.info(
                        "folding multaplicative bias from %s into new bias on %s",
                        op_node.name, params.name)

                out_edges = G.out_edges(op_node.name)
                G.remove(op_node)
                if remove_constant:
                    G.remove(const_node)
                for edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=params,
                               to_node=edge.to_node,
                               to_idx=edge.to_idx))
                G.add_dimensions()
                if G.quantization:
                    quantizer = UnifiedQuantizer.from_quantized_graph(G)
                    quantizer.quantize(G, start_nodes=[params])
                    RemoveUnnecessaryQuantizeOperators().match(G)

        if has_transposed:
            G.adjust_order()

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 9
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        to_quantize = []
        node_sets = self.find_sets(G)
        for node_set in node_sets:
            Symbol.set_default_control(SymbolStats())
            has_modified_graph = True
            in_edges, out_edges, internal_edges = group_edges(G, node_set)
            frag = GraphView()
            for node in node_set:
                frag.add_node(node)
            for edge in internal_edges:
                frag.add_edge(edge)
            in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group]
                          for edge_group in in_edges.values()]
            in_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in in_edges
            ]
            out_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in out_edges
            ]
            out_mapping = list(out_edges.keys())
            constant_inputs = [
                node_edge_idx[0] for node_edge_idx in in_edges
                if isinstance(node_edge_idx[0], ConstantInputParameters)
            ]
            LOG.debug(
                "inputs coming from: %s",
                ",".join(f"{from_node.__repr__()}:{from_idx}"
                         for from_node, from_idx in in_edges))
            LOG.info("fusing nodes: %s into expr_%s",
                     ",".join(node.__repr__() for node in node_set),
                     self._expr_num)
            expr = ExpressionFusionParameters(
                G.unique_name(f"expr_{self._expr_num}"),
                subgraph=frag,
                qrecs=G.quantization,
                input_mapping=in_mapping,
                output_mapping=out_mapping,
                in_dims=in_dims,
                out_dims=out_dims,
                constant_inputs=constant_inputs)
            in_edge_mapping = list(in_edges.keys())
            out_edge_mapping = [[(edge.to_node, edge.to_idx)
                                 for edge in edge_set]
                                for edge_set in out_edges.values()]
            G.replace_fragment(
                frag,
                expr,
                frag_in_edges=list(set.union(*in_edges.values())),
                frag_out_edges=list(set.union(*out_edges.values())),
                edge_in_mapping=in_edge_mapping,
                edge_out_mapping=out_edge_mapping,
                edge_class=NNEdge)
            if G.quantization:
                qrecs = G.quantization
                in_qs = [
                    qrecs[NodeId(in_map[0][0])].in_qs[in_map[0][1]]
                    for in_map in in_mapping
                ]
                out_qs = [
                    qrecs[NodeId(node)].out_qs[idx]
                    for node, idx in out_mapping
                ]
                stats = Symbol.CURRENT_CONTROL.stats
                func_col = expr.func_col
                for idx, qtype in enumerate(in_qs):
                    symbol = func_col.variables[func_col.input_names[idx]]
                    stats[symbol.name] = {
                        'min': qtype.min_val,
                        'max': qtype.max_val
                    }
                for idx, qtype in enumerate(out_qs):
                    symbol = func_col.variables[func_col.output_names[idx]]
                    stats[symbol.name] = {
                        'min': qtype.min_val,
                        'max': qtype.max_val
                    }
                G.quantization[NodeId(expr)] = QRec(in_qs=in_qs,
                                                    out_qs=out_qs,
                                                    expression=stats,
                                                    ktype='scaled')
                # delete any quantize parameters on outputs to allow the quantizer
                # to fuse them into the expression
                out_edges = G.out_edges(expr.name)
                for edge in out_edges:
                    if isinstance(edge.to_node, QuantizeParameters):
                        G.remove_and_reconnect(edge.to_node)
                        if NodeId(edge.to_node) in G.quantization:
                            del G.quantization[NodeId(edge.to_node)]
                to_quantize.append(expr)

            self._expr_num += 1

        if to_quantize:
            quantizer = UnifiedQuantizer.from_quantized_graph(G)
            G.quantization = quantizer.quantize(G, start_nodes=to_quantize)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        if G.quantization:
            LOG.warning(
                'match_duplicate_operations does not handle quantized graphs')
            return False

        def same_source_edge_fn(x):
            return f"{x.from_node.__hash__()}##{x.from_idx}"

        def same_dest_edge(x):
            return f"{x.to_node.__hash__()}##{x.to_idx}"

        modified_graph = False
        while True:
            found_more = False
            same_source_edges = [
                list(edge_list) for _, edge_list in groupby(
                    sorted(G.edges(), key=same_source_edge_fn),
                    same_source_edge_fn)
            ]
            # all have the same origin
            same_source_edges = [
                elem for elem in same_source_edges if len(elem) > 1
            ]
            same_dest_edges = []
            same_dest_group_edges = []

            for same_source_edge in same_source_edges:
                same_source_edge = [
                    edge for edge in same_source_edge
                    if isinstance(edge.to_node, ComparableParameters)
                ]
                while same_source_edge:
                    first = same_source_edge.pop(0)

                    others = list(
                        filter(
                            partial(
                                lambda x, y: x.to_node != y.to_node and y.
                                to_node.is_same_operation_as(G, x.to_node),
                                first), same_source_edge))
                    if others:
                        same_dest_edges.append(tuple([first] + others))
                        for other in others:
                            same_source_edge.remove(other)
                        continue

                    other_groups = list(
                        filter(
                            partial(
                                lambda x, y: x.to_node != y.to_node and y.
                                to_node.can_be_grouped_with(x.to_node), first),
                            same_source_edge))
                    if other_groups:
                        same_dest_group_edges.append(
                            tuple([first] + other_groups))
                        for other in other_groups:
                            same_source_edge.remove(other)

            # all are multiple edges that go to something comparable
            save_same_dest_edges = same_dest_edges.copy()
            while same_dest_edges:
                edge_set = same_dest_edges.pop(0)
                keep_node = edge_set[0].to_node
                other_edge_sets = [
                    edges for edges in same_dest_edges
                    if any(edge.to_node == keep_node for edge in edges)
                ]
                for other_edge_set in other_edge_sets:
                    same_dest_edges.remove(other_edge_set)

                nodes_to_delete = set()
                for edge_set in [edge_set] + other_edge_sets:
                    for edge in edge_set:
                        other_node = edge.to_node
                        if other_node == keep_node or other_node in nodes_to_delete:
                            continue
                        nodes_to_delete.add(other_node)
                        for out_edge in G.out_edges(other_node):
                            G.add_edge(
                                NNEdge(from_node=keep_node,
                                       to_node=out_edge.to_node,
                                       to_idx=out_edge.to_idx))
                LOG.info(
                    f'removed duplicates {",".join(node.name for node in nodes_to_delete)} to {keep_node.name}'
                )
                for node in nodes_to_delete:
                    G.remove(node)

            # # all are multiple edges that go to something comparable

            # for edge_set in same_dest_edges:
            #     modified_graph = True
            #     found_more = True
            #     first = edge_set[0]
            #     first_node = first.to_node
            #     dup_nodes = []
            #     for other in edge_set[1::]:
            #         dest_node = other.to_node
            #         dup_nodes.append(dest_node.name)
            #         out_edges = G.out_edges(dest_node.name)
            #         G.remove(dest_node)
            #         for out_edge in out_edges:
            #             G.add_edge(NNEdge(from_node=first_node, to_node=out_edge.to_node,
            #                               from_idx=out_edge.from_idx, to_idx=out_edge.to_idx))
            #     LOG.info(
            #         f'removed duplicates {",".join(dup_nodes)} to {first_node.name}')

            for edge_set in same_dest_group_edges:
                modified_graph = True
                found_more = True
                # we will merge all the convolutions into one
                first = edge_set[0]
                first_node = first.to_node
                in_edges = G.indexed_in_edges(first_node.name)
                first_filter = first_node.filter
                weights_node = in_edges[1].from_node
                biases_node = in_edges[2].from_node
                dup_nodes = []
                num_convs = len(edge_set)
                out_shape = deepcopy(first_node.out_dims[0])
                out_shape.c *= num_convs
                # create a split after the first node splitting on channel axis
                act_slices, out_shapes, axis = SplitParameters.get_splits(
                    out_shape,
                    out_shape.get_order_idx('c'),
                    num_splits=num_convs)
                split1 = SplitParameters(
                    G.unique_name(f'{first_node.name}_split'),
                    act_slices=act_slices,
                    out_shapes=out_shapes,
                    axis=axis)
                out_num = 0
                # first node out edge goes to split
                out_edges = G.out_edges(first_node.name)
                for edge in out_edges:
                    G.remove_edge(edge)
                    G.add_edge(
                        NNEdge(from_node=split1,
                               from_idx=out_num,
                               to_node=edge.to_node,
                               to_idx=edge.to_idx))
                G.add_edge(NNEdge(from_node=first_node, to_node=split1))
                # first split output goes to original output
                for other in edge_set[1::]:
                    out_num += 1
                    node_other = other.to_node
                    dup_nodes.append(node_other.name)
                    in_edges = G.indexed_in_edges(node_other.name)
                    weights_other = in_edges[1].from_node
                    biases_other = in_edges[2].from_node
                    # merge the weights and biases diwn output channel
                    weights_node.value = np.concatenate(
                        (weights_node.value, weights_other.value),
                        axis=first_filter.get_order_idx('out_c'))
                    weights_node.dims = Dim.unnamed(weights_node.value.shape)
                    biases_node.value = np.concatenate(
                        (biases_node.value, biases_other.value))
                    biases_node.dims = Dim.unnamed(biases_node.value.shape)
                    first_filter.out_c += node_other.filter.out_c
                    # wire edge from split
                    out_edges = G.out_edges(node_other.name)
                    G.remove(node_other)
                    G.remove(weights_other)
                    G.remove(biases_other)
                    for edge in out_edges:
                        G.add_edge(
                            NNEdge(from_node=split1,
                                   from_idx=out_num,
                                   to_node=edge.to_node,
                                   to_idx=edge.to_idx))
                LOG.info(
                    f'merged convolutions {",".join(dup_nodes)} into {first_node.name}'
                )
            if not found_more:
                break

        if set_identity:
            self.set_identity(G)

        return modified_graph
Esempio n. 11
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        slices_by_origin = {}
        for slice_node in [
                node for node in G.nodes()
                if isinstance(node, StridedSliceParameters)
        ]:
            in_edge = G.in_edges(slice_node.name)[0]
            group = slices_by_origin.setdefault(
                (in_edge.from_node, in_edge.from_idx), [])
            group.append(slice_node)
        for in_edge, slice_nodes in slices_by_origin.items():
            slices = list(zip(*[node.act_slice for node in slice_nodes]))
            if len(slice_nodes) == 1:
                self.slice_to_split(G, slice_nodes, slices)
                continue

            # strides must be one
            if any(sl[2] != 1 for sl_axis in slices for sl in sl_axis):
                continue

            diff_axes = list([
                idx for idx, elems in enumerate(slices)
                if not all(elems[0] == elem for elem in elems[1::])
            ])
            not_diff_axes = [
                idx for idx in range(len(slices)) if idx not in diff_axes
            ]
            diff_slices = [
                sl for idx, sl in enumerate(slices) if idx in diff_axes
            ]
            axis_lengths = in_edge[0].out_dims[in_edge[1]].shape
            if not_diff_axes and min(not_diff_axes) < max(diff_axes):
                transpose_from = tuple(range(len(slices)))
                transpose_to = tuple(diff_axes + not_diff_axes)
                axis_lengths = [axis_lengths[idx] for idx in transpose_to]
            else:
                transpose_from = transpose_to = None
            diff_axis_lengths = axis_lengths[0:len(diff_axes):]

            diff_slices = combine_slices(diff_axis_lengths, diff_slices,
                                         slice_nodes)
            if diff_slices is None:
                continue

            if len(diff_axes) > 1:
                reshape_from = axis_lengths
                reshape_to = [np.prod(diff_axis_lengths)] + \
                    axis_lengths[len(diff_axes)::]
            else:
                reshape_from = None
                reshape_to = slice_nodes[0].in_dims[0].shape
                if transpose_from:
                    reshape_to = [reshape_to[idx] for idx in transpose_to]

            sizes, shapes, sorted_nodes = slices_to_sizes(
                diff_slices, axis_lengths[len(diff_axes)::])

            name_prefix = sorted_nodes[0].name

            in_edge = G.in_edges(sorted_nodes[0].name)[0]
            in_node = in_edge.from_node
            in_idx = in_edge.from_idx

            if transpose_from:
                params = TransposeParameters(G.unique_name(name_prefix +
                                                           '_tin'),
                                             transpose=transpose_to)
                G.add_edge(
                    NNEdge(from_node=in_node, to_node=params, from_idx=in_idx))
                in_node = params
                in_idx = 0

            if reshape_from:
                params = ReshapeParameters(G.unique_name(name_prefix +
                                                         '_reshape'),
                                           old_shape=Dim.unnamed(reshape_from),
                                           shape=Dim.unnamed(reshape_to))
                G.add_edge(
                    NNEdge(from_node=in_node, to_node=params, from_idx=in_idx))
                in_node = params
                in_idx = 0

            act_slices, out_shapes, axis = SplitParameters.get_splits(
                reshape_to, 0, splits=sizes)
            split_node = SplitParameters(G.unique_name(name_prefix + '_split'),
                                         act_slices=act_slices,
                                         out_shapes=out_shapes,
                                         axis=axis)

            G.add_edge(
                NNEdge(from_node=in_node, from_idx=in_idx, to_node=split_node))

            sub_names = []
            for idx, node in enumerate(sorted_nodes):
                sub_names.append(node.name)
                out_edges = G.out_edges(node.name)
                G.remove(node)
                for out_edge in out_edges:
                    params = split_node
                    out_idx = idx
                    if reshape_from:
                        from_node = params
                        params = ReshapeParameters(
                            G.unique_name(name_prefix + f'_reshape{idx}'),
                            shape=Dim.unnamed(shapes[idx]))
                        G.add_edge(
                            NNEdge(from_node=from_node,
                                   to_node=params,
                                   from_idx=out_idx))
                        out_idx = 0
                    if transpose_from:
                        from_node = params
                        params = TransposeParameters(
                            G.unique_name(name_prefix + f'_tout{idx}'),
                            transpose=reverse_transpose(transpose_to))
                        G.add_edge(
                            NNEdge(from_node=from_node,
                                   to_node=params,
                                   from_idx=out_idx))
                        out_idx = 0

                    G.add_edge(
                        NNEdge(from_node=params,
                               to_node=out_edge.to_node,
                               from_idx=out_idx,
                               to_idx=out_edge.to_idx))
            if G.quantization:
                G.add_dimensions()
                quantizer = NewQuantizer.from_quantized_graph(G)
                quantizer.quantize()
                RemoveUnnecessaryQuantizeOperators().match(G)

            LOG.info(
                f'replaced slice nodes {",".join(sub_names)} with split node {split_node.name}'
            )

            has_modified_graph = True

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 12
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        modified_graph = False
        concats = set(G.nodes(node_classes=ConcatParameters))
        while concats:
            concat = concats.pop()
            if concat.axis != 0:
                continue
            subgraph = find_concats_up(G, concat)
            found = set(subgraph.nodes(node_classes=ConcatParameters))
            if len(found) <= 1:
                continue
            LOG.info(
                f"Combining concats {','.join([node.name for node in found])}")
            modified_graph = True
            concats -= found

            in_edges = [inp.edge for inp in subgraph.inputs()]
            in_dims = [
                edge.from_node.out_dims[edge.from_idx] for edge in in_edges
            ]
            nodes_to_remove = [
                node for node in subgraph.nodes()
                if node != concat and not isinstance(node, DummyInput)
            ]
            for edge in in_edges:
                G.remove_edge(edge)
            for node in nodes_to_remove:
                if node.name in G:
                    G.remove(node)
                nid = NodeId(node)
                if G.quantization and nid in G.quantization:
                    del G.quantization[nid]

            # remove_internal_graph(G, subgraph)
            out_dim = concat.out_dims[0]
            in_qs = []
            for idx, edge in enumerate(in_edges):
                from_node = edge.from_node
                from_idx = edge.from_idx
                if len(in_dims[idx]) > 1:
                    reshape = ReshapeParameters(
                        G.unique_name(f'{concat.name}_flat{idx}'),
                        old_shape=in_dims[idx],
                        shape=Dim.unnamed([in_dims[idx].size()]))
                    G.add_edge(
                        NNEdge(from_node=from_node,
                               from_idx=from_idx,
                               to_node=reshape))
                    from_node = reshape
                    from_idx = 0
                G.add_edge(
                    NNEdge(from_node=from_node,
                           from_idx=from_idx,
                           to_node=concat,
                           to_idx=idx))
                if in_qs is not None and G.quantization:
                    nid = NodeId(edge.from_node)
                    if nid in G.quantization:
                        qrec = G.quantization[nid]
                        in_qs.append(qrec.out_qs[edge.from_idx])
                    else:
                        in_qs = None
                else:
                    in_qs = None
            if in_qs is not None and G.quantization:
                nid = NodeId(concat)
                if nid in G.quantization:
                    G.quantization[nid].in_qs = in_qs
            reshape = ReshapeParameters(G.unique_name(f'{concat.name}_expand'),
                                        old_shape=Dim.unnamed([out_dim.size()
                                                               ]),
                                        shape=out_dim)
            G.insert_node_after(concat, reshape, edge_class=NNEdge)

        if set_identity:
            self.set_identity(G)

        return modified_graph