Exemplo n.º 1
0
    def replace_identityN(node: Node):
        graph = node.graph
        name = node.soft_get('name', node.id)

        assert node.has_valid(
            'data_types'), 'IdentityN {} has no `data_types` attribute'.format(
                name)
        dtypes = node.data_types

        for idx, port in node.in_ports().items():
            if not node.is_in_port_connected(
                    idx) or not node.is_out_port_connected(idx):
                # ATTENTION section in the description above
                continue
            assert idx < len(
                dtypes
            ), 'IdentityN {} has inconsistent `data_types` attribute {}'.format(
                name, dtypes)
            identity = Identity(graph, {
                'name': '{}/{}_port'.format(name, idx),
                'data_type': dtypes[idx]
            }).create_node()
            port.get_connection().set_destination(identity.in_port(0))
            node.out_port(idx).get_connection().set_source(
                identity.out_port(0))

        # ATTENTION section in the description above
        for in_port in node.in_ports().values():
            in_port.disconnect()
        for out_port in node.out_ports().values():
            out_port.disconnect()
Exemplo n.º 2
0
    def re_numerate_input_ports(loop_node: Node):
        """
        Update input ports ids to be consecutive from 0 to num_input_ports - 1 and update the port_map values of the
        Loop node.

        :param loop_node: the Loop node
        :return: None
        """
        def re_number_input_port(loop_node: Node, old_port_id: int,
                                 new_port_id: int):
            loop_node.add_input_port(new_port_id, skip_if_exist=True)
            loop_node.in_port(old_port_id).get_connection().set_destination(
                loop_node.in_port(new_port_id))
            Loop.update_port_map_value(loop_node.input_port_map,
                                       'external_port_id', old_port_id,
                                       new_port_id)

        if len(loop_node.in_ports()) > 0:
            max_port_id = sorted(loop_node.in_ports().keys())[-1]
            new_port_id = 0
            for port_id in range(max_port_id + 1):
                if loop_node.is_in_port_connected(port_id):
                    if port_id != new_port_id:
                        re_number_input_port(loop_node, port_id, new_port_id)
                    new_port_id += 1

            for port_idx_to_remove in reversed(
                    range(new_port_id, max_port_id + 1)):
                if port_idx_to_remove in loop_node.in_ports().keys():
                    loop_node.delete_input_port(port_idx_to_remove)
Exemplo n.º 3
0
def get_rnn_input_size(node: Node):
    node_name = node.soft_get('name', node.id)
    assert node.is_in_port_connected(1), 'weights input is not connected'

    if node.format == 'onnx':
        # ONNX weights on input 1 contain only W part, R, and B are connected separately
        # weights_shape = `[num_directions, 4 * hidden_size, input_size]`
        weights_size = node.in_port(1).data.get_shape()
        assert len(
            weights_size
        ) == 3, 'incorrect weights ranks for MXNet {} node {}'.format(
            node.op, node_name)
        input_size = weights_size[2]
        return input_size
    elif node.format == 'mxnet':
        multiplier = node.multiplier
        hidden_size = node.hidden_size
        num_layers = node.num_layers
        direction = 2 if node.has_num_directions else 1

        # for MXNet models we always get flattened weights which contains WRB
        weights_size = node.in_port(1).data.get_shape()
        assert len(
            weights_size
        ) == 1, 'incorrect weights ranks for MXNet {} node {}'.format(
            node.op, node_name)
        weights_size = weights_size[0]

        size = hidden_size * direction * multiplier
        other_layer_params_size = (hidden_size * direction + hidden_size +
                                   2) * size
        first_layer_params_size = weights_size - (num_layers -
                                                  1) * other_layer_params_size
        # lhe lines above to find first_layer_params_size was taken from MXNetSplitMultiLayers.py:79
        # input_size can be calculated from the first_layer_params_size
        # if first_layer_params_size = (input_size + hidden_size + 2) * size
        # then input_size = first_layer_params_size / size - 2 - hidden_size
        input_size = first_layer_params_size / size - 2 - hidden_size
        return input_size
    elif node.format == 'tf':
        log.error(
            'reverse infer for TensorFlow RNN operation {} is not implemented yet'
            .format(node_name),
            extra={'is_warning': True})
    else:
        raise Error('Incorrect framework name')
Exemplo n.º 4
0
    def extend_inputs(node: Node, num_insertions: int):
        graph = node.graph
        node_name = node.soft_get('name', node.id)

        for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
            if i == 3 and not node.is_in_port_connected(3):
                continue  # no need to extend strides if they are not connected

            blank_values_arr = np.zeros(
                num_insertions) if input_name != 'strides' else np.ones(
                    num_insertions)
            blank_values_node = Const(
                graph, {
                    'name': node_name + '/extend_{}_const'.format(input_name),
                    'value': int64_array(blank_values_arr)
                }).create_node()

            if node.in_port(i).get_source().node.soft_get('type') == 'Concat':
                # concat already exists
                concat = node.in_port(i).get_source().node
                # because output data node shape will be changed
                # while shapes will be reinferred no need to check consistency
                concat['override_output_shape'] = True

                last_in_port = max(concat.in_ports().keys())
                assert not concat.in_port(last_in_port).disconnected(), 'The last in_port of Concat node {} ' \
                                                                        'should be connected'. \
                    format(concat.soft_get('name', node.id))

                concat.add_input_port(last_in_port + 1)
                concat.in_port(last_in_port + 1).connect(
                    blank_values_node.out_port(0))
            else:
                # have to create concat
                concat = Concat(
                    graph, {
                        'axis': 0,
                        'name': node_name + '/concat_{}'.format(input_name),
                        'in_ports_count': 2
                    }).create_node()
                node.in_port(i).get_connection().set_destination(
                    concat.in_port(0))
                concat.in_port(1).connect(blank_values_node.out_port(0))
                concat.out_port(0).get_connection().set_destination(
                    node.in_port(i))
Exemplo n.º 5
0
 def reverse_infer(node: Node):
     input_shape = node.in_port(0).data.get_shape()
     if input_shape is None and node.is_in_port_connected(2) and node.in_port(2).data.get_shape() is not None:
         shape = undefined_shape_of_rank(node.in_port(2).data.get_shape()[0])
         node.in_port(0).data.set_shape(shape)
    def quantize_to_fakequantize(graph: Graph,
                                 quantize_node: Node,
                                 set_stop_value_propagation=False):
        node_name = quantize_node.soft_get('name', quantize_node.id)
        axis = quantize_node.soft_get('axis', None)
        scale_y_shape = quantize_node.in_port(1).data.get_shape()

        if quantize_node.is_in_port_connected(2):
            zerop = quantize_node.in_port(2).get_source().node
        else:
            zerop = Const(
                graph, {
                    'value': mo_array(0, dtype=np.uint8),
                    'name': node_name + '/ZeroPoint'
                }).create_node()

        assert zerop.soft_get(
            'type'
        ) == 'Const', 'only constant for zero_point is supported for QuantizeLinear'
        zero_point_type = zerop.value.dtype
        # data type affects range of output values: [-128..127] or [0..255]
        if zero_point_type == np.int8:
            output_low_value = -128.0
            output_high_value = 127.0
        elif zero_point_type == np.uint8:
            output_low_value = 0.0
            output_high_value = 255.0
        else:
            raise Error(
                'Not expected type {} for zero point value in node {}'.format(
                    zero_point_type, zerop.soft_get('name')))

        fake_quantize = create_op_with_const_inputs(
            graph, FakeQuantize, {
                3: float_array(output_low_value),
                4: float_array(output_high_value)
            }, {
                'levels': 256,
                'name': node_name + '/FakeQuantize'
            })
        if set_stop_value_propagation:
            fake_quantize['stop_compression'] = True
            fake_quantize['stop_value_propagation'] = True
        quantize_node.in_port(0).get_connection().set_destination(
            fake_quantize.in_port(0))

        # Calculate input_low value
        mul_low = create_op_with_const_inputs(
            graph, Mul, {1: float_array(output_low_value - zerop.value)},
            {'name': node_name + '/Mul/Low'})
        quantize_node.in_port(1).get_connection().set_destination(
            mul_low.in_port(0))
        mul_low.out_port(0).connect(fake_quantize.in_port(1))

        # Calculate input_high value
        mul_high = create_op_with_const_inputs(
            graph, Mul, {1: float_array(output_high_value - zerop.value)},
            {'name': node_name + '/Mul/High'})
        mul_low.in_port(0).get_connection().add_destination(
            mul_high.in_port(0))
        mul_high.out_port(0).connect(fake_quantize.in_port(2))

        cast = Cast(graph, {
            'dst_type': zero_point_type,
            'name': node_name + '/Cast'
        }).create_node()
        fake_quantize.out_port(0).connect(cast.in_port(0))
        quantize_node.out_port(0).get_connection().set_source(cast.out_port(0))
        rename_nodes([(quantize_node, node_name + '/TBD'), (cast, node_name)])

        assert scale_y_shape is not None, "{0} contains scale(input with port 1) with shape None".\
            format(quantize_node.soft_get('name', soft_get('id')))
        if axis is not None and len(
                scale_y_shape) > 0 and scale_y_shape[0] > 1:
            input_shape = fake_quantize.in_port(0).data.get_shape()
            target_shape = np.ones(len(input_shape), np.int)
            target_shape[axis] = input_shape[axis]
            mul_low_reshape = create_op_with_const_inputs(
                graph, Reshape, {1: int64_array(target_shape)},
                {'name': node_name + '/Reshape/Mul/Low'})
            mul_high_reshape = create_op_with_const_inputs(
                graph, Reshape, {1: int64_array(target_shape)},
                {'name': node_name + '/Reshape/Mul/high'})

            fake_quantize.in_port(1).get_connection().set_destination(
                mul_low_reshape.in_port(0))
            fake_quantize.in_port(2).get_connection().set_destination(
                mul_high_reshape.in_port(0))

            mul_low_reshape.out_port(0).connect(fake_quantize.in_port(1))
            mul_high_reshape.out_port(0).connect(fake_quantize.in_port(2))
Exemplo n.º 7
0
 def type_infer(node: Node):
     assert node.is_in_port_connected(1), 'The second input is not connected for a node {}.' \
                                          ''.format(node.soft_get('name'), node.id)
     node.out_port(0).set_data_type(node.in_port(1).get_data_type())
Exemplo n.º 8
0
def get_rnn_batch_size_and_seq_len(node: Node):
    """
    Gets batch_size and sequence_length from RNN constant inputs
    and output shapes retrieved during reverse_infer

    :param node:
    :return:
    """
    node_name = node.soft_get('name', node.id)
    out_shape = node.out_port(0).data.get_shape()
    batch_size = dynamic_dimension
    seq_len = dynamic_dimension
    in_port_with_initial_states = 3  # initial hidden size values is framework dependent

    if out_shape is not None:
        # note that op is not in opset state but in the state of the original framework
        if node.batch_dim == 1:
            seq_len = out_shape[0]

            if node.format == 'mxnet':
                assert len(
                    out_shape
                ) == 3, 'incorrect out_shape rank for node {}'.format(
                    node_name)
                # for MXNet out_shape = [seq_len, batch_size, hidden_size]
                batch_size = out_shape[1]
                in_port_with_initial_states = 2
            elif node.format == 'onnx':
                assert len(
                    out_shape
                ) == 4, 'incorrect out_shape rank for node {}'.format(
                    node_name)
                # even for ONNX in extractor 'batch_dim': 1 (front/onnx/lstm_ext.py:26) despite the fact that
                # out_shape = [seq_len, num_directions, batch_size, hidden_size]
                batch_size = out_shape[2]
                in_port_with_initial_states = 5
            elif node.format == 'tf':
                log.error(
                    'reverse infer for TensorFlow RNN operation {} is not implemented yet'
                    .format(node_name),
                    extra={'is_warning': True})
            else:
                raise Error('Incorrect framework name')
        elif node.batch_dim == 0:
            # out_shape = [batch_size, num_directions, seq_len, hidden_size]
            batch_size = out_shape[0]
            seq_len = out_shape[2]
            in_port_with_initial_states = 3
        else:
            raise Error('incorrect batch_dim for node {}'.format(node_name))

    if batch_size is dynamic_dimension:
        if node.is_in_port_connected(in_port_with_initial_states):
            initial_hidden_state_size = node.in_port(
                in_port_with_initial_states).data.get_shape()
            if initial_hidden_state_size is not None:
                batch_size = initial_hidden_state_size[1]

    if seq_len is dynamic_dimension and node.format == 'onnx':
        # ONNX can store seq_len in optional input
        if node.is_in_port_connected(4):
            seq_len_val = node.in_port(4).data.get_value()
            if seq_len_val is not None:
                seq_len = seq_len.item()

    return [batch_size, seq_len]