Beispiel #1
0
    def _common(cls, node, **kwargs):
        all_nodes = kwargs['all_nodes']
        G = kwargs['G']
        axis = kwargs['axis']
        splits = kwargs.get('splits')
        opts = kwargs['opts']
        input_idx = kwargs.get('input_idx', 0)
        num_splits = kwargs.get('num_splits')

        inputs = [all_nodes[inp] for inp in node.input]
        x = inputs[input_idx]

        x_shape = x[2].shape
        act_slices, pout_shapes, axis = SplitParameters.get_splits(
            x_shape, axis, splits=splits, num_splits=num_splits)
        out_shapes = [
            BackendHandler.remove_unspecified_dim(shape)
            for shape in pout_shapes
        ]
        params = SplitParameters(node.name,
                                 act_slices=act_slices,
                                 out_shapes=out_shapes,
                                 axis=axis)

        if opts.get('load_quantization'):
            G.quantization[NodeId(
                params)] = BackendHandler.load_tf_quantization([node.input[0]],
                                                               node.output)

        G.add_edge(
            NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))

        for idx, tensor in enumerate(node.output):
            all_nodes[tensor] = (params, idx, ProvisionalDim(pout_shapes[idx]))
        return params
    def _common(cls, node, **kwargs):
        all_nodes = kwargs['all_nodes']
        G = kwargs['G']
        axis = kwargs['axis']
        splits = kwargs.get('splits')
        opts = kwargs['opts']
        input_idx = kwargs.get('input_idx', 0)
        num_splits = kwargs.get('num_splits')
        # eliminates a silly split / unpack that does dothing - seen in dtln1.tflite
        if splits is None and num_splits == 1 and len(node.output) == 1:
            all_nodes[node.output[0]] = all_nodes[node.input[0]]
            return None

        inputs = [all_nodes[inp] for inp in node.input]
        x = inputs[input_idx]
        x_shape = x[2].shape
        if axis and axis < 0:
            axis = axis + len(x_shape)
        act_slices, pout_shapes, axis = SplitParameters.get_splits(
            x_shape, axis, splits=splits, num_splits=num_splits)
        out_shapes = [
            BackendHandler.remove_unspecified_dim(shape)
            for shape in pout_shapes
        ]
        params = SplitParameters(node.name,
                                 act_slices=act_slices,
                                 out_shapes=out_shapes,
                                 axis=axis)

        if opts.get('load_quantization'):
            G.quantization[NodeId(
                params)] = BackendHandler.load_tf_quantization([node.input[0]],
                                                               node.output)

        G.add_edge(
            NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))

        for idx, tensor in enumerate(node.output):
            all_nodes[tensor] = (params, idx, ProvisionalDim(pout_shapes[idx]))
        return params
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        if G.quantization:
            LOG.warning(
                'match_duplicate_operations does not handle quantized graphs')
            return False

        def same_source_edge_fn(x):
            return f"{x.from_node.__hash__()}##{x.from_idx}"

        def same_dest_edge(x):
            return f"{x.to_node.__hash__()}##{x.to_idx}"

        modified_graph = False
        while True:
            found_more = False
            same_source_edges = [
                list(edge_list) for _, edge_list in groupby(
                    sorted(G.edges(), key=same_source_edge_fn),
                    same_source_edge_fn)
            ]
            # all have the same origin
            same_source_edges = [
                elem for elem in same_source_edges if len(elem) > 1
            ]
            same_dest_edges = []
            same_dest_group_edges = []

            for same_source_edge in same_source_edges:
                same_source_edge = [
                    edge for edge in same_source_edge
                    if isinstance(edge.to_node, ComparableParameters)
                ]
                while same_source_edge:
                    first = same_source_edge.pop(0)

                    others = list(
                        filter(
                            partial(
                                lambda x, y: x.to_node != y.to_node and y.
                                to_node.is_same_operation_as(G, x.to_node),
                                first), same_source_edge))
                    if others:
                        same_dest_edges.append(tuple([first] + others))
                        for other in others:
                            same_source_edge.remove(other)
                        continue

                    other_groups = list(
                        filter(
                            partial(
                                lambda x, y: x.to_node != y.to_node and y.
                                to_node.can_be_grouped_with(x.to_node), first),
                            same_source_edge))
                    if other_groups:
                        same_dest_group_edges.append(
                            tuple([first] + other_groups))
                        for other in other_groups:
                            same_source_edge.remove(other)

            # all are multiple edges that go to something comparable
            save_same_dest_edges = same_dest_edges.copy()
            while same_dest_edges:
                edge_set = same_dest_edges.pop(0)
                keep_node = edge_set[0].to_node
                other_edge_sets = [
                    edges for edges in same_dest_edges
                    if any(edge.to_node == keep_node for edge in edges)
                ]
                for other_edge_set in other_edge_sets:
                    same_dest_edges.remove(other_edge_set)

                nodes_to_delete = set()
                for edge_set in [edge_set] + other_edge_sets:
                    for edge in edge_set:
                        other_node = edge.to_node
                        if other_node == keep_node or other_node in nodes_to_delete:
                            continue
                        nodes_to_delete.add(other_node)
                        for out_edge in G.out_edges(other_node):
                            G.add_edge(
                                NNEdge(from_node=keep_node,
                                       to_node=out_edge.to_node,
                                       to_idx=out_edge.to_idx))
                LOG.info(
                    f'removed duplicates {",".join(node.name for node in nodes_to_delete)} to {keep_node.name}'
                )
                for node in nodes_to_delete:
                    G.remove(node)

            # # all are multiple edges that go to something comparable

            # for edge_set in same_dest_edges:
            #     modified_graph = True
            #     found_more = True
            #     first = edge_set[0]
            #     first_node = first.to_node
            #     dup_nodes = []
            #     for other in edge_set[1::]:
            #         dest_node = other.to_node
            #         dup_nodes.append(dest_node.name)
            #         out_edges = G.out_edges(dest_node.name)
            #         G.remove(dest_node)
            #         for out_edge in out_edges:
            #             G.add_edge(NNEdge(from_node=first_node, to_node=out_edge.to_node,
            #                               from_idx=out_edge.from_idx, to_idx=out_edge.to_idx))
            #     LOG.info(
            #         f'removed duplicates {",".join(dup_nodes)} to {first_node.name}')

            for edge_set in same_dest_group_edges:
                modified_graph = True
                found_more = True
                # we will merge all the convolutions into one
                first = edge_set[0]
                first_node = first.to_node
                in_edges = G.indexed_in_edges(first_node.name)
                first_filter = first_node.filter
                weights_node = in_edges[1].from_node
                biases_node = in_edges[2].from_node
                dup_nodes = []
                num_convs = len(edge_set)
                out_shape = deepcopy(first_node.out_dims[0])
                out_shape.c *= num_convs
                # create a split after the first node splitting on channel axis
                act_slices, out_shapes, axis = SplitParameters.get_splits(
                    out_shape,
                    out_shape.get_order_idx('c'),
                    num_splits=num_convs)
                split1 = SplitParameters(
                    G.unique_name(f'{first_node.name}_split'),
                    act_slices=act_slices,
                    out_shapes=out_shapes,
                    axis=axis)
                out_num = 0
                # first node out edge goes to split
                out_edges = G.out_edges(first_node.name)
                for edge in out_edges:
                    G.remove_edge(edge)
                    G.add_edge(
                        NNEdge(from_node=split1,
                               from_idx=out_num,
                               to_node=edge.to_node,
                               to_idx=edge.to_idx))
                G.add_edge(NNEdge(from_node=first_node, to_node=split1))
                # first split output goes to original output
                for other in edge_set[1::]:
                    out_num += 1
                    node_other = other.to_node
                    dup_nodes.append(node_other.name)
                    in_edges = G.indexed_in_edges(node_other.name)
                    weights_other = in_edges[1].from_node
                    biases_other = in_edges[2].from_node
                    # merge the weights and biases diwn output channel
                    weights_node.value = np.concatenate(
                        (weights_node.value, weights_other.value),
                        axis=first_filter.get_order_idx('out_c'))
                    weights_node.dims = Dim.unnamed(weights_node.value.shape)
                    biases_node.value = np.concatenate(
                        (biases_node.value, biases_other.value))
                    biases_node.dims = Dim.unnamed(biases_node.value.shape)
                    first_filter.out_c += node_other.filter.out_c
                    # wire edge from split
                    out_edges = G.out_edges(node_other.name)
                    G.remove(node_other)
                    G.remove(weights_other)
                    G.remove(biases_other)
                    for edge in out_edges:
                        G.add_edge(
                            NNEdge(from_node=split1,
                                   from_idx=out_num,
                                   to_node=edge.to_node,
                                   to_idx=edge.to_idx))
                LOG.info(
                    f'merged convolutions {",".join(dup_nodes)} into {first_node.name}'
                )
            if not found_more:
                break

        if set_identity:
            self.set_identity(G)

        return modified_graph
Beispiel #4
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        gathers_by_origin = {}
        for gather in [
                node for node in G.nodes()
                if isinstance(node, GatherParameters)
        ]:
            in_edge = G.in_edges(gather.name)[0]
            group = gathers_by_origin.setdefault(
                (in_edge.from_node, in_edge.from_idx), [])
            group.append(gather)
        for in_edge, gathers in gathers_by_origin.items():
            # This is too difficult to handle if there are multiple slices
            axis = gathers[0].axis
            if not all(gather.axis == axis and len(gather.indices.shape) <= 1
                       for gather in gathers[1::]):
                continue
            # sort all the indices
            gathers = sorted(gathers,
                             key=lambda x: x.indices
                             if len(x.indices.shape) == 0 else x.indices[0])
            indices = [
                elem for gather in gathers
                for elem in ([int(gather.indices)] if len(gather.indices.shape)
                             == 0 else list(gather.indices))
            ]
            # All the indices must be independant and sum to the out dim (this could be relaxed but
            # then needs to handle gaps)
            in_shape = in_edge[0].out_dims[in_edge[1]].shape
            in_shape_without_axis = in_shape[:axis:] + in_shape[axis + 1::]
            if len(set(indices)) != len(indices) and len(
                    set(indices)) == in_shape[axis]:
                continue
            # good for a split
            LOG.info("gathers from %s[%s] converted to a split",
                     in_edge[0].name, in_edge[1])
            splits = []
            shapes = []
            out_edges = []
            for gather in gathers:
                splits.append(
                    [tuple([int(gather.indices),
                            int(gather.indices) + 1, 1])])
                shapes.append(in_shape_without_axis)
                out_edges.append(G.out_edges(gather.name))
                G.remove(gather)
            params = SplitParameters("%s_split" % in_edge[0].name,
                                     act_slices=splits,
                                     out_shapes=shapes,
                                     axis=axis)
            if axis != 0:
                trans = [axis] + list(range(0, axis)) + list(
                    range(axis, len(in_shape)))
                params.transpose_out = [[
                    trans.index(idx) for idx in range(len(trans))
                ]]
                params.transpose_in = [trans]
            for idx, edges in enumerate(out_edges):
                for edge in edges:
                    G.add_edge(
                        NNEdge(from_node=params,
                               to_node=edge.to_node,
                               from_idx=idx,
                               to_idx=edge.to_idx))
            G.add_edge(
                NNEdge(from_node=in_edge[0],
                       to_node=params,
                       from_idx=in_edge[1]))
            has_modified_graph = True

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Beispiel #5
0
    def _common(cls, node, **kwargs):
        all_nodes = kwargs['all_nodes']
        G = kwargs['G']
        valid_name = kwargs['valid_name']
        inputs = [all_nodes[inp] for inp in node.input]

        x = inputs[0]
        x_shape = x[2].shape
        axis = node.attrs.get('axis', 0)

        if axis < 0:
            axis += len(x_shape)

        assert axis < len(x_shape) and axis >= 0,\
            "axis %s is out of bounds - input dims %s in node %s" % (axis, x_shape, valid_name)

        split_dim = x_shape[axis]
        assert split_dim is not None, "split dimension must be defined"

        split = None
        if cls.SINCE_VERSION >= 13:
            if len(inputs) > 1:
                split = cls.get_constant(inputs[1])
        else:
            split = node.attrs.get('split')
            if split:
                split = np.array(split)
                assert sum(
                    split
                ) == split_dim, "split sizes should add up to total size %s" % valid_name
                assert np.all(
                    split > 0
                ), "split sizes should be greater than zero %s" % valid_name
            else:
                num_outputs = len(node.output)
                assert split_dim % num_outputs == 0,\
                    "no split attribute or value and dimension is not divisible by number of outputs %s" % valid_name
                split = np.array([split_dim // num_outputs] * num_outputs)

        split = split.tolist()
        act_slices = []
        out_shapes = []
        out_pshapes = []
        cur = 0
        for idx, split_dim in enumerate(split):
            act_slices.append(
                tuple(
                    (cur, cur + split_dim, 1) if didx == axis else (0, dim, 1)
                    for didx, dim in enumerate(x_shape) if dim is not None))
            out_pshape = tuple(split_dim if didx == axis else dim
                               for didx, dim in enumerate(x_shape))
            out_shapes.append(
                tuple(dim for dim in out_pshape if dim is not None))
            out_pshapes.append(ProvisionalDim(out_pshape))
            cur += split_dim
        axis -= sum(1 if dim is None else 0 for dim in x_shape[:axis:])
        params = SplitParameters(valid_name,
                                 act_slices=act_slices,
                                 out_shapes=out_shapes,
                                 axis=axis)
        if cls.is_constant(x):
            logger.info("reducing %s to %s constant(s)", valid_name,
                        len(out_shapes))
            values = params.numpy_split(cls.get_constant(x))
            for idx, out_pshape in enumerate(out_pshapes):
                cparams = ConstantInputParameters(valid_name,
                                                  value=values[idx])
                all_nodes[node.output[idx]] = (cparams, 0, out_pshape, x[3])
            return None

        G.add_edge(
            NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))
        for idx, out_pshape in enumerate(out_pshapes):
            all_nodes[node.output[idx]] = (params, idx, out_pshape, x[3])
        return params
Beispiel #6
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        slices_by_origin = {}
        for slice_node in [
                node for node in G.nodes()
                if isinstance(node, StridedSliceParameters)
        ]:
            in_edge = G.in_edges(slice_node.name)[0]
            group = slices_by_origin.setdefault(
                (in_edge.from_node, in_edge.from_idx), [])
            group.append(slice_node)
        for in_edge, slice_nodes in slices_by_origin.items():
            slices = list(zip(*[node.act_slice for node in slice_nodes]))
            if len(slice_nodes) == 1:
                self.slice_to_split(G, slice_nodes, slices)
                continue

            # strides must be one
            if any(sl[2] != 1 for sl_axis in slices for sl in sl_axis):
                continue

            diff_axes = list([
                idx for idx, elems in enumerate(slices)
                if not all(elems[0] == elem for elem in elems[1::])
            ])
            not_diff_axes = [
                idx for idx in range(len(slices)) if idx not in diff_axes
            ]
            diff_slices = [
                sl for idx, sl in enumerate(slices) if idx in diff_axes
            ]
            axis_lengths = in_edge[0].out_dims[in_edge[1]].shape
            if not_diff_axes and min(not_diff_axes) < max(diff_axes):
                transpose_from = tuple(range(len(slices)))
                transpose_to = tuple(diff_axes + not_diff_axes)
                axis_lengths = [axis_lengths[idx] for idx in transpose_to]
            else:
                transpose_from = transpose_to = None
            diff_axis_lengths = axis_lengths[0:len(diff_axes):]

            diff_slices = combine_slices(diff_axis_lengths, diff_slices,
                                         slice_nodes)
            if diff_slices is None:
                continue

            if len(diff_axes) > 1:
                reshape_from = axis_lengths
                reshape_to = [np.prod(diff_axis_lengths)] + \
                    axis_lengths[len(diff_axes)::]
            else:
                reshape_from = None
                reshape_to = slice_nodes[0].in_dims[0].shape
                if transpose_from:
                    reshape_to = [reshape_to[idx] for idx in transpose_to]

            sizes, shapes, sorted_nodes = slices_to_sizes(
                diff_slices, axis_lengths[len(diff_axes)::])

            name_prefix = sorted_nodes[0].name

            in_edge = G.in_edges(sorted_nodes[0].name)[0]
            in_node = in_edge.from_node
            in_idx = in_edge.from_idx

            if transpose_from:
                params = TransposeParameters(G.unique_name(name_prefix +
                                                           '_tin'),
                                             transpose=transpose_to)
                G.add_edge(
                    NNEdge(from_node=in_node, to_node=params, from_idx=in_idx))
                in_node = params
                in_idx = 0

            if reshape_from:
                params = ReshapeParameters(G.unique_name(name_prefix +
                                                         '_reshape'),
                                           old_shape=Dim.unnamed(reshape_from),
                                           shape=Dim.unnamed(reshape_to))
                G.add_edge(
                    NNEdge(from_node=in_node, to_node=params, from_idx=in_idx))
                in_node = params
                in_idx = 0

            act_slices, out_shapes, axis = SplitParameters.get_splits(
                reshape_to, 0, splits=sizes)
            split_node = SplitParameters(G.unique_name(name_prefix + '_split'),
                                         act_slices=act_slices,
                                         out_shapes=out_shapes,
                                         axis=axis)

            G.add_edge(
                NNEdge(from_node=in_node, from_idx=in_idx, to_node=split_node))

            sub_names = []
            for idx, node in enumerate(sorted_nodes):
                sub_names.append(node.name)
                out_edges = G.out_edges(node.name)
                G.remove(node)
                for out_edge in out_edges:
                    params = split_node
                    out_idx = idx
                    if reshape_from:
                        from_node = params
                        params = ReshapeParameters(
                            G.unique_name(name_prefix + f'_reshape{idx}'),
                            shape=Dim.unnamed(shapes[idx]))
                        G.add_edge(
                            NNEdge(from_node=from_node,
                                   to_node=params,
                                   from_idx=out_idx))
                        out_idx = 0
                    if transpose_from:
                        from_node = params
                        params = TransposeParameters(
                            G.unique_name(name_prefix + f'_tout{idx}'),
                            transpose=reverse_transpose(transpose_to))
                        G.add_edge(
                            NNEdge(from_node=from_node,
                                   to_node=params,
                                   from_idx=out_idx))
                        out_idx = 0

                    G.add_edge(
                        NNEdge(from_node=params,
                               to_node=out_edge.to_node,
                               from_idx=out_idx,
                               to_idx=out_edge.to_idx))
            if G.quantization:
                G.add_dimensions()
                quantizer = NewQuantizer.from_quantized_graph(G)
                quantizer.quantize()
                RemoveUnnecessaryQuantizeOperators().match(G)

            LOG.info(
                f'replaced slice nodes {",".join(sub_names)} with split node {split_node.name}'
            )

            has_modified_graph = True

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Beispiel #7
0
 def slice_to_split(G, slice_nodes, slices):
     slice_node = slice_nodes[0]
     in_dims = slice_node.in_dims[0].shape
     slices = [sl[0] for sl in slices]
     if any(sl[2] != 1 for sl in slices):
         return
     szes = tuple([sl[1] - sl[0] for sl in slices])
     # find sliced axes that differ
     diff_axis = tuple(idx
                       for idx, (d1, d2) in enumerate(zip(szes, in_dims))
                       if d1 != d2)
     if len(diff_axis) != 1:
         return
     # good to convert to a split
     axis = diff_axis[0]
     axis_slice = slices[axis]
     axis_dim = in_dims[axis]
     outs = []
     splits = []
     if axis_slice[0] > 0:
         splits.append(axis_slice[0])
         oparams = OutputParameters(G.unique_name('unused'))
         oparams.at_options.allocate = 1
         outs.append(((oparams, 0), ))
     splits.append(axis_slice[1] - axis_slice[0])
     outs.append([(edge.to_node, edge.to_idx)
                  for edge in G.out_edges(slice_node.name)])
     if axis_slice[1] < axis_dim:
         splits.append(axis_dim - axis_slice[1])
         oparams = OutputParameters(G.unique_name('unused'))
         oparams.at_options.allocate = 1
         outs.append(((oparams, 0), ))
     in_edge = G.in_edges(slice_node.name)[0]
     G.remove(slice_node)
     act_slices, out_shapes, axis = SplitParameters.get_splits(
         in_dims, axis, splits=splits)
     LOG.info(
         'replacing strided slice %s with split with %s redundant outputs',
         slice_node.name,
         len(outs) - 1)
     if axis != 0:
         LOG.warning('adjust needs to be rerun')
     split_params = SplitParameters(slice_node.name,
                                    act_slices=act_slices,
                                    out_shapes=out_shapes,
                                    axis=axis)
     G.add_edge(
         NNEdge(from_node=in_edge.from_node,
                from_idx=in_edge.from_idx,
                to_node=split_params))
     for out_idx, out_cons in enumerate(outs):
         for out_con in out_cons:
             G.add_edge(
                 NNEdge(from_node=split_params,
                        from_idx=out_idx,
                        to_node=out_con[0],
                        to_idx=out_con[1]))
     if G.quantization:
         G.add_dimensions()
         quantizer = NewQuantizer.from_quantized_graph(G)
         quantizer.quantize()
         RemoveUnnecessaryQuantizeOperators().match(G)
Beispiel #8
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        slices_by_origin = {}
        for slice_node in [
                node for node in G.nodes()
                if isinstance(node, StridedSliceParameters)
        ]:
            in_edge = G.in_edges(slice_node.name)[0]
            group = slices_by_origin.setdefault(
                (in_edge.from_node, in_edge.from_idx), [])
            group.append(slice_node)
        for in_edge, slice_nodes in slices_by_origin.items():
            slices = list(zip(*[node.act_slice for node in slice_nodes]))
            if len(slice_nodes) == 1:
                self.slice_to_split(G, slice_nodes, slices)
                continue

            diff_slices = [(idx, elems) for idx, elems in enumerate(slices)
                           if not all(elems[0] == elem for elem in elems[1::])]
            if len(diff_slices) != 1:
                continue
            # strides must be one
            if any(sl[2] != 1 for sl in diff_slices[0][1]):
                continue
            # check if slices are consecutive and non overlapping
            slices = sorted(diff_slices[0][1], key=lambda x: x[0])
            if not all(sl[0] + sl[1] == slices[i + 1][0]
                       for i, sl in enumerate(slices[:-1:])):
                continue
            szes = [sl[1] - sl[0] for sl in slices]
            axis = diff_slices[0][0]
            slice_nodes = sorted(slice_nodes,
                                 key=lambda x: x.act_slice[axis][0])
            act_slices, out_shapes, axis = SplitParameters.get_splits(
                slice_nodes[0].in_dims[0].shape, axis, splits=szes)
            params = SplitParameters(slice_nodes[0].name + '_split',
                                     act_slices=act_slices,
                                     out_shapes=out_shapes,
                                     axis=axis)
            in_edge = G.in_edges(slice_nodes[0].name)[0]
            G.add_edge(
                NNEdge(from_node=in_edge.from_node,
                       to_node=params,
                       from_idx=in_edge.from_idx))
            sub_names = []
            for idx, node in enumerate(slice_nodes):
                sub_names.append(node.name)
                out_edges = G.out_edges(node.name)
                G.remove(node)
                for out_edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=params,
                               to_node=out_edge.to_node,
                               from_idx=idx,
                               to_idx=out_edge.to_idx))
            if G.quantization:
                G.add_dimensions()
                quantizer = UnifiedQuantizer.from_quantized_graph(G)
                quantizer.quantize(G, start_nodes=[params])
                RemoveUnnecessaryQuantizeOperators().match(G)

            LOG.info(
                f'replaced slice nodes {",".join(sub_names)} with split node {sub_names[0]}'
            )

            has_modified_graph = True

        if set_identity:
            self.set_identity(G)

        return has_modified_graph