コード例 #1
0
ファイル: split_test.py プロジェクト: yding10/openvino
    def test_negative_variadic_split_axis(self, axis):
        lengths = int64_array([2, 13, 10])
        graph = build_graph(
            self.nodes, self.edges, {
                'split_input_data': {
                    'shape': int64_array([2, 12, 25, 30])
                },
                'split_axis_data': {
                    'value': axis
                },
                'split_lengths_data': {
                    'value': lengths
                },
                'split_op': {
                    'out_ports_count': 4
                },
            })
        node = Node(graph, 'split_op')
        for p in range(len(node.out_edges()), node.out_ports_count):
            node.add_output_port(p)

        try:
            VariadicSplit.infer(node)
        except AssertionError as e:
            self.assertTrue(
                e.args[0] ==
                'VariadicSplit `axis` should be scalar or tensor with shape [1], '
                'but it`s not for node split_op')
コード例 #2
0
 def extract(cls, node: Node):
     VariadicSplit.update_node_stat(
         node, {
             'out_ports_count': node.pb.attr['num_split'].i,
             'swap_axis_and_split_size_inputs': True
         })
     return cls.enabled
コード例 #3
0
ファイル: split_test.py プロジェクト: yding10/openvino
    def test_variadic_split_axis(self, axis):
        lengths = int64_array([2, 13, 10])
        graph = build_graph(
            self.nodes, self.edges, {
                'split_input_data': {
                    'shape': int64_array([2, 12, 25, 30])
                },
                'split_axis_data': {
                    'value': axis
                },
                'split_lengths_data': {
                    'value': lengths
                },
                'split_op': {
                    'out_ports_count': 4
                },
            })
        node = Node(graph, 'split_op')
        for p in range(len(node.out_edges()), node.out_ports_count):
            node.add_output_port(p)

        VariadicSplit.infer(node)

        ont_nodes_count = len(node.out_edges())
        self.assertTrue(ont_nodes_count == 3)
        for out in range(ont_nodes_count):
            self.assertTrue(
                np.all(
                    node.out_node(out).shape == int64_array(
                        [2, 12, lengths[out], 30])))
コード例 #4
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['lstm_cell']
        lstm_cell_name = node.soft_get('name', node.id)
        hidden_size = node.get_attrs()["hidden_size"]

        WR_shape = node.in_port(3).data.get_shape()
        assert WR_shape is not None, "Undefined 'WR' input shape for LSTM Cell node '{}'".format(
            lstm_cell_name)

        num_elements_in_WR = np.prod(WR_shape)
        input_size = (num_elements_in_WR / (4 * hidden_size)) - hidden_size

        # Reshape
        reshape = create_op_node_with_second_input(
            graph, Reshape,
            int64_array([4 * hidden_size, hidden_size + input_size]),
            {'name': lstm_cell_name + '/Dims'})

        # VariadicSplit
        const_axis = Const(graph, {'value': 1}).create_node()
        const_size_splits = Const(
            graph, {
                'value': int64_array([input_size, hidden_size])
            }).create_node()
        split = VariadicSplit(graph, {
            'name': lstm_cell_name + '/Split',
            'out_ports_count': 2
        }).create_node()
        const_axis.out_port(0).connect(split.in_port(1))
        const_size_splits.out_port(0).connect(split.in_port(2))

        # LSTM Cell
        node.in_port(3).get_connection().set_destination(reshape.in_port(0))
        reshape.out_port(0).connect(split.in_port(0))

        node.add_input_port(5, skip_if_exist=True)
        assert node.in_port(5).disconnected()
        node.in_port(4).get_connection().set_destination(node.in_port(5))

        split.out_port(0).connect(node.in_port(3))
        split.out_port(1).connect(node.in_port(4))
コード例 #5
0
    def find_and_replace_pattern(self, graph: Graph):
        # Iterate over all data nodes and find all with >= 1 consumers
        for input_data in list(graph.get_data_nodes()):
            # We don't use constant data nodes
            if input_data.value is not None:
                continue

            input_shape = np.array(input_data.shape)

            # Get all unique StridedSlice consumers
            out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice' and node.in_node(0).name == input_data.name]
            sorted_out_nodes = sorted(out_nodes, key=lambda n: list(n.slices))
            out_nodes = unique_by(sorted_out_nodes, strided_slices_equality)
            if len(out_nodes) <= 1:
                continue

            valid_for_replacement = True

            for node in out_nodes:
                if len(node.slices) != len(out_nodes[0].slices):
                    valid_for_replacement = False

            # Detect dimension for splitting
            split_channel_dim = None
            for dim_id, s in enumerate(out_nodes[0].slices):
                l, r, stride = s.start, s.stop, s.step
                if l != 0 or r != input_shape[dim_id]:
                    if split_channel_dim is None:
                        split_channel_dim = dim_id
                    else:
                        valid_for_replacement = False

            if split_channel_dim is None:
                valid_for_replacement = False

            # split_dims contains tuples with split range and output data node
            split_dims = []
            for out_id, node in enumerate(out_nodes):
                # Check that StridedSlice op has stride eq 1 and splits only feature channel
                for id, s in enumerate(node.slices):
                    l, r, stride = s.start, s.stop, s.step
                    # We don't support StridedSlice with stride != 1
                    if stride != 1:
                        valid_for_replacement = False
                    if id == split_channel_dim:
                        split_dims.append((s.start, s.stop, node.out_node()))

            if not valid_for_replacement:
                continue

            # Check feature split intersection
            final_data_nodes_list = []
            sorted_split_dims = sorted(split_dims, key=lambda item: (item[0], item[1]))

            # check if we have similar StridedSlice operations with different outputs
            prev_sd = sorted_split_dims[0]
            to_remove = []
            for i in range(1, len(sorted_split_dims)):
                if sorted_split_dims[i][0] == prev_sd[0] and sorted_split_dims[i][1] == prev_sd[1] and sorted_split_dims[i][2].name != prev_sd[2].name:
                    cur_node = sorted_split_dims[i][2]
                    for out in cur_node.out_nodes():
                        attrs = deepcopy(graph.get_edge_data(cur_node.id, out.id)[0])
                        graph.remove_edge(cur_node.id, out.id)
                        graph.add_edge(prev_sd[2].id, out.id, **attrs)
                    to_remove.append(i)

            for ind in reversed(to_remove):
                sorted_split_dims.pop(ind)

            size_splits = []
            prev_r = 0
            for l, r, out in sorted_split_dims:
                # Split dims shouldn't intersect
                if l < prev_r:
                    valid_for_replacement = False
                prev_r = r

            if prev_r > input_shape[split_channel_dim]:
                valid_for_replacement = False

            if not valid_for_replacement:
                continue

            prev_r = 0
            for l, r, out in sorted_split_dims:
                # Save missing tensor part
                if l > prev_r:
                    shape = np.array(input_shape)
                    size_splits.append(l - prev_r)
                    shape[split_channel_dim] = l - prev_r
                    data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape})
                    add_opoutput(graph, data_node.id, 0, False)
                    final_data_nodes_list.append(data_node)

                prev_r = r
                size_splits.append(r - l)
                final_data_nodes_list.append(out)

            if prev_r < input_shape[split_channel_dim]:
                # Add last part of tensor
                shape = input_shape.copy()
                shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r
                size_splits.append(input_shape[split_channel_dim] - prev_r)
                data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape})
                add_opoutput(graph, data_node.id, 0, False)
                final_data_nodes_list.append(data_node)

            for node in out_nodes:
                if not np.all([x == 0 for x in node.shrink_axis_mask]):
                    out_node = node.out_node()
                    if np.any(node['shrink_axis_mask']):
                        self.add_squeeze_for_shrink(graph, node)
                    if np.any(node['new_axis_mask']):
                        self.add_unsqueeze_for_new(graph, node)

                    for i in range(len(final_data_nodes_list)):
                        if final_data_nodes_list[i].name == out_node.name:
                            final_data_nodes_list[i] = node.out_node()
                            break

            # Insert Split layer and remove old StridedSlice layers
            # 1. Remove connections from input_data to StridedSlice ops
            out_data_nodes = []
            name_for_future_split = out_nodes[0].name
            for node in out_nodes:
                out_data_nodes.append(node.out_node())
                graph.remove_edge(input_data.id, node.id)
                graph.remove_edge(node.id, node.out_node().id)
                graph.remove_node(node.id)
                log.debug("Removed: {}".format(node.id))

            # 2. Create Split layer and reorder outputs
            name = name_for_future_split + "/Split"
            axis_const = Const(graph, {'value': int64_array(split_channel_dim),
                                       'name': name + '/Axis'}).create_node_with_data()
            size_splits_const = Const(graph, {'value': int64_array(size_splits),
                                              'name': name + '/Sizes'}).create_node_with_data()
            split = VariadicSplit(graph, dict(name=name, out_ports_count=len(size_splits)))

            split.create_node_with_data(inputs=[input_data, axis_const, size_splits_const],
                                        data_nodes=final_data_nodes_list)