def reverse_matmul(G: GraphView, params):
    # reverse edges
    in_edges = G.indexed_in_edges(params.name)
    for edge in in_edges[0:2:]:
        G.remove_edge(edge)
    other_idx = 1
    for edge in in_edges[0:2:]:
        G.add_edge(
            NNEdge(from_node=edge.from_node,
                   to_node=params,
                   from_idx=edge.from_idx,
                   to_idx=other_idx))
        other_idx = 1 - other_idx
    nid = NodeId(params)
    if G.quantization and nid in G.quantization:
        qrec = G.quantization[nid]
        # swap qrecs
        qrec.in_qs[0], qrec.in_qs[1] = qrec.in_qs[1], qrec.in_qs[0]

    # add transposes
    in_nodes = []
    for idx in range(2):
        tin_params = TransposeParameters(
            G.unique_name(f"{params.name}_tin{idx}"), transpose=(1, 0))
        in_nodes.append(tin_params)
        G.insert_node_before(tin_params, params, to_idx=idx, edge_class=NNEdge)
    tout_params = TransposeParameters(G.unique_name(f"{params.name}_tout"),
                                      transpose=(1, 0))
    G.insert_node_after(params, tout_params)
    return in_nodes, tout_params
Example #2
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified = False
        for node in G.nodes(node_classes=ConstantInputParameters):
            out_edges = G.out_edges(node.name)
            if len(out_edges) <= 1:
                continue
            has_modified = True
            LOG.info(
                'node %s has more than one out edge and will be duplicated',
                node.name)
            idx = 1
            for out_edge in out_edges[1::]:
                new_constant = ConstantInputParameters(f'{node.name}_{idx}',
                                                       dims=Dim.unnamed(
                                                           node.dims.shape),
                                                       value=node.value.copy())
                G.remove_edge(out_edge)
                G.add_edge(
                    NNEdge(from_node=new_constant,
                           to_node=out_edge.to_node,
                           to_idx=out_edge.to_idx))
                idx += 1

        if set_identity:
            self.set_identity(G)

        return has_modified
Example #3
0
 def match(self, G: GraphView, set_identity: bool = True):
     split_nodes = [
         node for node in G.nodes() if isinstance(node, SplitParameters)
     ]
     has_modified_graph = False
     for node in split_nodes:
         # traverse reshapes or transposes that do nothing - check gen
         # find edges connected to concats
         res = self.find_split_concat(G, node)
         if res is None:
             continue
         # TODO(martin) - group edges that have adjacent inputs and outputs
         if G.quantization:
             qrec = G.quantization[NodeId(node)]
         for idx, bundle in enumerate(res):
             if not bundle:
                 continue
             has_modified_graph = True
             copy_node = CopyParameters("%s_copy_%s" % (node.name, idx))
             for edge_set in bundle:
                 first_edge = edge_set[0]
                 G.remove_edge(first_edge)
                 G.add_edge(
                     NNEdge(copy_node,
                            first_edge.to_node,
                            to_idx=first_edge.to_idx))
             G.add_edge(NNEdge(node, copy_node, from_idx=idx))
             if G.quantization:
                 G.quantization[NodeId(copy_node)] = qrec.__class__(
                     in_qs=deepcopy(qrec.out_qs[idx]),
                     out_qs=deepcopy(qrec.out_qs[idx]))
     return has_modified_graph
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        for node in [
                node for node in G.nodes() if self.node_does_nothing(G, node)
        ]:
            has_modified_graph = True
            in_edge = G.in_edges(node.name)[0]
            G.remove_edge(in_edge)
            for out_edge in G.out_edges(node.name):
                G.remove_edge(out_edge)
                G.add_edge(
                    NNEdge(in_edge.from_node,
                           out_edge.to_node,
                           from_idx=in_edge.from_idx,
                           to_idx=out_edge.to_idx))
            LOG.info(f'removing {node.name} that does nothing')
            G.remove(node)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Example #5
0
    def move_constant(cls, G: GraphView, params, in_qs):
        # looks for a constant on one of the inputs
        # if there is one we can scale by the second dimension of the second
        # tensor. If the constant is on the first tensor then move to the second
        # and transpose the operation
        in_edges = G.indexed_in_edges(params.name)
        in1_node = in_edges[0].from_node
        in2_node = in_edges[1].from_node

        if isinstance(in2_node, ConstantInputParameters):
            return in2_node, in_qs
        elif isinstance(in1_node, ConstantInputParameters):
            if len(params.in_dims) > 2:
                # check if the bias has the correct length to move constant
                # it must have a length equal to the second tensors second dimension after transpose
                bias_size = params.in_dims[2].size()
                in1_shape = params.in_dims[0].shape
                if in1_shape[1] != bias_size:
                    return None, in_qs
            for edge in in_edges[:2:]:
                G.remove_edge(edge)
            to_idx = 1
            # swap edges to move constant onto input 2
            for edge in in_edges[:2:]:
                new_edge = NNEdge(from_node=edge.from_node,
                                  to_node=edge.to_node,
                                  from_idx=edge.from_idx,
                                  to_idx=to_idx)
                G.add_edge(new_edge)
                to_idx = 1 - to_idx
            # use A.B = (BT.AT)T identity
            tin1 = TransposeParameters(G.unique_name(f'{params.name}_tin1'),
                                       transpose=(1, 0))
            tin2 = TransposeParameters(G.unique_name(f'{params.name}_tin2'),
                                       transpose=(1, 0))
            tout = TransposeParameters(G.unique_name(f'{params.name}_tout'),
                                       transpose=(1, 0))
            G.insert_node_before(tin1, params)
            G.insert_node_before(tin2, params, to_idx=1)
            G.insert_node_after(params, tout)
            LOG.warning('transposes inserted on %s - rerun adjust',
                        params.name)
            return in1_node, [in_qs[1], in_qs[0]] + in_qs[2::]
        else:
            return None, in_qs
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        if G.quantization:
            LOG.warning(
                'match_duplicate_operations does not handle quantized graphs')
            return False

        def same_source_edge_fn(x):
            return f"{x.from_node.__hash__()}##{x.from_idx}"

        def same_dest_edge(x):
            return f"{x.to_node.__hash__()}##{x.to_idx}"

        modified_graph = False
        while True:
            found_more = False
            same_source_edges = [
                list(edge_list) for _, edge_list in groupby(
                    sorted(G.edges(), key=same_source_edge_fn),
                    same_source_edge_fn)
            ]
            # all have the same origin
            same_source_edges = [
                elem for elem in same_source_edges if len(elem) > 1
            ]
            same_dest_edges = []
            same_dest_group_edges = []

            for same_source_edge in same_source_edges:
                same_source_edge = [
                    edge for edge in same_source_edge
                    if isinstance(edge.to_node, ComparableParameters)
                ]
                while same_source_edge:
                    first = same_source_edge.pop(0)

                    others = list(
                        filter(
                            partial(
                                lambda x, y: x.to_node != y.to_node and y.
                                to_node.is_same_operation_as(G, x.to_node),
                                first), same_source_edge))
                    if others:
                        same_dest_edges.append(tuple([first] + others))
                        for other in others:
                            same_source_edge.remove(other)
                        continue

                    other_groups = list(
                        filter(
                            partial(
                                lambda x, y: x.to_node != y.to_node and y.
                                to_node.can_be_grouped_with(x.to_node), first),
                            same_source_edge))
                    if other_groups:
                        same_dest_group_edges.append(
                            tuple([first] + other_groups))
                        for other in other_groups:
                            same_source_edge.remove(other)

            # all are multiple edges that go to something comparable
            save_same_dest_edges = same_dest_edges.copy()
            while same_dest_edges:
                edge_set = same_dest_edges.pop(0)
                keep_node = edge_set[0].to_node
                other_edge_sets = [
                    edges for edges in same_dest_edges
                    if any(edge.to_node == keep_node for edge in edges)
                ]
                for other_edge_set in other_edge_sets:
                    same_dest_edges.remove(other_edge_set)

                nodes_to_delete = set()
                for edge_set in [edge_set] + other_edge_sets:
                    for edge in edge_set:
                        other_node = edge.to_node
                        if other_node == keep_node or other_node in nodes_to_delete:
                            continue
                        nodes_to_delete.add(other_node)
                        for out_edge in G.out_edges(other_node):
                            G.add_edge(
                                NNEdge(from_node=keep_node,
                                       to_node=out_edge.to_node,
                                       to_idx=out_edge.to_idx))
                LOG.info(
                    f'removed duplicates {",".join(node.name for node in nodes_to_delete)} to {keep_node.name}'
                )
                for node in nodes_to_delete:
                    G.remove(node)

            # # all are multiple edges that go to something comparable

            # for edge_set in same_dest_edges:
            #     modified_graph = True
            #     found_more = True
            #     first = edge_set[0]
            #     first_node = first.to_node
            #     dup_nodes = []
            #     for other in edge_set[1::]:
            #         dest_node = other.to_node
            #         dup_nodes.append(dest_node.name)
            #         out_edges = G.out_edges(dest_node.name)
            #         G.remove(dest_node)
            #         for out_edge in out_edges:
            #             G.add_edge(NNEdge(from_node=first_node, to_node=out_edge.to_node,
            #                               from_idx=out_edge.from_idx, to_idx=out_edge.to_idx))
            #     LOG.info(
            #         f'removed duplicates {",".join(dup_nodes)} to {first_node.name}')

            for edge_set in same_dest_group_edges:
                modified_graph = True
                found_more = True
                # we will merge all the convolutions into one
                first = edge_set[0]
                first_node = first.to_node
                in_edges = G.indexed_in_edges(first_node.name)
                first_filter = first_node.filter
                weights_node = in_edges[1].from_node
                biases_node = in_edges[2].from_node
                dup_nodes = []
                num_convs = len(edge_set)
                out_shape = deepcopy(first_node.out_dims[0])
                out_shape.c *= num_convs
                # create a split after the first node splitting on channel axis
                act_slices, out_shapes, axis = SplitParameters.get_splits(
                    out_shape,
                    out_shape.get_order_idx('c'),
                    num_splits=num_convs)
                split1 = SplitParameters(
                    G.unique_name(f'{first_node.name}_split'),
                    act_slices=act_slices,
                    out_shapes=out_shapes,
                    axis=axis)
                out_num = 0
                # first node out edge goes to split
                out_edges = G.out_edges(first_node.name)
                for edge in out_edges:
                    G.remove_edge(edge)
                    G.add_edge(
                        NNEdge(from_node=split1,
                               from_idx=out_num,
                               to_node=edge.to_node,
                               to_idx=edge.to_idx))
                G.add_edge(NNEdge(from_node=first_node, to_node=split1))
                # first split output goes to original output
                for other in edge_set[1::]:
                    out_num += 1
                    node_other = other.to_node
                    dup_nodes.append(node_other.name)
                    in_edges = G.indexed_in_edges(node_other.name)
                    weights_other = in_edges[1].from_node
                    biases_other = in_edges[2].from_node
                    # merge the weights and biases diwn output channel
                    weights_node.value = np.concatenate(
                        (weights_node.value, weights_other.value),
                        axis=first_filter.get_order_idx('out_c'))
                    weights_node.dims = Dim.unnamed(weights_node.value.shape)
                    biases_node.value = np.concatenate(
                        (biases_node.value, biases_other.value))
                    biases_node.dims = Dim.unnamed(biases_node.value.shape)
                    first_filter.out_c += node_other.filter.out_c
                    # wire edge from split
                    out_edges = G.out_edges(node_other.name)
                    G.remove(node_other)
                    G.remove(weights_other)
                    G.remove(biases_other)
                    for edge in out_edges:
                        G.add_edge(
                            NNEdge(from_node=split1,
                                   from_idx=out_num,
                                   to_node=edge.to_node,
                                   to_idx=edge.to_idx))
                LOG.info(
                    f'merged convolutions {",".join(dup_nodes)} into {first_node.name}'
                )
            if not found_more:
                break

        if set_identity:
            self.set_identity(G)

        return modified_graph
Example #7
0
    def match(self, G: GraphView, set_identity: bool = True) -> bool:
        has_modified_graph = False
        for pad_params in [
                pad for pad in G.nodes() if isinstance(pad, PadParameters)
        ]:
            pad_in_edges = G.in_edges(pad_params.name)
            pad_out_edges = G.out_edges(pad_params.name)
            dont_delete = False
            for pad_out_edge in pad_out_edges:
                filter_like_node, is_1d = self.find_conv(
                    G, pad_out_edge.to_node)
                if not filter_like_node:
                    dont_delete = True
                    continue
                if not filter_like_node.in_dims_hint or not filter_like_node.in_dims_hint[
                        0]:
                    raise ValueError(
                        f"filter {filter_like_node.name} doesn't have a input hint"
                    )
                in_hint = filter_like_node.in_dims_hint[0]
                if is_1d:
                    if len(pad_params.padding) != 2:
                        LOG.warning(
                            "pad node %s is applied to 1d convolution but has length %s",
                            pad_params.name, len(pad_params.padding))
                        dont_delete = True
                        continue
                    expanded_padding = [
                        pad_params.padding[0], (0, 0), pad_params.padding[1]
                    ]
                else:
                    if len(pad_params.padding) != 3:
                        LOG.warning(
                            "pad node %s is applied to 2d convolution but has length %s",
                            pad_params.name, len(pad_params.padding))
                        dont_delete = True
                        continue
                    expanded_padding = pad_params.padding

                hinted_pad = {
                    in_hint[idx]: pad
                    for idx, pad in enumerate(expanded_padding) if sum(pad) > 0
                }
                key_set = set(hinted_pad.keys())
                key_set -= set(['h', 'w'])
                if len(key_set) > 0:
                    dont_delete = True
                    LOG.error(
                        "node %s has padding on axes %s and cannot be fused with filter %s",
                        pad_params.name, key_set, filter_like_node.name)
                    continue
                if any(pval != 0 for val in pad_params.pad_vals
                       for pval in val):
                    dont_delete = True
                    LOG.error(
                        "node %s has non zero pad values and cannot be fused with filter %s",
                        pad_params.name, filter_like_node.name)
                    continue

                LOG.info("adding padding from: %s to %s filter: %s",
                         pad_params.name, is_1d and "1D" or "2D",
                         filter_like_node.name)

                for key in ['h', 'w']:
                    if key not in hinted_pad:
                        hinted_pad[key] = (0, 0)

                filter_like_node.padding = PadDim(*(list(hinted_pad['h']) +
                                                    list(hinted_pad['w'])))
                filter_like_node.pad_type = "zero"
                has_modified_graph = True
                G.remove_edge(pad_out_edge)
                if is_1d:
                    reshape_node = pad_out_edge.to_node
                    reshape_node.old_shape = self.remove_padding(
                        reshape_node.old_shape, pad_params.padding)
                    reshape_node.shape = self.remove_padding(
                        reshape_node.shape, expanded_padding)
                for in_edge in pad_in_edges:
                    G.add_edge(
                        NNEdge(from_node=in_edge.from_node,
                               to_node=pad_out_edge.to_node,
                               from_idx=in_edge.from_idx,
                               to_idx=pad_out_edge.to_idx))

            if not dont_delete:
                G.remove(pad_params)
                if G.quantization:
                    G.quantization.remove_node(pad_params)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Example #8
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool:
        has_modified_graph = False
        for pad_params in [pad for pad in G.nodes() if isinstance(pad, PadParameters)]:
            pad_in_edges = G.in_edges(pad_params.name)
            pad_out_edges = G.out_edges(pad_params.name)
            dont_delete = False
            if len(pad_in_edges) == 1 and all(sum(padding) == 0 for padding in pad_params.padding):
                LOG.info("removing zero padding node %s",
                         pad_params.name)
                G.remove(pad_params)
                if G.quantization:
                    G.quantization.remove_node(pad_params)
                dont_delete = True
                in_edge = pad_in_edges[0]
                for out_edge in pad_out_edges:
                    G.add_edge(NNEdge(from_node=in_edge.from_node,
                                      to_node=out_edge.to_node,
                                      from_idx=in_edge.from_idx,
                                      to_idx=out_edge.to_idx))
            else:
                for pad_out_edge in pad_out_edges:
                    filter_like_node, expanded_padding, reshapes = self.find_conv(
                        G, pad_out_edge.to_node, pad_params.padding)
                    if not filter_like_node:
                        dont_delete = True
                        continue
                    if not filter_like_node.in_dims_hint or not filter_like_node.in_dims_hint[0]:
                        raise ValueError(
                            f"filter {filter_like_node.name} doesn't have a input hint")
                    in_hint = filter_like_node.in_dims_hint[0]

                    hinted_pad = {in_hint[idx]: pad for idx,
                                  pad in enumerate(expanded_padding) if sum(pad) > 0}
                    key_set = set(hinted_pad.keys())
                    key_set -= set(['h', 'w'])
                    if len(key_set) > 0:
                        dont_delete = True
                        LOG.error("node %s has padding on axes %s and cannot be fused with filter %s",
                                  pad_params.name, key_set, filter_like_node.name)
                        continue
                    if any(pval != 0 for val in pad_params.pad_vals for pval in val):
                        dont_delete = True
                        LOG.error("node %s has non zero pad values and cannot be fused with filter %s",
                                  pad_params.name, filter_like_node.name)
                        continue

                    LOG.info("adding padding from: %s to filter: %s - has %s reshapes",
                             pad_params.name, filter_like_node.name, len(reshapes))

                    for key in ['h', 'w']:
                        if key not in hinted_pad:
                            hinted_pad[key] = (0, 0)

                    filter_like_node.padding = PadDim(
                        *(list(hinted_pad['h']) + list(hinted_pad['w'])))
                    filter_like_node.pad_type = "zero"
                    has_modified_graph = True
                    G.remove_edge(pad_out_edge)
                    for reshape_node, old_padding, new_padding in reshapes:
                        reshape_node.old_shape = self.remove_padding(
                            reshape_node.old_shape, old_padding)
                        reshape_node.shape = self.remove_padding(
                            reshape_node.shape, new_padding)

                    for in_edge in pad_in_edges:
                        G.add_edge(NNEdge(from_node=in_edge.from_node,
                                          to_node=pad_out_edge.to_node,
                                          from_idx=in_edge.from_idx,
                                          to_idx=pad_out_edge.to_idx))

            if not dont_delete:
                G.remove(pad_params)
                if G.quantization:
                    G.quantization.remove_node(pad_params)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Example #9
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        modified_graph = False
        concats = set(G.nodes(node_classes=ConcatParameters))
        while concats:
            concat = concats.pop()
            if concat.axis != 0:
                continue
            subgraph = find_concats_up(G, concat)
            found = set(subgraph.nodes(node_classes=ConcatParameters))
            if len(found) <= 1:
                continue
            LOG.info(
                f"Combining concats {','.join([node.name for node in found])}")
            modified_graph = True
            concats -= found

            in_edges = [inp.edge for inp in subgraph.inputs()]
            in_dims = [
                edge.from_node.out_dims[edge.from_idx] for edge in in_edges
            ]
            nodes_to_remove = [
                node for node in subgraph.nodes()
                if node != concat and not isinstance(node, DummyInput)
            ]
            for edge in in_edges:
                G.remove_edge(edge)
            for node in nodes_to_remove:
                if node.name in G:
                    G.remove(node)
                nid = NodeId(node)
                if G.quantization and nid in G.quantization:
                    del G.quantization[nid]

            # remove_internal_graph(G, subgraph)
            out_dim = concat.out_dims[0]
            in_qs = []
            for idx, edge in enumerate(in_edges):
                from_node = edge.from_node
                from_idx = edge.from_idx
                if len(in_dims[idx]) > 1:
                    reshape = ReshapeParameters(
                        G.unique_name(f'{concat.name}_flat{idx}'),
                        old_shape=in_dims[idx],
                        shape=Dim.unnamed([in_dims[idx].size()]))
                    G.add_edge(
                        NNEdge(from_node=from_node,
                               from_idx=from_idx,
                               to_node=reshape))
                    from_node = reshape
                    from_idx = 0
                G.add_edge(
                    NNEdge(from_node=from_node,
                           from_idx=from_idx,
                           to_node=concat,
                           to_idx=idx))
                if in_qs is not None and G.quantization:
                    nid = NodeId(edge.from_node)
                    if nid in G.quantization:
                        qrec = G.quantization[nid]
                        in_qs.append(qrec.out_qs[edge.from_idx])
                    else:
                        in_qs = None
                else:
                    in_qs = None
            if in_qs is not None and G.quantization:
                nid = NodeId(concat)
                if nid in G.quantization:
                    G.quantization[nid].in_qs = in_qs
            reshape = ReshapeParameters(G.unique_name(f'{concat.name}_expand'),
                                        old_shape=Dim.unnamed([out_dim.size()
                                                               ]),
                                        shape=out_dim)
            G.insert_node_after(concat, reshape, edge_class=NNEdge)

        if set_identity:
            self.set_identity(G)

        return modified_graph
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        edge_groups = []
        for node in G.nodes(node_classes=SplitParameters):
            cur_group = None
            for out_edge_bundle in G.indexed_out_edges(node):
                if len(out_edge_bundle) == 1:
                    out_edge = out_edge_bundle[0]
                    concat_node_edges = search_down(G,
                                                    out_edge,
                                                    ConcatParameters,
                                                    can_pass=(CopyParameters,
                                                              NoOPParameters))
                    if concat_node_edges:
                        if cur_group:
                            this_concat_edge = concat_node_edges[-1]
                            last_concat_edge = cur_group[-1][-1]
                            if this_concat_edge.to_node == last_concat_edge.to_node and this_concat_edge.to_idx == last_concat_edge.to_idx + 1:
                                cur_group.append(concat_node_edges)
                                continue
                            if len(cur_group) > 1:
                                edge_groups.append(cur_group)
                        cur_group = [concat_node_edges]
                        continue
                if cur_group:
                    if len(cur_group) > 1:
                        edge_groups.append(cur_group)
                    cur_group = None
            if cur_group:
                if len(cur_group) > 1:
                    edge_groups.append(cur_group)
                cur_group = None
        # we leave the splits and concats after this since they will be cleared up by remove_noops
        for edge_group in edge_groups:
            split_node = edge_group[0][0].from_node
            concat_node = edge_group[0][-1].to_node
            from_idx = edge_group[0][0].from_idx
            to_idx = edge_group[-1][0].from_idx
            LOG.info(
                f"combining outputs {from_idx}:{to_idx} on split node {split_node.name} followed by concat {concat_node.name}"
            )
            # combine slices and shapes on edges in group
            new_slice, new_shape = reduce_slices(
                split_node.act_slices[from_idx:to_idx + 1],
                split_node.out_shapes[from_idx:to_idx + 1])
            split_node.act_slices = split_node.act_slices[:from_idx] + [
                new_slice
            ] + split_node.act_slices[to_idx + 1:]
            split_node.out_shapes = split_node.out_shapes[:from_idx] + [
                new_shape
            ] + split_node.out_shapes[to_idx + 1:]
            # remove all edges and intermediate nodes on all edge groups except the first
            for edge_list in edge_group[1:]:
                remove_edges(G, edge_list)
            out_edge_bundles = G.indexed_out_edges(split_node)
            # move edges beyond the edge group after the first index
            for offset, edge_list in enumerate(out_edge_bundles[to_idx + 1:]):
                assert len(edge_list) == 1
                edge = edge_list[0]
                G.remove_edge(edge)
                G.add_edge(NNEdge.clone(edge, from_idx=from_idx + 1 + offset))
            # reindex the in edges in the concat
            from_idx = edge_group[0][-1].to_idx
            to_idx = edge_group[-1][-1].to_idx
            in_edges = G.indexed_in_edges(concat_node)
            for offset, in_edge in enumerate(in_edges[to_idx + 1:]):
                G.remove_edge(in_edge)
                G.add_edge(NNEdge.clone(in_edge, to_idx=from_idx + 1 + offset))

        return bool(edge_groups)