Beispiel #1
0
def walk_down(subg: GraphView, node: Parameters,
              node_slices: Mapping[Parameters, Sequence[SlicedTensor]]):

    # edges not created
    if isinstance(node, FusionInputParameters):
        inp_slice = InputSlice.from_shape(node, node.dims.shape)
        dim_slices = node_slices[node] = [
            SlicedTensor([
                SliceElement(tuple([0] * inp_slice.rank), inp_slice.shape,
                             inp_slice)
            ])
        ]
    else:
        dim_slices = node_slices.get(node)
        if dim_slices is None:
            return
        # all edges not created
        if len(dim_slices) < subg.num_in_edges(node.name):
            return
        # all edges not created
        if any(val is None for val in dim_slices):
            return
        if isinstance(node, Transposable) and node.transpose_in:
            for idx, transpose in enumerate(node.transpose_in):
                if transpose:
                    dim_slices[idx] = dim_slices[idx].transpose(transpose)
        if isinstance(node, ConcatParameters):
            dim_slices = [SlicedTensor.concat(*dim_slices, axis=node.axis)]
        elif isinstance(node, SplitParameters):
            dim_slices = dim_slices[0].split(node.act_slices)
        elif isinstance(node, StridedSliceParameters):
            dim_slices = [dim_slices[0].slice(node.act_slice)]

    if isinstance(node, Transposable) and node.transpose_out:
        for idx, transpose in enumerate(node.transpose_out):
            if transpose:
                dim_slices[idx] = dim_slices[idx].transpose(transpose)

    # set output edges
    for edge_set in subg.indexed_out_edges(node.name):
        for edge in edge_set:
            dest_slices = node_slices.setdefault(edge.to_node,
                                                 [None] * (edge.to_idx + 1))
            if len(dest_slices) < edge.to_idx + 1:
                dest_slices = dest_slices + \
                    ([None] * ((edge.to_idx + 1) - len(dest_slices)))
                node_slices[edge.to_node] = dest_slices
            dest_slices[edge.to_idx] = dim_slices[edge.from_idx]

    # explore graph
    for edge_set in subg.indexed_out_edges(node.name):
        for edge in edge_set:
            walk_down(subg, edge.to_node, node_slices)
Beispiel #2
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        modified_graph = False
        candidates = [node for node in G.nodes()
                      if len(G.indexed_out_edges(node.name)) == 1 and len(G.out_edges(node.name)) > 1]
        while candidates:
            node = candidates.pop(0)
            strings = self.explore(G, [node])
            if not strings:
                continue
            modified_graph = True
            primary = strings.pop(0)
            for pnode in primary:
                if pnode in candidates:
                    candidates.remove(pnode)
            out_edges = []
            for other in strings:
                out_edges.extend(G.out_edges(other[-1].name))
                for other_node in other:
                    if other_node in candidates:
                        candidates.remove(other_node)
                    G.remove(other_node)
                    nid = NodeId(other_node)
                    if G.quantization and nid in G.quantization:
                        del G.quantization[nid]
                LOG.info(
                    f'removed duplicates from {primary[0].name} {",".join(node.name for node in other)}')
            pend = primary[-1]
            for edge in out_edges:
                G.add_edge(
                    NNEdge(from_node=pend, to_node=edge.to_node, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return modified_graph
Beispiel #3
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        for split_node in set(
            [node for node in G.nodes() if isinstance(node, SplitParameters)]):
            in_edges = G.in_edges(split_node.name)
            if len(in_edges) > 1:
                continue
            in_edge = in_edges[0]
            if not isinstance(in_edge.from_node, ConcatParameters):
                continue
            concat_node = in_edge.from_node
            if len(G.out_edges(concat_node.name)) > 1:
                continue
            if concat_node.transpose_out or split_node.transpose_in:
                continue
            if concat_node.axis != split_node.axis:
                continue
            axis = concat_node.axis
            split_out_sizes = [
                out_shape[axis] for out_shape in split_node.out_shapes
            ]
            if len(split_out_sizes) != len(concat_node.in_dims):
                continue
            if not all(split_out_sizes[idx] == in_dim.shape[axis]
                       for idx, in_dim in enumerate(concat_node.in_dims)):
                continue
            has_modified_graph = True
            LOG.info("removing unnecessary concat/split pair %s/%s",
                     concat_node.name, split_node.name)
            concat_in_edges = G.indexed_in_edges(concat_node.name)
            split_out_edges = G.indexed_out_edges(split_node.name)
            G.remove(split_node)
            G.remove(concat_node)
            for idx, in_edge in enumerate(concat_in_edges):
                for out_edge in split_out_edges[idx]:
                    G.add_edge(
                        NNEdge(from_node=in_edge.from_node,
                               from_idx=in_edge.from_idx,
                               to_node=out_edge.to_node,
                               to_idx=out_edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        edge_groups = []
        for node in G.nodes(node_classes=SplitParameters):
            cur_group = None
            for out_edge_bundle in G.indexed_out_edges(node):
                if len(out_edge_bundle) == 1:
                    out_edge = out_edge_bundle[0]
                    concat_node_edges = search_down(G,
                                                    out_edge,
                                                    ConcatParameters,
                                                    can_pass=(CopyParameters,
                                                              NoOPParameters))
                    if concat_node_edges:
                        if cur_group:
                            this_concat_edge = concat_node_edges[-1]
                            last_concat_edge = cur_group[-1][-1]
                            if this_concat_edge.to_node == last_concat_edge.to_node and this_concat_edge.to_idx == last_concat_edge.to_idx + 1:
                                cur_group.append(concat_node_edges)
                                continue
                            if len(cur_group) > 1:
                                edge_groups.append(cur_group)
                        cur_group = [concat_node_edges]
                        continue
                if cur_group:
                    if len(cur_group) > 1:
                        edge_groups.append(cur_group)
                    cur_group = None
            if cur_group:
                if len(cur_group) > 1:
                    edge_groups.append(cur_group)
                cur_group = None
        # we leave the splits and concats after this since they will be cleared up by remove_noops
        for edge_group in edge_groups:
            split_node = edge_group[0][0].from_node
            concat_node = edge_group[0][-1].to_node
            from_idx = edge_group[0][0].from_idx
            to_idx = edge_group[-1][0].from_idx
            LOG.info(
                f"combining outputs {from_idx}:{to_idx} on split node {split_node.name} followed by concat {concat_node.name}"
            )
            # combine slices and shapes on edges in group
            new_slice, new_shape = reduce_slices(
                split_node.act_slices[from_idx:to_idx + 1],
                split_node.out_shapes[from_idx:to_idx + 1])
            split_node.act_slices = split_node.act_slices[:from_idx] + [
                new_slice
            ] + split_node.act_slices[to_idx + 1:]
            split_node.out_shapes = split_node.out_shapes[:from_idx] + [
                new_shape
            ] + split_node.out_shapes[to_idx + 1:]
            # remove all edges and intermediate nodes on all edge groups except the first
            for edge_list in edge_group[1:]:
                remove_edges(G, edge_list)
            out_edge_bundles = G.indexed_out_edges(split_node)
            # move edges beyond the edge group after the first index
            for offset, edge_list in enumerate(out_edge_bundles[to_idx + 1:]):
                assert len(edge_list) == 1
                edge = edge_list[0]
                G.remove_edge(edge)
                G.add_edge(NNEdge.clone(edge, from_idx=from_idx + 1 + offset))
            # reindex the in edges in the concat
            from_idx = edge_group[0][-1].to_idx
            to_idx = edge_group[-1][-1].to_idx
            in_edges = G.indexed_in_edges(concat_node)
            for offset, in_edge in enumerate(in_edges[to_idx + 1:]):
                G.remove_edge(in_edge)
                G.add_edge(NNEdge.clone(in_edge, to_idx=from_idx + 1 + offset))

        return bool(edge_groups)