Beispiel #1
0
    def find_and_replace_pattern(self, graph: Graph):
        layout = graph.graph['layout']
        for eltwise_op_node in graph.get_op_nodes(is_eltwise=True):
            out_shape = eltwise_op_node.out_port().data.get_shape()
            if 4 <= len(out_shape) <= 5:
                out_features = out_shape[get_features_dim(
                    layout, len(out_shape))]
                for port, node in eltwise_op_node.in_nodes().items():
                    if len(node.shape) != len(out_shape) and len(
                            node.shape) == 1 and out_features == node.shape[0]:
                        new_shape = shape_for_layout(
                            layout,
                            batch=1,
                            features=out_features,
                            height=1,
                            width=1,
                            depth=1 if len(out_shape) == 5 else None)
                        dim_const = Const(graph, {
                            'value': new_shape,
                            'name': node.id + '/Dim'
                        }).create_node()
                        reshape_op = Reshape(graph,
                                             attrs={
                                                 'dim': new_shape,
                                                 'name': node.id + '/Broadcast'
                                             }).create_node()

                        eltwise_op_node.in_port(port).get_source(
                        ).node.out_port(0).get_connection().set_destination(
                            reshape_op.in_port(0))
                        reshape_op.in_port(1).connect(dim_const.out_port(0))

                        reshape_op.out_port(0).connect(
                            eltwise_op_node.in_port(port))
    def squeeze_initial_states(graph: Graph, match: dict):
        """
        Squeeze input initial states of recurrent node to 2-D shape.
        """
        hidden_init_port = 5
        cell_init_port = 6

        rnn_layer = match['rnn_layer']
        # Add input ports to rnn_layer
        rnn_layer.add_sequence_of_ports(type='in', rng=range(7))
        rnn_layer_name = rnn_layer.soft_get('name', rnn_layer.id)

        assert hidden_init_port in rnn_layer.in_nodes()
        hidden_size = rnn_layer.hidden_size
        shape = Shape(graph, dict(name=rnn_layer_name + '/ShapeOf')).create_node()
        rnn_layer.in_port(0).get_source().connect(shape.in_port(0))

        batch = node_to_get_shape_value_of_indices(shape, int64_array([rnn_layer.batch_dim]))
        new_dim = create_op_node_with_second_input(graph, Concat, second_input_value=int64_array([hidden_size]),
                                                   op_attrs=dict(name=rnn_layer_name + '/HiddenStateResizeDim',
                                                                 in_ports_count=2, axis=0), input_node=batch)
        reshape_h = Reshape(graph, dict(name=rnn_layer_name + '/HiddenStateResize', override_output_shape=True)).create_node()
        new_dim.out_port(0).connect(reshape_h.in_port(1))
        rnn_layer.in_port(hidden_init_port).get_connection().insert_node(reshape_h)

        if rnn_layer.op == 'LSTM':
            assert cell_init_port in rnn_layer.in_nodes()

            reshape_c = Reshape(graph, dict(name=rnn_layer_name + '/CellStateResize', override_output_shape=True)).create_node()
            new_dim.out_port(0).connect(reshape_c.in_port(1))
            rnn_layer.in_port(cell_init_port).get_connection().insert_node(reshape_c)
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)
        assert node.has_valid(
            'axis'
        ), 'The node "{}" does not have mandatory attribute "axis"'.format(
            node_name)

        flatten_node = FlattenONNX(graph, {
            'name': node_name + '/FlattenONNX_',
            'axis': node.axis
        }).create_node()
        shape_node = Shape(graph, {
            'name': node_name + '/ShapeOf_'
        }).create_node()
        logsoftmax_node = LogSoftmax(graph, {
            'name': node_name + '/LogSoftmax_',
            'axis': 1
        }).create_node()
        reshape_node = Reshape(graph, {}).create_node()

        rename_nodes([(node, node_name + '/delete'),
                      (reshape_node, node_name)])

        shape_node.out_port(0).connect(reshape_node.in_port(1))
        logsoftmax_node.out_port(0).connect(reshape_node.in_port(0))
        flatten_node.out_port(0).connect(logsoftmax_node.in_port(0))

        source = node.in_port(0).get_source()

        flatten_node.in_port(0).connect(source)
        shape_node.in_port(0).connect(source)

        return [reshape_node.id]
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['flatten']
        name = node.soft_get('name', node.id)

        assert node.has_valid('axis'), 'Flatten {} should have `axis` attribute extracted, but it\'s not'.format(name)
        axis = node.axis

        if axis == 0:
            dim = Const(graph, {'value': int64_array([1, -1])}).create_node()
        elif axis == 1:
            dim = Const(graph, {'value': int64_array([0, -1])}).create_node()
        else:
            shape = Shape(graph, {'name': name + '/input_shape'}).create_node()

            idxs = list(range(axis)) if axis > 0 else list(range(axis, 0))

            axis_shape_portion = node_to_get_shape_value_of_indices(shape, idxs)
            first_dims = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]),
                                                          {'keep_dims': True})
            second_dims = Const(graph, {'value': int64_array([-1])}).create_node()

            node.in_port(0).get_source().connect(shape.in_port(0))
            axis_shape_portion.out_port(0).connect(first_dims.in_port(0))

            order_of_dims = [first_dims, second_dims] if axis > 0 else [second_dims, first_dims]

            dim = new_shape_node_from_shape_nodes(order_of_dims)

        reshape_node = Reshape(graph, {'name': node.id + '/Reshape'}).create_node()
        reshape_node.in_port(1).connect(dim.out_port(0))

        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
        node.in_port(0).get_connection().set_destination(reshape_node.in_port(0))
Beispiel #5
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['op']
        node.is_training = False

        shape = node.in_port(1).data.get_shape()
        assert shape is not None, 'The shape of scale input of the BatchNorm node {} is not defined'.format(node.name)

        bn_mean = Const(graph, {'name': node.name + '/mean', 'value': np.zeros(shape, dtype=np.float32),
                                'override_output_shape': True}).create_node()
        bn_std = Const(graph, {'name': node.name + '/std', 'value': np.ones(shape, dtype=np.float32),
                               'override_output_shape': True}).create_node()
        node.in_port(3).get_connection().set_source(bn_mean.out_port(0))
        node.in_port(4).get_connection().set_source(bn_std.out_port(0))

        # save the original shape
        original_shape = Shape(graph, {'name': node.in_port(0).get_source().node.soft_get('name')}).create_node()
        original_shape.in_port(0).connect(node.in_port(0).get_source())

        mvn = MVN(graph, {'name': node.name + '/mvn_', 'eps': node.soft_get('eps', 1e-6),
                          'override_output_shape': True}).create_node()
        node.in_port(0).get_connection().insert_node(mvn)

        reshape_4d = create_op_node_with_second_input(graph, Reshape, int64_array([1, -1, 0, 0]),
                                                      {'override_output_shape': True,
                                                       'name': node.soft_get('name') + '/fused_batch_and_channels'})
        mvn.in_port(0).get_connection().insert_node(reshape_4d)

        # restore original shape
        reshape_back = Reshape(graph, {'name': mvn.soft_get('name') + '/restore_shape',
                                       'override_output_shape': True}).create_node()
        reshape_back.in_port(1).connect(original_shape.out_port(0))
        mvn.out_port(0).get_connection().insert_node(reshape_back)
Beispiel #6
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['mxreshape']

        input_index = 0
        reshape_index = 0
        shape_node = Shape(graph, dict(name=node.id +
                                       '/ShapeMXReshape')).create_node()
        shape_node.in_port(0).connect(node.in_port(0).get_source())
        output_dims_nodes = []
        for d in node.dim:
            if reshape_index < len(node.dim):
                input_index, reshape_index, output_dims_nodes = self.resolve(
                    input_index, reshape_index, node.dim, shape_node,
                    output_dims_nodes)

        concat_node = Concat(
            shape_node.graph,
            dict(name=shape_node.id + '/ConcatMXReshape_',
                 axis=0,
                 in_ports_count=len(output_dims_nodes))).create_node()

        for in_port_index, dim_node in enumerate(output_dims_nodes):
            concat_node.in_port(in_port_index).connect(dim_node.out_port(0))

        reshape_node = Reshape(graph,
                               dict(name=node.id + '/Reshape_')).create_node()
        reshape_node.in_port(1).connect(concat_node.out_port(0))
        node.in_port(0).get_connection().set_destination(
            reshape_node.in_port(0))
        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
    def replace_pattern(graph: Graph, match: dict):
        node = match['conv']
        node_name = node.soft_get('name', node.id)

        # create Reshape before convolution
        # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
        i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
        shape = Cast(
            graph, {
                'name':
                node_name + '/to_float',
                'dst_type':
                data_type_str_to_np(graph.graph['cmd_params'].data_type)
            }).create_node()
        i_shape.in_port(0).connect(node.in_port(0).get_source())
        shape.in_port(0).connect(i_shape.out_port(0))

        N, H = node_to_get_shape_value_of_indices(
            shape, [0]), node_to_get_shape_value_of_indices(shape, [1])

        div = create_op_with_const_inputs(
            graph, Div, {1: float_array([node.patch_stride])},
            {'name': node_name + '/div_stride_h'})
        div.in_port(0).connect(H.out_port(0))

        concat = create_op_with_const_inputs(
            graph, Concat, {
                2: float_array([1]),
                3: float_array([node.patch_stride])
            }, {
                'name': node_name + '/concat_all_dims',
                'in_ports_count': 4,
                'axis': 0
            })
        concat.in_port(0).connect(N.out_port(0))
        concat.in_port(1).connect(div.out_port(0))

        reshape_pattern = Cast(graph, {
            'name': node_name + '/to_int',
            'dst_type': np.int64
        }).create_node()
        concat.out_port(0).connect(reshape_pattern.in_port(0))

        reshape_in = Reshape(graph, {
            'name': node_name + '/reshape_in'
        }).create_node()
        reshape_in.in_port(1).connect(reshape_pattern.out_port(0))

        # create Reshape after Convolution
        reshape_out = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            {'name': node_name + '/reshape_out'})

        # connect input_reshape_node
        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(reshape_in.out_port(0))
        reshape_in.in_port(0).connect(source)
        # connect output_reshape_node
        node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
        node.out_port(0).connect(reshape_out.in_port(0))
Beispiel #8
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['proposal']
        assert len(node.in_ports()
                   ) == 3, "Proposal op must have exactly 3 input ports"
        im_info_shape = node.in_port(2).data.get_shape()
        assert im_info_shape is not None

        if np.array_equal(im_info_shape, [1, 6]):
            log.error(
                'The model contains Proposal layer "{}" with input of shape [1, 6]. Inference Engine '
                'implementation of the Proposal layer uses only 4 first values (indices 0, 1, 2 and 3). '
                'Elements with indices 4 and 5 will be ignored.'.format(
                    node.soft_get('name', node.id)),
                extra={'is_warning': True})
            begin = Const(graph, {
                'value': np.array([0, 0], dtype=np.int32)
            }).create_node()
            end = Const(graph, {
                'value': np.array([1, 3], dtype=np.int32)
            }).create_node()
            stride = Const(graph, {
                'value': np.array([1, 1], dtype=np.int32)
            }).create_node()

            cropped_im_info = StridedSlice(
                graph, {
                    'name': 'cropped_im_info',
                    'begin_mask': int64_array([1, 1]),
                    'end_mask': int64_array([1, 1]),
                    'new_axis_mask': int64_array([0]),
                    'shrink_axis_mask': int64_array([0]),
                    'ellipsis_mask': int64_array([0]),
                    'override_output_shape': True,
                }).create_node()

            node.in_port(2).get_connection().insert_node(cropped_im_info)
            begin.out_port(0).connect(cropped_im_info.in_port(1))
            end.out_port(0).connect(cropped_im_info.in_port(2))
            stride.out_port(0).connect(cropped_im_info.in_port(3))

            # update the im_info_shape so the next 'if' statement become true
            im_info_shape = int64_array([1, 3])

        if np.array_equal(im_info_shape, [1, 3]) or np.array_equal(
                im_info_shape, [1, 4]):
            reshape = Reshape(graph,
                              dict(name="im_info/Reshape")).create_node()
            const = Const(graph, dict(value=[im_info_shape[1]])).create_node()
            node.in_port(2).get_connection().set_destination(
                reshape.in_port(0))
            const.out_port(0).connect(reshape.in_port(1))
            reshape.out_port(0).connect(node.in_port(2))

        if node.has_port('out', 1) and not node.out_port(1).disconnected():
            # This is the case when Proposal layer is used from extension, not from opset.
            # Setting version attribute is not recommended, this will be fixed after Proposal will be updated in IE.
            graph.node[node.id]['version'] = 'extension'
Beispiel #9
0
    def decompose_shuffle_channel(node: Node):
        graph = node.graph
        name = node.soft_get('name', node.id)

        rename_node(node, name + '/to_be_removed')

        shape = Shape(graph, dict(name=name + '/InputShape')).create_node()
        shape.in_port(0).connect(node.in_port(0).get_source())

        # Reshape [input_batch, group, input_channels/group, -1]
        batch = node_to_get_batch_value(shape)
        group = Const(
            graph, dict(name=name + '/Rows',
                        value=int64_array([node.group]))).create_node()
        const = Const(graph, dict(name=name + '/Const',
                                  value=int64_array([-1]))).create_node()

        input_channels = node_to_get_features_dimension_value(shape)
        output_channels = create_op_node_with_second_input(
            graph,
            Div,
            np.int64(node.group), {'name': name + '/Cols'},
            input_node=input_channels)
        i_output_channels = Cast(graph, {
            'name': output_channels.name + '/Convert',
            'dst_type': np.int64
        }).create_node()
        output_channels.out_port(0).connect(i_output_channels.in_port(0))

        reshape_split_dim = new_shape_node_from_shape_nodes(
            [batch, group, i_output_channels, const])
        reshape_split_node = Reshape(
            graph, dict(name=name + '/Reshape_split_')).create_node()
        reshape_split_dim.out_port(0).connect(reshape_split_node.in_port(1))

        # Transpose(0, 2, 1, 3)
        transpose_node = create_op_node_with_second_input(
            graph,
            Transpose,
            int64_array([0, 2, 1, 3]), {'name': name + '/Transpose_'},
            input_node=reshape_split_node)

        # Reshape back to input shape
        reshape_concat = Reshape(graph, dict(name=name)).create_node()
        rename_node(reshape_concat, name)

        shape.out_port(0).connect(reshape_concat.in_port(1))
        transpose_node.out_port(0).connect(reshape_concat.in_port(0))

        # Final connections
        node.in_port(0).get_connection().set_destination(
            reshape_split_node.in_port(0))
        node.out_port(0).get_connection().set_source(
            reshape_concat.out_port(0))
Beispiel #10
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['pool']

        if node.pool_step is None:
            node.stride = int64_array([1, 1, node.window[-1], node.window[-1]])

        # create Reshape before convolution
        # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
        shape = Shape(graph, {}).create_node()
        shape.in_port(0).connect(node.in_port(0).get_source())

        split = create_op_with_const_inputs(graph, VariadicSplit, {
            1: int64_array(0),
            2: int64_array([1, -1])
        }, {'out_ports_count': 2}, shape)
        node_pool_stride = Const(graph, {
            'value': int64_array([node.pool_stride])
        }).create_node()
        pow_node = create_op_node_with_second_input(graph, Pow,
                                                    int64_array([-1]))
        pow_node.in_port(0).connect(node_pool_stride.out_port(0))

        mul = Mul(graph, {}).create_node()
        mul.in_port(0).connect(split.out_port(1))
        mul.in_port(1).connect(pow_node.out_port(0))

        const_1 = Const(graph, {'value': int64_array([1])}).create_node()

        concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node()
        concat.in_port(0).connect(split.out_port(0))
        concat.in_port(3).connect(mul.out_port(0))
        concat.in_port(2).connect(const_1.out_port(0))
        concat.in_port(1).connect(node_pool_stride.out_port(0))

        reshape_in = Reshape(graph, {
            'name': '/Reshape/' + node.name
        }).create_node()
        reshape_in.in_port(1).connect(concat.out_port(0))

        # create Reshape after Convolution
        reshape_out = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            {'name': node.name + '/Reshape/'})

        # connect input_reshape_node
        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(reshape_in.out_port(0))
        reshape_in.in_port(0).connect(source)
        # connect output_reshape_node
        node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
        node.out_port(0).connect(reshape_out.in_port(0))
Beispiel #11
0
    def replace_pattern(graph: Graph, match: dict):
        flatten = match['reshape']
        output_shape = np.copy(flatten.out_port(0).data.get_shape())
        output_shape[0] = 0

        reshape = Reshape(graph, dict(name=flatten.id)).create_node()
        dim = Const(graph,
                    dict(name=flatten.id + '/DimData',
                         value=output_shape)).create_node()

        flatten.in_port(0).get_connection().set_destination(reshape.in_port(0))
        dim.out_port(0).connect(reshape.in_port(1))
        flatten.out_port(0).get_connection().set_source(reshape.out_port(0))
        reshape['force_precision_in_ports'] = {1: 'int64'}
 def replace_pattern(graph: Graph, match: dict):
     node = match['op']
     input_shape = node.in_port(0).data.get_shape()
     if len(input_shape) > 2:
         new_shape = Const(graph, {
             'value': np.array([0, -1], dtype=np.int64)
         }).create_node()
         reshape = Reshape(graph, {}).create_node()
         source = node.in_port(0).get_source()
         node.in_port(0).get_connection().set_source(reshape.out_port(0))
         source.connect(reshape.in_port(0))
         new_shape.out_port(0).connect(reshape.in_port(1))
         new_shape.infer(new_shape)
         reshape.infer(reshape)
Beispiel #13
0
    def replace_pattern(self, graph: Graph, match: dict):
        conv = match['conv']

        assert len(conv.out_nodes()) == 1, "Convolution operation {} should have 1 output data node".format(conv.id)
        out_data = conv.out_node()

        assert out_data.has_valid('shape'), 'Output shape is undefined for {} in back phase'.format(conv.id)
        out_shape = out_data.shape

        if out_shape.size != 3:
            return

        assert len(conv.in_nodes()) >= 1, "Convolution operation {} should have more than 1 input data node".format(
            conv.id)
        inp_data = conv.in_node()

        assert inp_data.has_valid('shape'), 'Input shape is undefined for {} in back phase'.format(conv.id)
        inp_shape = inp_data.shape
        new_inp_shape = np.insert(inp_shape, 2, 1)

        # setting to None to be overwritten by infer function
        conv.kernel_spatial_idx = None
        conv.spatial_dims = None

        # inserting fake H dimension
        conv.dilation = np.insert(conv.dilation, 2, 1)
        conv.kernel_spatial = np.append([1], conv.kernel_spatial)
        conv.pad = np.insert(conv.pad, 2, [0, 0], axis=0)
        conv.stride = np.insert(conv.stride, 2, 1)

        weights_node = conv.in_node(1)
        weights_node.value = np.reshape(weights_node.value, np.insert(weights_node.value.shape, 2, 1))
        weights_node.shape = np.array(weights_node.value.shape, dtype=np.int64)

        reshape = Reshape(graph, {'name': conv.name + '/reshape'}).create_node()
        reshape_dim = Const(graph, {'value': new_inp_shape, 'name': reshape.id + '/Dim'}).create_node()
        conv.in_port(0).get_connection().insert_node(reshape)
        reshape.in_port(1).connect(reshape_dim.out_port(0))

        reshape_back = Reshape(graph, {'name': conv.name + '/reshape_back'}).create_node()
        reshape_back_dim = Const(graph, {'value': out_shape, 'name': reshape.id + '/Dim'}).create_node()
        conv.out_port(0).get_connection().insert_node(reshape_back)
        reshape_back.in_port(1).connect(reshape_back_dim.out_port(0))

        # run shape inference manually for several nodes to override shapes of the model nodes which changed behaviour
        reshape_dim.infer(reshape_dim)
        reshape.infer(reshape)
        conv.infer(conv)
Beispiel #14
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['proposal']
        assert len(node.in_ports()
                   ) == 3, "Proposal op must have exactly 3 input ports"
        im_info_shape = node.in_port(2).data.get_shape()
        assert im_info_shape is not None

        if np.array_equal(im_info_shape, [1, 3]) or np.array_equal(
                im_info_shape, [1, 4]):
            reshape = Reshape(graph,
                              dict(name="im_info/Reshape")).create_node()
            const = Const(graph, dict(value=[im_info_shape[1]])).create_node()
            node.in_port(2).get_connection().set_destination(
                reshape.in_port(0))
            const.out_port(0).connect(reshape.in_port(1))
            reshape.out_port(0).connect(node.in_port(2))
Beispiel #15
0
    def find_and_replace_pattern(self, graph: Graph):
        for roll_node in graph.get_op_nodes(op='Roll'):
            if not roll_node.in_port(2).disconnected():
                return
            node_name = roll_node.soft_get('name', roll_node.id)

            # reshape to 1d tensor
            reshape_to_1d = create_op_node_with_second_input(
                graph, Reshape, int64_array([-1]),
                {'name': node_name + '/reshape'})
            roll_node.in_port(0).get_connection().insert_node(reshape_to_1d)

            # add zero const as axes input to roll
            const_zero = Const(graph, {
                'value': int64_array([0]),
                'name': node_name + '/axes'
            }).create_node()
            const_zero.out_port(0).connect(roll_node.in_port(2))

            # reshape to original shape
            shape_of = Shape(graph, {
                'name': node_name + '/shape_of'
            }).create_node()
            roll_node.in_port(0).get_connection().add_destination(
                shape_of.in_port(0))
            reshape_to_orig_shape = Reshape(graph, {}).create_node()
            rename_nodes([(roll_node, node_name + '/roll'),
                          (reshape_to_orig_shape, node_name)])
            shape_of.out_port(0).connect(reshape_to_orig_shape.in_port(1))
            roll_node.out_port(0).get_connection().insert_node(
                reshape_to_orig_shape)
Beispiel #16
0
    def replace_pattern(graph: Graph, match: dict):
        select = match['op']
        if select.has_valid('format') and select['format'] == 'tf':
            condition = select.in_node(0)
            input_1 = select.in_node(1)
            input_2 = select.in_node(2)

            assert np.array_equal(input_1.shape, input_2.shape)

            if len(condition.shape) == 1 and len(input_1.shape) > 1:
                new_shape = np.array([0] + [1] * (len(input_1.shape) - 1),
                                     dtype=np.int64)

                reshape_shape_const = Const(graph, {
                    'name': select.name + '/Reshape/Dim/',
                    'value': new_shape
                }).create_node()

                unsqueeze_op = Reshape(
                    graph, dict(name=select.name +
                                '/Broadcast/')).create_node(inputs=[condition])

                reshape_shape_const.out_port(
                    0).get_connection().set_destination(
                        unsqueeze_op.in_port(1))

                select.in_port(0).disconnect()
                select.in_port(0).get_connection().set_source(
                    unsqueeze_op.out_port(0))
    def replace_sub_graph(self, graph: Graph, match: dict):
        if not check_applicability(match):
            return

        reshape = match['reshape']
        div_name = match['division'].name

        input_shape = Shape(graph, dict(name=div_name + '/shape/MVN_T_')).create_node()
        shape_of_reshape = reshape.in_port(1).get_connection().get_source().node.value
        c1, c2 = shape_of_reshape[1], shape_of_reshape[2]
        c = c1 * c2

        new_reshape = create_op_node_with_second_input(graph, Reshape, int64_array([0, 0, 0, c1, c2]),
                                                       dict(name=div_name + '/first_reshape/MVN_T_'))
        permute_order = int64_array([0, 1, 2, 4, 3])
        first_permute = create_op_node_with_second_input(graph, Transpose, permute_order,
                                                         dict(name=div_name + '/first_permute/MVN_T_'), new_reshape)

        add = match['add']
        variance = match['variance']
        eps_port_num = 0 if add.in_port(0).get_connection().get_source().node.id != variance.id else 1
        eps = add.in_port(eps_port_num).get_connection().get_source().node
        mvn_node = create_op_with_const_inputs(graph, MVN, {1: int64_array([1, 2, 3])},
                                               dict(name=div_name + '/MVN/MVN_T_',
                                                    eps=eps.value, normalize_variance=1,
                                                    eps_mode='inside_sqrt'))
        first_permute.out_port(0).connect(mvn_node.in_port(0))

        second_permute = create_op_node_with_second_input(graph, Transpose, permute_order,
                                                          dict(name=div_name + '/second_permute/MVN_T_'), mvn_node)
        new_reshape2 = Reshape(graph, dict(name=div_name + '/second_reshape/MVN_T_')).create_node()
        second_permute.out_port(0).connect(new_reshape2.in_port(0))
        gamma_val = np.reshape(match['gamma_identity'].in_port(0).get_connection().get_source().node.value,
                               int64_array([1, 1, 1, c]))
        new_mul = create_op_node_with_second_input(graph, Mul, gamma_val,
                                                   dict(name=match['mul'].name + '/MVN_T_'), new_reshape2)
        beta_val = np.reshape(match['beta_identity'].in_port(0).get_connection().get_source().node.value,
                              int64_array([1, 1, 1, c]))
        new_add2 = create_op_node_with_second_input(graph, Add, beta_val,
                                                    dict(name=match['add2'].name + '/MVN_T_'), new_mul)

        transpose_connection = match['transpose'].in_port(0).get_connection()
        before_transpose = transpose_connection.get_source().node
        transpose_connection.set_destination(new_reshape.in_port(0))
        input_shape.out_port(0).connect(new_reshape2.in_port(1))
        before_transpose.out_port(0).connect(input_shape.in_port(0))
        match['transpose2'].out_port(0).get_connection().set_source(new_add2.out_port(0))
Beispiel #18
0
    def convert_ifft_to_dft(self, graph: Graph, mx_fft: Node):
        mx_fft_name = mx_fft.soft_get('name', mx_fft.id)

        rank_node = Rank(graph, {'name': mx_fft_name + '/rank'}).create_node()
        sub_node = create_op_with_const_inputs(graph, Sub, {1: int64_array(1)},
                                               {'name': mx_fft_name + '/Sub'})
        rank_node.out_port(0).connect(sub_node.in_port(0))
        broadcast_node0 = create_op_with_const_inputs(
            graph, Broadcast, {0: int64_array(0)},
            {'name': mx_fft_name + '/broadcast'})
        sub_node.out_port(0).connect(broadcast_node0.in_port(1))
        concat_node = create_op_with_const_inputs(
            graph, Concat, {1: int64_array([-1, 2])}, {
                'name': mx_fft_name + '/new_shape',
                'in_ports_count': 2,
                'axis': 0
            }, broadcast_node0)

        reshape_node = Reshape(graph, {
            'name': mx_fft_name + '/reshape'
        }).create_node()
        concat_node.out_port(0).connect(reshape_node.in_port(1))

        mx_fft_connection = mx_fft.in_port(0).get_connection()
        mx_fft_connection.set_destination(reshape_node.in_port(0))
        mx_fft_connection.get_source().connect(rank_node.in_port(0))

        dft_node = create_op_with_const_inputs(
            graph, IDFT, {1: int64_array([-1])}, {
                'name': mx_fft_name + '/idft',
                'in_ports_count': 2
            }, reshape_node)

        split_node = create_op_with_const_inputs(
            graph, Split, {1: int64_array(-1)}, {
                'name': mx_fft_name + '/split',
                'num_splits': 2
            }, dft_node)
        squeeze_node = create_op_with_const_inputs(graph, Squeeze,
                                                   {1: int64_array([-1])}, {},
                                                   split_node)

        mx_fft.out_port(0).get_connection().set_source(
            squeeze_node.out_port(0))
        rename_nodes([(mx_fft, mx_fft_name + '/to_be_removed'),
                      (squeeze_node, mx_fft_name)])
    def replace_pattern(self, graph: Graph, match: dict):
        gather = match['GatherNd']
        input_shape = gather.in_node(0).shape
        indices = gather.in_node(1).value
        if indices is None:
            # We can't do such special pass without indices value
            return

        # 0. All needed checks that we can replace GatherNd by Gather
        gather_idx = self.indices_check(indices, input_shape)
        if gather_idx is None:
            log.warning(
                'Node {} with op=GatherNd  can\'t be normalized to op=Gather.'.
                format(gather.name))
            return

        # 1. Add Reshape and connect
        new_shape = int64_array([-1] + list(input_shape[indices.shape[-1]:]))
        reshape = Reshape(graph, {
            'name': gather.name + '/Reshape_for_GatherNd/'
        }).create_node()
        reshape_const_node = Const(graph, {
            'name': reshape.name + '/Dim',
            'value': new_shape
        }).create_node()
        gather.in_port(0).get_connection().set_destination(reshape.in_port(0))
        reshape.in_port(1).connect(reshape_const_node.out_port(0))

        # 2. Change indices from Nd to 1d:
        new_indices = np.reshape(
            np.take(indices, indices=[gather_idx], axis=-1), [-1])
        new_indices_const = Const(graph, dict(value=new_indices)).create_node()
        axis_const = Const(graph, {'value': int64_array(0)}).create_node()

        # 3. Create new Gather operation and reconnect all inputs/outputs
        new_gather = Gather(graph, {
            'name': gather.name + '/NewGather/'
        }).create_node()
        reshape.out_port(0).connect(new_gather.in_port(0))
        new_indices_const.out_port(0).connect(new_gather.in_port(1))
        axis_const.out_port(0).connect(new_gather.in_port(2))

        gather.out_port(0).get_connection().set_source(new_gather.out_port(0))

        # 4. Remove old Gather node
        graph.remove_node(gather.id)
Beispiel #20
0
 def replace_sub_graph(self, graph: Graph, match: dict):
     node = match['conv']
     input_reshape_node = Reshape(graph,
                                  {
                                      'name': '/Reshape/' + node.name,
                                      'axis': 1,
                                      'infer': Reshape.kaldi_infer
                                  }).create_node()
     output_reshape_node = Reshape(graph,
                                   {
                                       'name': node.name + '/Reshape/',
                                       'axis': 1,
                                       'infer': Reshape.kaldi_infer
                                   }).create_node()
     # connect input_reshape_node
     source = node.in_port(0).get_source()
     node.in_port(0).get_connection().set_source(input_reshape_node.out_port(0))
     input_reshape_node.in_port(0).connect(source)
     # connect output_reshape_node
     node.out_port(0).get_connection().set_source(output_reshape_node.out_port(0))
     node.out_port(0).connect(output_reshape_node.in_port(0))
Beispiel #21
0
 def broadcast_with_reshape(port):
     input_shape = input_port.data.get_shape()
     reshape_dims = np.zeros(len(input_shape), dtype=np.int64)
     for i in range(0, node.axis):
         reshape_dims[i] = 1
     data_shape = port.data.get_shape()
     for i in range(node.axis, node.axis + len(data_shape)):
         reshape_dims[i] = data_shape[i - node.axis]
     for i in range(node.axis + len(data_shape), len(input_shape)):
         reshape_dims[i] = 1
     reshape = Reshape(graph, dict(name=port.node.name + "/Broadcast_", dim=reshape_dims)).create_node()
     port.get_connection().set_destination(reshape.in_port(0))
     reshape.out_port(0).connect(port)
Beispiel #22
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        mxreshape = match['op']
        if not mxreshape.reverse:
            return

        shape_node = Shape(graph, dict(name=mxreshape.id + '/Shape')).create_node()
        forward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
                                                                          dict(name=str(mxreshape.id) + '/ForwardUnsqueeze'))
        forward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/ForwardReverse', axis=1)).create_node()

        forward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]),
                                                                        dict(name=str(mxreshape.id) + '/ForwardSqueeze'))
        reshape_node = Reshape(graph, dict(name=mxreshape.id + '/Reshape')).create_node()
        shape_node.in_port(0).connect(mxreshape.in_port(0).get_source())
        mxreshape.in_port(0).get_connection().set_destination(reshape_node.in_port(0))

        forward_reverse_unsqueeze_node.in_port(0).connect(shape_node.out_port(0))
        forward_reverse_node.in_port(0).connect(forward_reverse_unsqueeze_node.out_port(0))
        forward_reverse_squeeze_node.in_port(0).connect(forward_reverse_node.out_port(0))
        reshape_node.in_port(1).connect(forward_reverse_squeeze_node.out_port(0))

        reshape_shape_node = create_op_node_with_second_input(graph, Reshape, int64_array(np.flip(mxreshape.dim, 0)),
                                                              dict(name=str(mxreshape.id) + '/ReshapeShape'))
        if np.sum(np.in1d([-2, -3, -4], mxreshape.dim), axis=0):
            reshape_shape_node = MXReshape(graph, dict(name=mxreshape.id + '/Reshape',
                                     dim=int64_array(np.flip(mxreshape.dim, 0)))).create_node()

        reshape_shape_node.in_port(0).connect(reshape_node.out_port(0))

        backward_shape_node = Shape(graph, dict(name=mxreshape.id + '/BackwardShape')).create_node()
        backward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
                                                                           dict(name=str(mxreshape.id) + '/BackwardUnsqueeze'))
        backward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/BackwardReverse', axis=1)).create_node()
        backward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]),
                                                                         dict(name=str(mxreshape.id) + '/BackwardSqueeze'))
        backward_reshape_node = Reshape(graph, dict(name=mxreshape.id + '/BackwardReshape')).create_node()

        backward_shape_node.in_port(0).connect(reshape_shape_node.out_port(0))
        backward_reverse_unsqueeze_node.in_port(0).connect(backward_shape_node.out_port(0))
        backward_reverse_node.in_port(0).connect(backward_reverse_unsqueeze_node.out_port(0))
        backward_reverse_squeeze_node.in_port(0).connect(backward_reverse_node.out_port(0))

        backward_reshape_node.in_port(0).connect(reshape_shape_node.out_port(0))
        backward_reshape_node.in_port(1).connect(backward_reverse_squeeze_node.out_port(0))

        mxreshape.out_port(0).get_connection().set_source(backward_reshape_node.out_port(0))
    def append_variances(priors_scale_node: Node, variance: list):
        graph = priors_scale_node.graph
        name = priors_scale_node.name

        sp_shape = Shape(graph, {'name': name + '/shape'}).create_node()
        priors_scale_node.out_port(0).connect(sp_shape.in_port(0))

        begin = Const(graph, {'value': np.array([-2])}).create_node()
        end = Const(graph, {'value': np.array([-1])}).create_node()
        stride = Const(graph, {'value': np.array([1])}).create_node()
        shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': np.array([1]),
                                                     'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
                                                     'shrink_axis_mask': np.array([0]),
                                                     'ellipsis_mask': np.array([0])}).create_node()

        sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0))
        begin.out_port(0).connect(shape_part_for_tiling.in_port(1))
        end.out_port(0).connect(shape_part_for_tiling.in_port(2))
        stride.out_port(0).connect(shape_part_for_tiling.in_port(3))

        concat_value = Const(graph, {'value': np.array([4])}).create_node()
        shape_concat = Concat(graph, {'name': name + '/shape_for_tiling', 'in_ports_count': 2,
                                      'axis': np.array(0)}).create_node()
        shape_part_for_tiling.out_port(0).connect(shape_concat.in_port(0))
        concat_value.out_port(0).connect(shape_concat.in_port(1))

        variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node()
        tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
        variance.out_port(0).connect(tile.in_port(0))
        shape_concat.out_port(0).connect(tile.in_port(1))

        reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node()
        sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node()
        sp_reshape.in_port(0).connect(priors_scale_node.out_port(0))
        sp_reshape.in_port(1).connect(reshape_dim.out_port(0))

        concat = Concat(graph,
                        {'name': name + '/priors_concat', 'axis': np.array(0), 'in_ports_count': 2}).create_node()
        sp_reshape.out_port(0).connect(concat.in_port(0))
        tile.out_port(0).connect(concat.in_port(1))

        output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node()
        output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node()
        concat.out_port(0).connect(output_node.in_port(0))
        output_dims.out_port(0).connect(output_node.in_port(1))

        return output_node
Beispiel #24
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['conv']
        node_name = node.soft_get('name', node.id)

        # create Reshape before convolution
        # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
        shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
        shape.in_port(0).connect(node.in_port(0).get_source())

        split = create_op_with_const_inputs(graph, VariadicSplit, {
            1: int64_array(0),
            2: int64_array([1, -1])
        }, {
            'name': shape.name + '/split_batch',
            'out_ports_count': 2
        }, shape)

        pow_node = create_op_node_with_second_input(
            graph, Pow, int64_array([-1]),
            {'name': node_name + '/patch_stride/inverse'})
        conv_patch_stride = Const(
            graph, {
                'value': int64_array([node.patch_stride]),
                'name': node_name + '/patch_stride/'
            }).create_node()
        pow_node.in_port(0).connect(conv_patch_stride.out_port(0))

        mul = Mul(graph, {
            'name': node_name + '/mul_inverse_stride_h'
        }).create_node()
        mul.in_port(0).connect(split.out_port(1))
        mul.in_port(1).connect(pow_node.out_port(0))

        concat = create_op_with_const_inputs(
            graph, Concat, {2: int64_array([1])}, {
                'name': node_name + '/concat_all_dims',
                'in_ports_count': 4,
                'axis': 0
            })

        concat.in_port(0).connect(split.out_port(0))
        concat.in_port(1).connect(mul.out_port(0))
        concat.in_port(3).connect(conv_patch_stride.out_port(0))

        reshape_in = Reshape(graph, {
            'name': node_name + '/reshape_in'
        }).create_node()
        reshape_in.in_port(1).connect(concat.out_port(0))

        # create Reshape after Convolution
        reshape_out = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            {'name': node_name + '/reshape_out'})

        # connect input_reshape_node
        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(reshape_in.out_port(0))
        reshape_in.in_port(0).connect(source)
        # connect output_reshape_node
        node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
        node.out_port(0).connect(reshape_out.in_port(0))
Beispiel #25
0
    def replace_pattern(self, graph: Graph, match: dict):

        merge = match['merge']
        power = Pow(graph, {
            'name': merge.name + '/reciprocal_',
            'type': 'PNORM'
        }).create_node()
        const1 = Const(graph, {
            'value': -1.0,
            'name': merge.name + '/negate_const'
        }).create_node()
        merge.in_port(0).get_connection().set_destination(power.in_port(0))
        const1.out_port(0).connect(power.in_port(1))

        concat_node = Concat(
            graph, {
                'axis': 0,
                'name': merge.name + '/Concat_',
                'override_output_shape': True
            }).create_node()
        const3 = Const(graph, {
            'name': merge.name + '/const_reduce',
            'value': 0
        }).create_node()

        for ii, idx in enumerate(
                range(merge.significant, merge.to_significant + 1, 1)):
            const_node = Const(
                graph, {
                    'value': float_array(math.pow(10.0, idx)),
                    'name': merge.name + '/Const_' + ii.__str__()
                }).create_node()

            mul_node = Mul(graph, {
                'name': merge.name + '/Mul_' + ii.__str__()
            }).create_node()
            const_node.out_port(0).connect(mul_node.in_port(0))

            power.out_port(0).connect(
                mul_node.in_port(1))  # connect to the graph node
            mul_node2 = Mul(graph, {
                'name': merge.name + '/Mul_Div_' + ii.__str__()
            }).create_node()

            const_node2 = Const(
                graph, {
                    'value': float_array(math.pow(10.0, -1 * idx)),
                    'name': merge.name + '/Const_Pow_' + ii.__str__()
                }).create_node()
            cast_node = Cast(
                graph, {
                    'name': merge.name + '/Cast_' + idx.__str__(),
                    'dst_type': np.float32
                }).create_node()

            mul_node.out_port(0).connect(cast_node.in_port(0))
            const_node2.out_port(0).connect(mul_node2.in_port(1))
            cast_node.out_port(0).connect(mul_node2.in_port(0))
            concat_node.add_input_port(ii, skip_if_exist=True)
            concat_node.in_port(ii).get_connection().set_source(
                mul_node2.out_port(0))

        reducesum_node = ReduceMean(
            graph, {
                'name': merge.id + '/_pnorm_reduced_sum',
                'keep_dims': False,
                'in_ports_count': 2,
                'need_shape_inference': None,
                'infer': reduce_infer
            }).create_node()

        const3.out_port(0).connect(reducesum_node.in_port(1))
        reducesum_node.in_port(0).get_connection().set_source(
            concat_node.out_port(0))

        reshape = Reshape(graph, {
            'name': merge.name + '/Reshape_Node'
        }).create_node()
        reshape_dim = Const(graph, {
            'value': np.array([1, 5]),
            'name': merge.id + '/Reshape_Dim'
        }).create_node()
        reducesum_node.out_port(0).connect(reshape.in_port(0))
        reshape.in_port(1).connect(reshape_dim.out_port(0))
        merge.out_port(0).get_connection().set_source(reshape.out_port(0))
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['flatten']
        name = node.soft_get('name', node.id)

        assert node.has_valid(
            'axis'), 'Flatten {} has no mandatory `axis` attribute'.format(
                name)
        assert node.has_valid(
            'end_axis'
        ), 'Flatten {} has no mandatory `end_axis` attribute'.format(name)

        axis = node.axis
        end_axis = node.end_axis

        if end_axis == -1 and axis >= 0:
            begin_dims = Const(graph, {
                'value': int64_array([0] * axis)
            }).create_node()
            middle_dim = Const(graph, {
                'value': int64_array([-1])
            }).create_node()
            end_dims = Const(graph, {'value': int64_array([])}).create_node()
        else:
            rank = Rank(graph, {'name': name + '/input_rank'}).create_node()
            node.in_port(0).get_source().connect(rank.in_port(0))

            shape = Shape(graph, {'name': name + '/input_shape'}).create_node()
            node.in_port(0).get_source().connect(shape.in_port(0))

            begin_dims = get_shape_values_by_range_idxs(shape=shape,
                                                        rank=rank,
                                                        begin=0,
                                                        end=axis)
            middle_dims = get_shape_values_by_range_idxs(shape=shape,
                                                         rank=rank,
                                                         begin=axis,
                                                         end=end_axis,
                                                         include_end=True)
            end_dims = get_shape_values_by_range_idxs(shape=shape,
                                                      rank=rank,
                                                      begin=end_axis,
                                                      end=-1,
                                                      include_begin=False,
                                                      include_end=True)

            middle_dim = create_op_node_with_second_input(
                graph, ReduceProd, int64_array([0]), {'keep_dims': True})
            middle_dims.out_port(0).connect(middle_dim.in_port(0))

        dim = new_shape_node_from_shape_nodes(
            [begin_dims, middle_dim, end_dims])

        original_name = node.soft_get('name')
        abandoned_name = original_name + '/ShouldBeDeleted'
        reshape_node = Reshape(graph, {}).create_node()
        # Keep node with the same name to avoid confuse with renaming
        rename_nodes([(node, abandoned_name), (reshape_node, original_name)])
        reshape_node.in_port(1).connect(dim.out_port(0))

        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
        node.in_port(0).get_connection().set_destination(
            reshape_node.in_port(0))
Beispiel #27
0
    def mxrepeat_decomposition(node: Node):
        graph = node.graph
        name = node.soft_get('name', node.id)

        rename_node(node, name + '/to_be_removed')

        # Unqueeze
        input_rank = Rank(graph, {'name': name + '/Rank'}).create_node()
        node.in_port(0).get_source().connect(input_rank.in_port(0))

        axis = get_canonical_axis_index_node(input_rank, node.axis)
        unsqueeze_axis = create_op_node_with_second_input(
            graph,
            Add,
            int64_array([1]), {'name': name + '/Unsqueeze/Axis'},
            input_node=axis)

        unsqueeze = Unsqueeze(graph, {
            'name': name + '/Unsqueeze'
        }).create_node()
        unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0))

        # Tile (1, 1, ..., repeats, ..., 1)
        # we generate tile array according to the following table:

        # parts:       |      first      |  repeats |  second     |
        # i:           | 0, 1, ..., axis,| axis + 1,| ..., rank+1 |
        # tile_array:  | 1, 1, ...,  1  ,| repeats ,| ...,   1    |

        one = Const(graph, {
            'name': name + '/Broadcast/One',
            'value': int64_array([1])
        }).create_node()
        first_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_first_part'
        }).create_node()
        first_ones.in_port(0).connect(one.out_port(0))
        first_ones.in_port(1).connect(unsqueeze_axis.out_port(0))

        repeats = Const(graph, {
            'name': name + '/repeats',
            'value': int64_array([node.repeats])
        }).create_node()

        second_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_second_part'
        }).create_node()
        second_part_broadcast_shape = Sub(
            graph, {
                'name': name + '/Broadcast/Shape/second_part'
            }).create_node()
        second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0))
        second_part_broadcast_shape.in_port(1).connect(
            unsqueeze_axis.out_port(0))
        second_ones.in_port(0).connect(one.out_port(0))
        second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0))

        tile_repeats = new_shape_node_from_shape_nodes(
            [first_ones, repeats, second_ones])
        tile = Tile(graph, {'name': name + '/Tile'}).create_node()
        tile.in_port(1).connect(tile_repeats.out_port(0))

        # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:])
        # we generate reshape dim array according to the following table:

        # parts:       |    first   |                rep           |  second   |
        # i:           | 0, 1, ... ,|               axis,          | ..., rank |
        # dim_array:   | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] |

        input_shape = Shape(graph, {'name': name + '/Shape'}).create_node()
        node.in_port(0).get_source().connect(input_shape.in_port(0))

        first_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=0,
            end=node.axis,
            include_begin=True,
            include_end=False)

        original_axis_dim = create_op_with_const_inputs(
            graph,
            Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'},
            input_node=input_shape)
        original_axis_dim.in_port(1).connect(axis.out_port(0))

        repeated_dimention = Mul(graph, {
            'name': name + '/RepeatedDim'
        }).create_node()
        repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0))
        repeated_dimention.in_port(1).connect(repeats.out_port(0))

        second_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=node.axis,
            end=-1,
            include_begin=False,
            include_end=True)

        output_shape = new_shape_node_from_shape_nodes([
            first_input_shape_part, repeated_dimention, second_input_shape_part
        ])

        reshape = Reshape(graph, {'name': name}).create_node()
        rename_node(reshape, name)
        reshape.in_port(1).connect(output_shape.out_port(0))

        # Final connections
        node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0))
        tile.in_port(0).connect(unsqueeze.out_port(0))
        reshape.in_port(0).connect(tile.out_port(0))
        node.out_port(0).get_connection().set_source(reshape.out_port(0))
Beispiel #28
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        group_norm_node = match['op']
        group_norm_num_input_dims = len(
            group_norm_node.in_port(0).data.get_shape())

        # node computing initial GroupNorm input shape
        initial_shape_op_node = Shape(graph, {
            'name': group_norm_node.name + '/Shape'
        }).create_node()
        initial_shape_op_node.in_port(0).connect(
            group_norm_node.in_port(0).get_source())

        initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node)
        initial_features_dim_node = node_to_get_features_dimension_value(
            initial_shape_op_node)
        initial_spatial_dims_node = node_to_get_spatial_dimensions_value(
            initial_shape_op_node)
        group_size_node = Const(
            graph, {
                'value': int64_array([group_norm_node.num_groups]),
                'name': group_norm_node.name + '/GroupSize'
            }).create_node()

        # calculate "features // group_size" value
        reciprocal_group_size_node = Const(
            graph, {
                'value': np.array([1.0 / group_norm_node.num_groups]),
                'name': group_norm_node.name + '/ReciprocalGroupSize'
            }).create_node()

        c_div_g_node = Mul(graph, {}).create_node()
        c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0))
        c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0))

        batch_mul_group_size_node = Mul(graph, {}).create_node()
        batch_mul_group_size_node.in_port(0).connect(
            initial_batch_dim_node.out_port(0))
        batch_mul_group_size_node.in_port(1).connect(
            group_size_node.out_port(0))

        # create new node which concatenates several dims to one
        new_shape_node = new_shape_node_from_shape_nodes([
            batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node
        ])

        reshape_for_mvn_node = Reshape(graph, {}).create_node()

        group_norm_node.in_port(0).get_connection().set_destination(
            reshape_for_mvn_node.in_port(0))
        reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0))

        # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc
        gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64)
        gamma_beta_shape[1] = -1

        gamma_value = group_norm_node.in_port(1).get_source().data.get_value()
        beta_value = group_norm_node.in_port(2).get_source().data.get_value()
        assert gamma_value is not None, 'The gamma should be constant'
        assert beta_value is not None, 'The beta should be constant'
        gamma_value = np.reshape(gamma_value, gamma_beta_shape)
        group_norm_node.in_port(1).get_source().data.set_value(gamma_value)
        beta_value = np.reshape(beta_value, gamma_beta_shape)
        group_norm_node.in_port(2).get_source().data.set_value(beta_value)

        # MVN
        mvn_node = MVN(
            graph, {
                'name': group_norm_node.name + '/MVN',
                'across_channels': 1,
                'normalize_variance': 1,
                'eps': group_norm_node.eps
            }).create_node()
        mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

        # reshape to the initial shape before multiplying with gamma and adding beta
        reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
        reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0))
        reshape_to_initial_shape_node.in_port(1).connect(
            initial_shape_op_node.out_port(0))

        mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node()
        mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0))
        group_norm_node.in_port(1).get_connection().set_destination(
            mul_node.in_port(1))

        add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node()
        add_node.in_port(0).connect(mul_node.out_port(0))
        group_norm_node.in_port(2).get_connection().set_destination(
            add_node.in_port(1))

        group_norm_node.out_port(0).get_connection().set_source(
            add_node.out_port(0))
    def replace_pattern(graph: Graph, match: dict):
        node = match['matmul']
        name = node.soft_get('name', node.id)

        A_shape = node.in_port(0).data.get_shape()
        B_shape = node.in_port(1).data.get_shape()
        out_shape = node.out_port(0).data.get_shape()

        assert A_shape is not None and B_shape is not None and out_shape is not None

        B_value = node.in_port(1).data.get_value()
        if (B_value is not None or node.in_port(1).get_source().node.has_and_set('stop_value_propagation')) and B_shape[
            B_shape != 1].size <= 2:
            # transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O]
            # to FullyConnected representation: [I, K] * [O, K] = [I, O]
            B, I, K, O, aligned_A_shape, aligned_B_shape = MatMulToFullyConnected.get_matmul_BIKO(node)

            # weights normalization
            if not node.transpose_b:
                # FullyConnected weights layout is OI
                # MatMul second input layout is (B)IO
                transpose_order = list(range(B_shape.size))
                transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]

                order = Const(graph, {'value': int64_array(transpose_order)}).create_node()
                transpose = Transpose(graph, {'name': name + '/weights_transpose'}).create_node()

                weights_source = node.in_port(1).get_source()
                node.in_port(1).get_connection().set_source(transpose.out_port(0))
                transpose.in_port(0).connect(weights_source)
                transpose.in_port(1).connect(order.out_port(0))

                order.infer(order)
                transpose.infer(transpose)

            if node.in_port(1).data.get_shape().size != 2:
                const = Const(graph, {'value': int64_array([-1, K])}).create_node()
                reshape = Reshape(graph, {'name': name + '/weights_reshape'}).create_node()

                weights_source = node.in_port(1).get_source()
                node.in_port(1).get_connection().set_source(reshape.out_port(0))

                reshape.in_port(0).connect(weights_source)
                reshape.in_port(1).connect(const.out_port(0))

                const.infer(const)
                reshape.infer(reshape)

            assert np.all(np.array_equal(node.in_port(1).data.get_shape(), int64_array([O, K]))), \
                "MatMul `{}` was not converted to FullyConnected: wrong weights shape: {}, " \
                "B={}, I={}, K={}, O={}".format(name, node.in_port(1).data.get_shape(), B, I, K, O)

            node.in_port(1).bin = 'weights'
            del node['transpose_b']

            # input normalization
            if node.transpose_a:
                transpose_order = list(range(A_shape.size))
                transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]

                order = Const(graph, {'value': int64_array(transpose_order)}).create_node()
                transpose = Transpose(graph, {'name': name + '/input_transpose'}).create_node()

                input_source = node.in_port(0).get_source()
                node.in_port(0).get_connection().set_source(transpose.out_port(0))
                transpose.in_port(0).connect(input_source)
                transpose.in_port(1).connect(order.out_port(0))

                order.infer(order)
                transpose.infer(transpose)

            if A_shape.size != 2:
                const = Const(graph, {'value': int64_array([-1, K])}).create_node()
                reshape = Reshape(graph, {'name': name + '/input_reshape'}).create_node()

                input_source = node.in_port(0).get_source()
                node.in_port(0).get_connection().set_source(reshape.out_port(0))
                reshape.in_port(0).connect(input_source)
                reshape.in_port(1).connect(const.out_port(0))

                const.infer(const)
                reshape.infer(reshape)

            assert np.all(np.array_equal(node.in_port(0).data.get_shape(), int64_array([np.prod(B) * I, K]))), \
                "MatMul `{}` wasn't converted to FullyConnected: wrong input shape: {}, " \
                "B={}, I={}, K={}, O={}".format(name, node.in_port(0).data.get_shape(), B, I, K, O)

            del node['transpose_a']

            FullyConnected.update_node_stat(node, {'out-size': O})

            # output normalization
            if out_shape.size != 2:
                const = Const(graph, {'value': int64_array([*B, I, O])}).create_node()
                reshape = Reshape(graph, {'name': name + '/output_reshape'}).create_node()

                dst = node.out_port(0).get_destination()
                node.out_port(0).get_connection().set_destination(reshape.in_port(0))
                const.out_port(0).connect(reshape.in_port(1))
                reshape.out_port(0).connect(dst)

                node.infer(node)

                const.infer(const)
                reshape.infer(reshape)

        else:
            assert A_shape.size == out_shape.size
            assert B_shape.size <= out_shape.size
            if B_shape.size != out_shape.size:
                unsqueeze_dim = Const(graph, {'value': int64_array(list(range(out_shape.size - B_shape.size)))
                                              }).create_node()
                unsqueeze = Unsqueeze(graph, {}).create_node()
                B_source = node.in_port(1).get_source()
                node.in_port(1).get_connection().set_source(unsqueeze.out_port(0))
                unsqueeze.in_port(0).connect(B_source)
                unsqueeze.in_port(1).connect(unsqueeze_dim.out_port(0))

                unsqueeze_dim.infer(unsqueeze_dim)
                unsqueeze.infer(unsqueeze)

            Gemm.update_node_stat(node, {
                'transpose_a': node.has_and_set('transpose_a'),
                'transpose_b': node.has_and_set('transpose_b'),
            })
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        group_norm_node = match['op']
        group_norm_num_input_dims = len(
            group_norm_node.in_port(0).data.get_shape())

        # node computing initial GroupNorm input shape
        initial_shape_op_node = Shape(graph, {
            'name': group_norm_node.name + '/Shape'
        }).create_node()
        initial_shape_op_node.in_port(0).connect(
            group_norm_node.in_port(0).get_source())

        initial_shape_op_node_float = Cast(
            graph, {
                'name':
                initial_shape_op_node.name + '/to_float',
                'dst_type':
                data_type_str_to_np(graph.graph['cmd_params'].data_type)
            }).create_node()
        initial_shape_op_node.out_port(0).connect(
            initial_shape_op_node_float.in_port(0))

        initial_batch_dim_node = node_to_get_batch_value(
            initial_shape_op_node_float)
        initial_features_dim_node = node_to_get_features_dimension_value(
            initial_shape_op_node_float)
        initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(
            initial_shape_op_node)
        initial_spatial_dims_node = Cast(
            graph, {
                'name':
                initial_spatial_dims_node_int.name + '/to_float',
                'dst_type':
                data_type_str_to_np(graph.graph['cmd_params'].data_type)
            }).create_node()
        initial_spatial_dims_node_int.out_port(0).connect(
            initial_spatial_dims_node.in_port(0))

        group_size_node = Const(
            graph, {
                'value': int64_array([group_norm_node.num_groups]),
                'name': group_norm_node.name + '/GroupSize'
            }).create_node()

        # calculate "features // group_size" value
        reciprocal_group_size_node = Const(
            graph, {
                'value': np.array([1.0 / group_norm_node.num_groups]),
                'name': group_norm_node.name + '/ReciprocalGroupSize'
            }).create_node()

        c_div_g_node = Mul(graph, {}).create_node()
        c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0))
        c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0))

        batch_mul_group_size_node = Mul(graph, {}).create_node()
        batch_mul_group_size_node.in_port(0).connect(
            initial_batch_dim_node.out_port(0))
        batch_mul_group_size_node.in_port(1).connect(
            group_size_node.out_port(0))

        # create new node which concatenates several dims to one
        new_shape_node_float = new_shape_node_from_shape_nodes([
            batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node
        ])
        new_shape_node = Cast(graph, {
            'name': new_shape_node_float.name + '/to_int64',
            'dst_type': np.int64
        }).create_node()
        new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0))

        reshape_for_mvn_node = Reshape(graph, {}).create_node()

        group_norm_node.in_port(0).get_connection().set_destination(
            reshape_for_mvn_node.in_port(0))
        reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0))

        # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc
        gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64)
        gamma_beta_shape[1] = -1

        gamma_value = group_norm_node.in_port(1).get_source().data.get_value()
        beta_value = group_norm_node.in_port(2).get_source().data.get_value()
        assert gamma_value is not None, 'The gamma should be constant'
        assert beta_value is not None, 'The beta should be constant'
        gamma_value = np.reshape(gamma_value, gamma_beta_shape)
        group_norm_node.in_port(1).get_source().data.set_value(gamma_value)
        beta_value = np.reshape(beta_value, gamma_beta_shape)
        group_norm_node.in_port(2).get_source().data.set_value(beta_value)

        # MVN
        mvn_node = MVN(
            graph, {
                'name': group_norm_node.name + '/MVN',
                'normalize_variance': 1,
                'eps': group_norm_node.eps,
                'eps_mode': 'inside_sqrt'
            }).create_node()
        mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

        # MVN axes
        _, rank = get_shape_and_rank_nodes_by_port(
            mvn_node.in_port(0).get_connection().get_source(),
            return_as_a_scalar=True)
        rng = create_op_with_const_inputs(graph, Range, {
            0: int64_array(1),
            2: int64_array(1)
        }, {
            'name': group_norm_node.name + '/Range',
            'output_type': np.int64
        })
        mvn_node.in_port(1).connect(rng.out_port(0))
        rng.in_port(1).connect(rank.out_port(0))

        # reshape to the initial shape before multiplying with gamma and adding beta
        reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
        reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0))
        reshape_to_initial_shape_node.in_port(1).connect(
            initial_shape_op_node.out_port(0))

        mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node()
        mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0))
        group_norm_node.in_port(1).get_connection().set_destination(
            mul_node.in_port(1))

        add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node()
        add_node.in_port(0).connect(mul_node.out_port(0))
        group_norm_node.in_port(2).get_connection().set_destination(
            add_node.in_port(1))

        group_norm_node.out_port(0).get_connection().set_source(
            add_node.out_port(0))