예제 #1
0
 def extract(cls, node: Node):
     VariadicSplit.update_node_stat(
         node, {
             'out_ports_count': node.pb.attr['num_split'].i,
             'swap_axis_and_split_size_inputs': True
         })
     return cls.enabled
예제 #2
0
    def test_negative_variadic_split_axis(self, axis):
        lengths = int64_array([2, 13, 10])
        graph = build_graph(
            self.nodes, self.edges, {
                'split_input_data': {
                    'shape': int64_array([2, 12, 25, 30])
                },
                'split_axis_data': {
                    'value': axis
                },
                'split_lengths_data': {
                    'value': lengths
                },
                'split_op': {
                    'out_ports_count': 4
                },
            })
        node = Node(graph, 'split_op')
        for p in range(len(node.out_edges()), node.out_ports_count):
            node.add_output_port(p)

        try:
            VariadicSplit.infer(node)
        except AssertionError as e:
            self.assertTrue(
                e.args[0] ==
                'VariadicSplit `axis` should be scalar or tensor with shape [1], '
                'but it`s not for node split_op')
예제 #3
0
    def test_variadic_split_axis(self, axis):
        lengths = int64_array([2, 13, 10])
        graph = build_graph(
            self.nodes, self.edges, {
                'split_input_data': {
                    'shape': int64_array([2, 12, 25, 30])
                },
                'split_axis_data': {
                    'value': axis
                },
                'split_lengths_data': {
                    'value': lengths
                },
                'split_op': {
                    'out_ports_count': 4
                },
            })
        node = Node(graph, 'split_op')
        for p in range(len(node.out_edges()), node.out_ports_count):
            node.add_output_port(p)

        VariadicSplit.infer(node)

        ont_nodes_count = len(node.out_edges())
        self.assertTrue(ont_nodes_count == 3)
        for out in range(ont_nodes_count):
            self.assertTrue(
                np.all(
                    node.out_node(out).shape == int64_array(
                        [2, 12, lengths[out], 30])))
예제 #4
0
    def test_variadic_split_value_inference_with_uint32(self):
        axis = int64_array(2)
        # because sum of Python int and Numpy np.uint64 gives float64

        # but np.split accepts only integers and raises error for floats
        # therefore needed to explicitly cast np.split arguments into integer
        # added this test for that case
        lengths = mo_array([2, 13, 10], dtype=np.uint64)
        input_shape = mo_array([2, 12, 25, 30])
        input_value = np.zeros(input_shape)

        graph = build_graph(
            self.nodes, self.edges, {
                'split_input_data': {
                    'shape': input_shape,
                    'value': input_value
                },
                'split_axis_data': {
                    'value': axis
                },
                'split_lengths_data': {
                    'value': lengths
                },
                'split_op': {
                    'out_ports_count': 4
                },
            })
        node = Node(graph, 'split_op')
        for p in range(len(node.out_edges()), node.out_ports_count):
            node.add_output_port(p)

        VariadicSplit.infer(node)

        ont_nodes_count = len(node.out_edges())
        self.assertTrue(ont_nodes_count == 3)
        for out in range(ont_nodes_count):
            self.assertTrue(
                np.all(
                    node.out_node(out).shape == int64_array(
                        [2, 12, lengths[out], 30])))
예제 #5
0
    def find_and_replace_pattern(self, graph: Graph):
        # Iterate over all data nodes and find all with >= 1 consumers
        for input_data in list(graph.get_data_nodes()):
            # We don't use constant data nodes
            if input_data.value is not None:
                continue

            if input_data.shape is None:
                continue
            input_shape = shape_array(input_data.shape)

            # Get all unique StridedSlice consumers
            out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice' and
                         node.in_node(0).id == input_data.id]

            if len(out_nodes) <= 1:
                continue

            valid_for_replacement = True
            for n in out_nodes:
                if any(not isinstance(s, slice) for s in n.slices):
                    # this is a slice with dynamic dimension. Such operation is not valid for replacement
                    valid_for_replacement = False
            if not valid_for_replacement:
                continue

            sorted_out_nodes = sorted(out_nodes, key=lambda n: list(n.slices))
            out_nodes = unique_by(sorted_out_nodes, strided_slices_equality)

            for node in out_nodes:
                if len(node.slices) != len(out_nodes[0].slices):
                    valid_for_replacement = False

            # Detect dimension for splitting
            split_channel_dim = None
            for dim_id, s in enumerate(out_nodes[0].slices):
                l, r, stride = s.start, s.stop, s.step
                # if both l and r are None then the dimension is not sliced
                if (l != 0 or r != input_shape[dim_id]) and (l is not None or r is not None):
                    if split_channel_dim is None:
                        split_channel_dim = dim_id
                    else:
                        valid_for_replacement = False

            if split_channel_dim is None:
                valid_for_replacement = False

            # split_dims contains tuples with split range and output data node
            split_dims = []
            for out_id, node in enumerate(out_nodes):
                # Check that StridedSlice op has stride eq 1 and splits only feature channel
                for id, s in enumerate(node.slices):
                    l, r, stride = s.start, s.stop, s.step
                    # We don't support StridedSlice with stride != 1
                    if stride != 1:
                        valid_for_replacement = False
                    if id == split_channel_dim:
                        split_dims.append((s.start, s.stop, node.out_node()))

            if not valid_for_replacement:
                continue

            # Check feature split intersection
            final_data_nodes_list = []
            sorted_split_dims = sorted(split_dims, key=lambda item: (item[0], item[1]))

            # check if we have similar StridedSlice operations with different outputs
            prev_sd = sorted_split_dims[0]
            to_remove = []
            for i in range(1, len(sorted_split_dims)):
                if sorted_split_dims[i][0] == prev_sd[0] and sorted_split_dims[i][1] == prev_sd[1] and sorted_split_dims[i][2].name != prev_sd[2].name:
                    cur_node = sorted_split_dims[i][2]
                    for out in cur_node.out_nodes():
                        attrs = deepcopy(graph.get_edge_data(cur_node.id, out.id)[0])
                        graph.remove_edge(cur_node.id, out.id)
                        graph.add_edge(prev_sd[2].id, out.id, **attrs)
                    to_remove.append(i)

            for ind in reversed(to_remove):
                sorted_split_dims.pop(ind)

            size_splits = []
            prev_r = 0
            for l, r, out in sorted_split_dims:
                # Split dims shouldn't intersect
                if l < prev_r:
                    valid_for_replacement = False
                prev_r = r

            if prev_r > input_shape[split_channel_dim]:
                valid_for_replacement = False

            if not valid_for_replacement:
                continue

            prev_r = 0
            for l, r, out in sorted_split_dims:
                # Save missing tensor part
                if l > prev_r:
                    shape = mo_array(input_shape)
                    size_splits.append(l - prev_r)
                    shape[split_channel_dim] = l - prev_r
                    data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape})
                    add_opoutput(graph, data_node.id, 0, False, keep_output_port=True)
                    final_data_nodes_list.append(data_node)

                prev_r = r
                size_splits.append(r - l)
                final_data_nodes_list.append(out)

            if prev_r < input_shape[split_channel_dim]:
                # Add last part of tensor
                shape = input_shape.copy()
                shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r
                size_splits.append(input_shape[split_channel_dim] - prev_r)
                data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape})
                add_opoutput(graph, data_node.id, 0, False, keep_output_port=True)
                final_data_nodes_list.append(data_node)

            for node in out_nodes:
                if not np.all([x == 0 for x in node.shrink_axis_mask]):
                    out_node = node.out_node()
                    if np.any(node['shrink_axis_mask']):
                        self.add_squeeze_for_shrink(graph, node)
                    if np.any(node['new_axis_mask']):
                        self.add_unsqueeze_for_new(graph, node)

                    for i in range(len(final_data_nodes_list)):
                        if final_data_nodes_list[i].name == out_node.name:
                            final_data_nodes_list[i] = node.out_node()
                            break

            # Insert Split layer and remove old StridedSlice layers
            # 1. Remove connections from input_data to StridedSlice ops
            out_data_nodes = []
            name_for_future_split = out_nodes[0].name
            for node in out_nodes:
                out_data_nodes.append(node.out_node())
                graph.remove_edge(input_data.id, node.id)
                graph.remove_edge(node.id, node.out_node().id)
                graph.remove_node(node.id)
                log.debug("Removed: {}".format(node.id))

            # 2. Create Split layer and reorder outputs
            name = name_for_future_split + "/Split"
            axis_const = Const(graph, {'value': int64_array(split_channel_dim),
                                       'name': name + '/Axis'}).create_node_with_data()
            size_splits_const = Const(graph, {'value': int64_array(size_splits),
                                              'name': name + '/Sizes'}).create_node_with_data()
            split = VariadicSplit(graph, dict(name=name, out_ports_count=len(size_splits)))

            split.create_node_with_data(inputs=[input_data, axis_const, size_splits_const],
                                        data_nodes=final_data_nodes_list)