コード例 #1
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['conv']
        node_name = node.soft_get('name', node.id)

        # create Reshape before convolution
        # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
        i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
        shape = Cast(
            graph, {
                'name':
                node_name + '/to_float',
                'dst_type':
                data_type_str_to_np(graph.graph['cmd_params'].data_type)
            }).create_node()
        i_shape.in_port(0).connect(node.in_port(0).get_source())
        shape.in_port(0).connect(i_shape.out_port(0))

        N, H = node_to_get_shape_value_of_indices(
            shape, [0]), node_to_get_shape_value_of_indices(shape, [1])

        div = create_op_with_const_inputs(
            graph, Div, {1: float_array([node.patch_stride])},
            {'name': node_name + '/div_stride_h'})
        div.in_port(0).connect(H.out_port(0))

        concat = create_op_with_const_inputs(
            graph, Concat, {
                2: float_array([1]),
                3: float_array([node.patch_stride])
            }, {
                'name': node_name + '/concat_all_dims',
                'in_ports_count': 4,
                'axis': 0
            })
        concat.in_port(0).connect(N.out_port(0))
        concat.in_port(1).connect(div.out_port(0))

        reshape_pattern = Cast(graph, {
            'name': node_name + '/to_int',
            'dst_type': np.int64
        }).create_node()
        concat.out_port(0).connect(reshape_pattern.in_port(0))

        reshape_in = Reshape(graph, {
            'name': node_name + '/reshape_in'
        }).create_node()
        reshape_in.in_port(1).connect(reshape_pattern.out_port(0))

        # create Reshape after Convolution
        reshape_out = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            {'name': node_name + '/reshape_out'})

        # connect input_reshape_node
        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(reshape_in.out_port(0))
        reshape_in.in_port(0).connect(source)
        # connect output_reshape_node
        node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
        node.out_port(0).connect(reshape_out.in_port(0))
コード例 #2
0
ファイル: MatMulNormalizer.py プロジェクト: SDxKeeper/dldt
    def replace_pattern(self, graph: Graph, match: dict):
        matmul = match['matmul']
        reshape = match['reshape']
        other_input_port_idx = 0 if match['matmul'].in_port(0).get_source().node.id == match['other_input'].id else 1
        shape_source = match['matmul'].in_port(other_input_port_idx).get_source()
        initial_reshape_pattern = reshape.in_port(1).data.get_value()
        if len(initial_reshape_pattern) != 2:
            return

        reshape_is_A_input = matmul.in_port(0).get_source().node.id == reshape.id
        if reshape_is_A_input:
            idx = -1 if matmul.transpose_b else -2
        else:
            idx = -2 if matmul.transpose_a else -1
        idx = get_canonical_axis_index(initial_reshape_pattern, idx)

        shape_name = shape_source.node.soft_get('name', shape_source.node.id)
        shape = Shape(graph, {'name': shape_name + '/Shape'}).create_node()
        shape.in_port(0).connect(shape_source)
        C = node_to_get_shape_value_of_indices(shape, [idx])
        N = Const(graph, {'name': shape_name + '/MinusOne', 'value': int64_array([-1])}).create_node()

        if len(initial_reshape_pattern) == 2:
            if reshape_is_A_input:
                reshape_pattern = [C, N] if matmul.transpose_a else [N, C]
            else:
                reshape_pattern = [N, C] if matmul.transpose_b else [C, N]
            new_reshape_pattern = new_shape_node_from_shape_nodes(reshape_pattern)
            reshape.in_port(1).get_connection().set_source(new_reshape_pattern.out_port(0))
        else:
            return
コード例 #3
0
    def squeeze_initial_states(graph: Graph, match: dict):
        """
        Squeeze input initial states of recurrent node to 2-D shape.
        """
        hidden_init_port = 5
        cell_init_port = 6

        rnn_layer = match['rnn_layer']
        # Add input ports to rnn_layer
        rnn_layer.add_sequence_of_ports(type='in', rng=range(7))
        rnn_layer_name = rnn_layer.soft_get('name', rnn_layer.id)

        assert hidden_init_port in rnn_layer.in_nodes()
        hidden_size = rnn_layer.hidden_size
        shape = Shape(graph, dict(name=rnn_layer_name + '/ShapeOf')).create_node()
        rnn_layer.in_port(0).get_source().connect(shape.in_port(0))

        batch = node_to_get_shape_value_of_indices(shape, int64_array([rnn_layer.batch_dim]))
        new_dim = create_op_node_with_second_input(graph, Concat, second_input_value=int64_array([hidden_size]),
                                                   op_attrs=dict(name=rnn_layer_name + '/HiddenStateResizeDim',
                                                                 in_ports_count=2, axis=0), input_node=batch)
        reshape_h = Reshape(graph, dict(name=rnn_layer_name + '/HiddenStateResize', override_output_shape=True)).create_node()
        new_dim.out_port(0).connect(reshape_h.in_port(1))
        rnn_layer.in_port(hidden_init_port).get_connection().insert_node(reshape_h)

        if rnn_layer.op == 'LSTM':
            assert cell_init_port in rnn_layer.in_nodes()

            reshape_c = Reshape(graph, dict(name=rnn_layer_name + '/CellStateResize', override_output_shape=True)).create_node()
            new_dim.out_port(0).connect(reshape_c.in_port(1))
            rnn_layer.in_port(cell_init_port).get_connection().insert_node(reshape_c)
コード例 #4
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['flatten']
        name = node.soft_get('name', node.id)

        assert node.has_valid('axis'), 'Flatten {} should have `axis` attribute extracted, but it\'s not'.format(name)
        axis = node.axis

        if axis == 0:
            dim = Const(graph, {'value': int64_array([1, -1])}).create_node()
        elif axis == 1:
            dim = Const(graph, {'value': int64_array([0, -1])}).create_node()
        else:
            shape = Shape(graph, {'name': name + '/input_shape'}).create_node()

            idxs = list(range(axis)) if axis > 0 else list(range(axis, 0))

            axis_shape_portion = node_to_get_shape_value_of_indices(shape, idxs)
            first_dims = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]),
                                                          {'keep_dims': True})
            second_dims = Const(graph, {'value': int64_array([-1])}).create_node()

            node.in_port(0).get_source().connect(shape.in_port(0))
            axis_shape_portion.out_port(0).connect(first_dims.in_port(0))

            order_of_dims = [first_dims, second_dims] if axis > 0 else [second_dims, first_dims]

            dim = new_shape_node_from_shape_nodes(order_of_dims)

        reshape_node = Reshape(graph, {'name': node.id + '/Reshape'}).create_node()
        reshape_node.in_port(1).connect(dim.out_port(0))

        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
        node.in_port(0).get_connection().set_destination(reshape_node.in_port(0))
コード例 #5
0
    def replace_pattern(self, graph: Graph, match: dict):
        if not self.is_applicable(match):
            return

        unsqueeze_node = match['unsqueeze']
        unsqueeze_name = unsqueeze_node.soft_get('name', unsqueeze_node.id)
        second_input_of_unsqueeze = unsqueeze_node.in_port(
            1).get_connection().get_source().node
        d_idx = int(second_input_of_unsqueeze.value)
        axis = d_idx - 1

        shape_node = Shape(graph,
                           dict(name=unsqueeze_name + '/Shape')).create_node()
        axis_len_node = node_to_get_shape_value_of_indices(shape_node, [axis])

        second_input_of_tile = match['tile'].in_port(
            1).get_connection().get_source().node
        scale = int64_array([second_input_of_tile.value[d_idx]])
        float_scale = float32_array([second_input_of_tile.value[d_idx]])
        mul_node = create_op_with_const_inputs(
            graph, Mul, {1: scale}, {'name': unsqueeze_name + '/Mul'})

        axis_len_node.out_port(0).connect(mul_node.in_port(0))

        interp_node = create_op_with_const_inputs(
            graph, Interpolate, {
                2: float_scale,
                3: int64_array([axis])
            }, {
                'mode': 'nearest',
                'antialias': 0,
                'pads_begin': int64_array([0]),
                'pads_end': int64_array([0]),
                'coordinate_transformation_mode': 'half_pixel',
                'nearest_mode': 'round_prefer_floor',
                'cube_coeff': -0.75,
                'version': 'opset4',
                'shape_calculation_mode': 'scales',
                'in_ports_count': 4,
                'maybe_part_of_sequence': True
            })
        mul_node.out_port(0).connect(interp_node.in_port(1))

        reshape_node = match['reshape']
        reshape_node.out_port(0).get_connection().set_source(
            interp_node.out_port(0))
        reshape_name = reshape_node.soft_get('name', reshape_node.id)
        rename_nodes([(reshape_node, reshape_name + '/delete'),
                      (interp_node, reshape_name)])

        unsqueeze_connection = unsqueeze_node.in_port(0).get_connection()
        unsqueeze_connection.set_destination(interp_node.in_port(0))
        unsqueeze_connection.get_source().connect(shape_node.in_port(0))
コード例 #6
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'):
            assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \
                                               'mode is supported for node {}.'.format(node.id)
            node_name = node.soft_get('name', node.id)
            rename_node(node, node_name + '/TBR')
            is_packed = False
            if len(node.in_ports()) < 3 or node.in_port(2).disconnected():
                is_packed = True
                embedding_bag = EmbeddingBagPackedSum(graph, {
                    'name': node_name
                }).create_node()
            else:
                embedding_bag = EmbeddingBagOffsetsSum(graph, {
                    'name': node_name
                }).create_node()
                node.in_port(2).get_connection().set_destination(
                    embedding_bag.in_port(2))
            rename_node(embedding_bag, node_name)
            node.in_port(0).get_connection().set_destination(
                embedding_bag.in_port(0))
            node.in_port(1).get_connection().set_destination(
                embedding_bag.in_port(1))
            node.out_port(0).get_connection().set_source(
                embedding_bag.out_port(0))
            if len(node.in_ports()
                   ) == 4 and not node.in_port(3).disconnected():
                if is_packed:
                    node.in_port(3).get_connection().set_destination(
                        embedding_bag.in_port(2))
                else:
                    # connect per_sample_weights
                    node.in_port(3).get_connection().set_destination(
                        embedding_bag.in_port(4))

                    weights_shape_node = Shape(
                        graph, {
                            'name': node_name + '/WeightsShape'
                        }).create_node()

                    weights_rank_node = Rank(graph, {
                        'name': node_name + '/WeightsRank'
                    }).create_node()
                    last_dim_node = get_canonical_axis_index_node(
                        weights_rank_node, -1)
                    weights_last_dim = get_shape_values_by_indices_node(
                        weights_shape_node, last_dim_node)

                    weights_first_dim = node_to_get_shape_value_of_indices(
                        weights_shape_node, [0])

                    zero_col_node = create_op_with_const_inputs(
                        graph, Broadcast, {0: int64_array([0])},
                        {'name': node_name + '/Broadcast'})
                    zero_col_node.in_port(1).connect(
                        weights_last_dim.out_port(0))

                    default_embeddings_node = create_op_with_const_inputs(
                        graph, Unsqueeze, {1: int64_array(0)},
                        {'name': node_name + '/Unsqueeze'})
                    default_embeddings_node.in_port(0).connect(
                        zero_col_node.out_port(0))

                    # expand embedding table with zeros
                    weights_concat = Concat(
                        graph, {
                            'axis': 0,
                            'in_ports_count': 2,
                            'name': node_name + '/Concat'
                        }).create_node()
                    embedding_bag.in_port(0).get_connection().set_destination(
                        weights_concat.in_port(0))
                    weights_concat.in_port(0).get_connection().add_destination(
                        weights_shape_node.in_port(0))
                    weights_concat.in_port(0).get_connection().add_destination(
                        weights_rank_node.in_port(0))
                    weights_concat.in_port(1).connect(
                        default_embeddings_node.out_port(0))
                    weights_concat.out_port(0).connect(
                        embedding_bag.in_port(0))

                    # point default index to expanded part of embedding table
                    weights_first_dim.out_port(0).connect(
                        embedding_bag.in_port(3))
コード例 #7
0
    def find_and_replace_pattern(self, graph: Graph):
        reverse_nodes = graph.get_op_nodes(op='Reverse')
        for reverse in reverse_nodes:
            reverse_name = reverse.soft_get('name', reverse.id)

            assert reverse.in_port(1).disconnected()
            assert reverse.has_valid('axis')

            in_shape_rank = len(reverse.in_port(0).data.get_shape())
            # 1. Add new dimension as batch for rank = 1 to have batch != seq_axis
            if in_shape_rank == 1:
                unsq_node = create_op_node_with_second_input(
                    graph, Unsqueeze, int64_array([0]),
                    {'name': reverse_name + "/Unsqueeze"})
                reverse.in_port(0).get_source().connect(unsq_node.in_port(0))
                new_in = unsq_node.out_port(0)
                batch_axis = 0
                seq_axis = 1
            else:
                new_in = reverse.in_port(0).get_source()
                seq_axis = reverse['axis']
                batch_axis = 0 if seq_axis != 0 else 1

            # 2. For ReverseSequence 1-port input is seq_lengths => create this input node as
            # shape[seq_axis] broadcasted to shape[batch_axis]
            # in ---> ShapeOf ----> Gather(seq_axis)  ----> Broadcast----->
            #            |                                      |
            #            | -------> Gather(batch_axis)----------|
            shape_node = Shape(graph, {
                'name': reverse_name + "/Shape"
            }).create_node()
            new_in.connect(shape_node.in_port(0))
            seq_axis_node = node_to_get_shape_value_of_indices(
                shape_node, [seq_axis])
            batch_node = node_to_get_shape_value_of_indices(
                shape_node, [batch_axis])
            broadcast_node = Broadcast(graph, {
                'name': reverse_name + "/Broadcast"
            }).create_node()
            broadcast_node.in_port(0).connect(seq_axis_node.out_port(0))
            broadcast_node.in_port(1).connect(batch_node.out_port(0))

            # 3. Create new ReverseSequence node and reconnect all inputs/outputs to it
            rename_node(reverse, reverse_name + '/to_delete')
            reverse_sequence = ReverseSequence(
                graph, {
                    'name': reverse_name,
                    'seq_axis': seq_axis,
                    'batch_axis': batch_axis
                }).create_node()
            reverse_sequence.in_port(0).connect(new_in)
            reverse_sequence.in_port(1).connect(broadcast_node.out_port(0))

            # 4. remove added dimension for rank = 1
            if in_shape_rank == 1:
                rename_node(reverse_sequence,
                            reverse_name + '/ReverseSequence')
                squeeze_node = create_op_node_with_second_input(
                    graph, Squeeze, int64_array([0]), {'name': reverse_name})
                squeeze_node.in_port(0).connect(reverse_sequence.out_port(0))
                reverse.out_port(0).get_connection().set_source(
                    squeeze_node.out_port(0))
            else:
                reverse.out_port(0).get_connection().set_source(
                    reverse_sequence.out_port(0))

        # 5. Delete old Reverse node
        graph.remove_nodes_from([reverse.id for reverse in reverse_nodes])
コード例 #8
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['pool']
        node_name = node.soft_get('name', node.id)

        if node.pool_step is None:
            node.stride = int64_array([1, 1, node.window[-1], node.window[-1]])

        # create Reshape before convolution
        # shape = [in_shape[0], pool_stride, 1, in_shape[1]/pool_stride]
        i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()

        dst_dtype = np.float32  # even if data_type=FP16 use float32 for shape values
        shape = Cast(graph, {
            'name': node_name + '/to_float',
            'dst_type': dst_dtype
        }).create_node()
        i_shape.in_port(0).connect(node.in_port(0).get_source())
        shape.in_port(0).connect(i_shape.out_port(0))

        N, H = node_to_get_shape_value_of_indices(
            shape, [0]), node_to_get_shape_value_of_indices(shape, [1])

        div = create_op_with_const_inputs(
            graph, Div, {1: float32_array([node.pool_stride])},
            {'name': node_name + '/div_stride_h'})
        div.in_port(0).connect(H.out_port(0))

        concat = create_op_with_const_inputs(
            graph, Concat, {
                1: float32_array([node.pool_stride]),
                2: float32_array([1])
            }, {
                'name': node_name + '/concat_all_dims',
                'in_ports_count': 4,
                'axis': 0
            })
        concat.in_port(0).connect(N.out_port(0))
        concat.in_port(3).connect(div.out_port(0))

        reshape_pattern = Cast(graph, {
            'name': node_name + '/to_int',
            'dst_type': np.int64
        }).create_node()
        concat.out_port(0).connect(reshape_pattern.in_port(0))

        reshape_in = Reshape(graph, {
            'name': node_name + '/reshape_in'
        }).create_node()
        reshape_in.in_port(1).connect(reshape_pattern.out_port(0))

        # create Reshape after Convolution
        reshape_out = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            {'name': node_name + '/reshape_out'})

        # connect input_reshape_node
        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(reshape_in.out_port(0))
        reshape_in.in_port(0).connect(source)
        # connect output_reshape_node
        node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
        node.out_port(0).connect(reshape_out.in_port(0))
コード例 #9
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['conv']
        node_name = node.soft_get('name', node.id)

        # create Reshape before convolution
        # if transpose will be applied (new models)
        #   shape = [in_shape[0], t= in_shape[1]/(patch_stride*t), patch_stride, C=1]
        # else (for old models to avoid fails on GNA - should be removed as soon as GNA will be changed)
        #   shape = [in_shape[0], t= in_shape[1]/(patch_stride*t), C=1, patch_stride]
        sp_dim_1 = 1
        if node.has_valid('patch_stride'):
            channel_dim = 2
            sp_dim_2 = 3
            frame_height = node.patch_stride
        else:
            channel_dim = 3
            sp_dim_2 = 2
            frame_height = node.height_in

        i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
        i_shape.in_port(0).connect(node.in_port(0).get_source())

        N, H = node_to_get_shape_value_of_indices(
            i_shape, [0]), node_to_get_shape_value_of_indices(i_shape, [1])

        div = create_op_with_const_inputs(
            graph, Div, {1: int64_array([frame_height * node.kernel[1]])},
            {'name': node_name + '/div_stride_h'})
        div.in_port(0).connect(H.out_port(0))

        concat = create_op_with_const_inputs(
            graph, Concat, {
                sp_dim_2: int64_array([frame_height]),
                channel_dim: int64_array([node.kernel[1]])
            }, {
                'name': node_name + '/concat_all_dims',
                'in_ports_count': 4,
                'axis': 0
            })
        concat.in_port(0).connect(N.out_port(0))
        concat.in_port(sp_dim_1).connect(div.out_port(0))

        reshape_in = Reshape(graph, {
            'name': node_name + '/reshape_in'
        }).create_node()
        reshape_in.in_port(1).connect(concat.out_port(0))

        # change layout from NHWC to NCHW
        # should be replaced by common Permute logic in future
        transpose = None
        if channel_dim == 3 and node.channel_dims == 1:
            transpose = create_op_node_with_second_input(
                graph, Transpose, int64_array([0, 3, 1, 2]),
                {'name': node.name + '/Transpose'}, reshape_in)

        # create Reshape after Convolution
        reshape_out = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            {'name': node_name + '/reshape_out'})

        # connect input_reshape_node
        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(
            transpose.out_port(0) if transpose else reshape_in.out_port(0))
        reshape_in.in_port(0).connect(source)
        # connect output_reshape_node
        node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
        node.out_port(0).connect(reshape_out.in_port(0))