def infer(node: Node):
        # there are limitations coming from ONNX LSTM definition and normalization rules
        assert len(node.in_nodes()) >= 3  # X, W and R
        assert len(node.in_nodes()) <= 7
        assert len(node.out_nodes()) <= 3
        assert node.batch_dim <= 1
        assert node.sequence_dim <= 1
        assert node.batch_dim != node.sequence_dim

        assert node.direction in ['forward', 'reverse', 'bidirectional']

        if node.blobs_wrb:
            mark_input_bins(node, ['W', 'R', 'B'])
        else:
            mark_input_bins(node)
        input_shape = node.in_node(0).shape
        assert len(input_shape) == 3

        for port in [2, 3]:
            if port in node.in_nodes() and len(node.in_node(port).in_nodes()) > 0 and \
               'zero_shapes' in node.in_node(port).in_node():
                for i in node.in_node(port).in_node().zero_shapes:
                    if node.in_node(port).shape[i] != input_shape[i]:
                        node.in_node(port).value = np.repeat(
                            node.in_node(port).value, input_shape[i], axis=i)
                        node.in_node(port).shape[i] = input_shape[i]

        out_shape = shape_array([
            input_shape[node.sequence_dim], input_shape[node.batch_dim],
            node.hidden_size
        ])
        assert not node.has_num_directions or node.sequence_dim == 0, \
            'If has_num_directions == True, then node.sequence_dim should be equal 0, but it is {}'.format(
                node.sequence_dim)
        num_directions = 2 if node.direction in ['bidirectional'] else 1
        num_layers = node.num_layers
        if node.has_num_directions:
            # insert extra dimension to output shape for num_directions
            out_shape = shape_insert(out_shape, 1, np.int64(num_directions))
        node.out_node(0).shape = out_shape
        # extra outputs for hidden/cell states
        state_size = shape_array([input_shape[1], node.hidden_size])
        if node.has_num_directions:
            state_size = shape_insert(state_size, 0,
                                      num_directions * num_layers)
        for i in [1, 2]:
            if i not in node.out_nodes():
                data_node = Op._create_data_node(node.graph,
                                                 name=node.node +
                                                 '/ExtraOutput/' + str(i),
                                                 attrs={'executable': True})
                node.graph.add_edge(node.id, data_node.id, key=0, out=i)
                add_opoutput(node.graph, data_node.id, 0, False)
            else:
                data_node = node.out_node(i)
            data_node.shape = state_size.copy()
Beispiel #2
0
    def shape_alignment(node: Node):
        """
        Specification of MatMul operation allows inputs to be aligned together before matrix multiplication.
        Current method raises an error if input shapes are not valid at any step of alignment process
        :return: aligned copies of both input shapes
        """
        node_name = node.soft_get('name', str(node.id))
        input_shapes = [node.in_port(i).data.get_shape() for i in range(2)]
        transpose_a = node.has_and_set('transpose_a')
        transpose_b = node.has_and_set('transpose_b')

        transformed_shapes = []
        for i, shape in enumerate(input_shapes):
            input_shape = shape.copy()
            # prerequisites check
            assert input_shape is not None, "MatMul has shape=`None` for {} input of `{}` node".format(
                i, node_name)
            assert input_shape.ndim == 1, "MatMul doesn't support scalar inputs. {} input of `{}` node has shape {}" \
                                          "".format(i, node_name, input_shape)
            assert input_shape.size >= 1, "MatMul doesn't support inputs with rank lower than 1. {} input of `{}` " \
                                          "node has shape {}".format(i, node_name, input_shape)
            rank = input_shape.size
            # shape alignment
            if rank != 1 and ((i == 0 and transpose_a) or
                              (i == 1 and transpose_b)):
                input_shape[-2], input_shape[-1] = input_shape[
                    -1], input_shape[-2]
            if rank == 1:
                input_shape = shape_insert(input_shape, int(i == 1), 1)

            max_shape_length = max(input_shapes[0].size, input_shapes[1].size)
            input_shape = shape_insert(input_shape, 0, [1] *
                                       (max_shape_length - input_shape.size))
            transformed_shapes.append(input_shape)

        A_shape = shape_array(transformed_shapes[0])
        B_shape = shape_array(transformed_shapes[1])

        assert A_shape.size == B_shape.size, \
            "Shapes were not aligned by length for MatMul `{}`. Shapes: `{}`".format(node_name, transformed_shapes)

        # batch broadcasting
        batch_len = A_shape.size - 2
        for i in range(batch_len):
            if A_shape[i] != B_shape[i]:
                if A_shape[i] == 1:
                    A_shape[i] = B_shape[i]
                if B_shape[i] == 1:
                    B_shape[i] = A_shape[i]

        assert compatible_shapes(A_shape[:-2], B_shape[:-2]), \
            "MatMul input shapes are incorrect. BATCH_DIMs are not equal. Node: {}. Aligned shapes: {}" \
            "".format(node_name, transformed_shapes)

        return A_shape, B_shape
def make_equal_rank(shape_1: np.ndarray, shape_2: np.ndarray):
    """
    Prepend shape with smaller length with 1. Return updates shapes
    :param shape_1: first shape
    :param shape_2: second shape
    :return: tuple with updated shapes
    """
    while len(shape_1) < len(shape_2):
        shape_1 = shape_insert(shape_1, 0, 1)

    while len(shape_2) < len(shape_1):
        shape_2 = shape_insert(shape_2, 0, 1)

    return shape_1, shape_2
Beispiel #4
0
    def infer(node):
        if len(node.in_nodes()) <= 1:
            raise Error('There is no input with unsqueeze dims for the node {}'.format(node.soft_get('name')))
        unsqueeze_dims = node.in_port(1).data.get_value()
        if unsqueeze_dims is None:
            raise Error('The dimensions to unsqueeze are not defined for the node {}'.format(node.soft_get('name')))
        unsqueeze_dims = int64_array(unsqueeze_dims)

        input_value = node.in_port(0).data.get_value()
        input_shape = node.in_port(0).data.get_shape()

        # TODO remove the following line when the Inference Engine plugins support 0D tensors
        if unsqueeze_dims.ndim == 0:
            unsqueeze_dims = int64_array([unsqueeze_dims.item()])

        # make dimensions positive to correctly translate from NHWC to NCHW layout
        unsqueeze_dims = int64_array([dim + len(node.in_port(0).data.get_shape()) + 1 if dim < 0 else dim
                                      for dim in unsqueeze_dims])
        if node.in_port(1).get_source().node.op == 'Const':
            node.in_port(1).data.set_value(unsqueeze_dims)

        output_shape = input_shape.copy()
        for dim in unsqueeze_dims:
            output_shape = shape_insert(output_shape, dim, 1)

        if input_value is not None and is_fully_defined(output_shape):
            node.out_port(0).data.set_value(input_value.reshape(output_shape))
        else:
            node.out_port(0).data.set_shape(output_shape)

        PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis')
Beispiel #5
0
    def swap_pad_and_unsqueeze(self, pad: Node, unsqueeze: Node):
        # insert additional items to the pads in the position specified by the Unsqueeze axis
        unsqueeze_axis = unsqueeze.in_port(1).data.get_value()
        for port_id in [1, 2]:
            current_value = pad.in_port(
                port_id).get_connection().data.get_value()
            new_value_node = Const(
                pad.graph, {
                    'name':
                    pad.soft_get('name', pad.id) + '/value_{}'.format(port_id),
                    'value':
                    shape_insert(current_value, unsqueeze_axis.item(), 0),
                    'override_output_shape':
                    True
                }).create_node()
            pad.in_port(port_id).disconnect()
            pad.in_port(port_id).connect(new_value_node.out_port(0))

        # swap Pad and Unsqueeze layers
        unsqueeze.in_port(0).disconnect()
        pad.in_port(0).get_connection().set_destination(unsqueeze.in_port(0))
        unsqueeze.out_port(0).get_connection().set_source(pad.out_port(0))
        unsqueeze.out_port(0).connect(pad.in_port(0))

        # output shapes of Pad and Unsqueeze changed so need to recalculate them
        pad['override_output_shape'] = True
        unsqueeze['override_output_shape'] = True
    def make_shape_4d(shape: np.array):
        """
        Create 4D tensor from 1D, 2D or 3D by adding new dimensions of size 1.
        :param shape: shape to extend.
        :return: 4D tensor.
        """
        new_shape = shape_array(shape)
        old_shape_len = len(shape)

        # TODO think about proper way to add additional dimensions considering layout
        for x in range(4 - old_shape_len):
            # if the shape is 0D or 1D then we should add additional dimensions to batch dimension
            if len(new_shape) <= 1:
                new_shape = shape_insert(new_shape, 0, 1)
            else:
                new_shape = shape_insert(new_shape, 1, 1)
        return new_shape
    def mark_eltwise_node(self, node, feature_channel=None):
        tensor_port, value_port = get_tensor_in_port(node), get_value_in_port(
            node)
        if tensor_port is None or value_port is None:
            self.set_flags_to_false(node,
                                    ['can_be_fused', 'can_be_scaleshift'])
            return

        connected_in_ports = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        }
        if len(connected_in_ports) != 2:
            return

        tensor_shape = tensor_port.data.get_shape()
        out_shape = node.out_port(0).data.get_shape()
        assert tensor_shape is not None and out_shape is not None
        if not np.array_equal(tensor_shape, out_shape):
            # ScaleShift operation doesn't support broadcasting
            self.set_flags_to_false(node,
                                    ['can_be_fused', 'can_be_scaleshift'])
            return

        value_shape = value_port.data.get_shape()
        assert value_shape is not None
        assert len(value_shape) <= len(tensor_shape), \
            "No broadcasting was done for elementwise node {} due to previous checks in EltwiseChecker class. " \
            "But constant input rank is larger than tensor input rank, that is inconsistent".format(node.name)

        # if both tensors are 0D they cannot be converted to scaleshift
        if len(tensor_shape) == 0 and len(value_shape) == 0:
            self.set_flags_to_false(node, ['can_be_scaleshift'])
            return

        broadcasted_value_shape = shape_insert(
            value_shape, 0, [1] * (len(tensor_shape) - len(value_shape)))

        feature_dim = min(1, tensor_shape.size -
                          1) if node.graph.graph['layout'] == 'NCHW' else -1
        if feature_channel is not None:
            feature_dim = feature_channel
        ones = np.ones(len(tensor_shape))
        possible_shape = ones.copy()
        np.put(possible_shape, feature_dim, tensor_shape.item(feature_dim))

        if not np.array_equal(broadcasted_value_shape, ones) and \
                not np.array_equal(broadcasted_value_shape, possible_shape):
            # ScaleShift weights should have [1,C,1,1]-like or [1,1,1,1]-like shape
            self.set_flags_to_false(node,
                                    ['can_be_fused', 'can_be_scaleshift'])
            return

        if len(tensor_shape) not in [2, 4, 5]:
            # ScaleShift operation is supported for 2D, 4D or 5D tensor inputs
            self.set_flags_to_false(node, ['can_be_scaleshift'])
            return
Beispiel #8
0
    def add_output_reshape(graph: Graph, match: dict):
        """
        Since MXNet Y output shape is [batch_size, seq_len, hidden_size * num_directions] we need to add reshape
        from above common format [batch_size, num_directions, seq_len, hidden_size] to MXNet format.
        """
        lstm = match['rnn_layer']
        input = match['input']
        if not lstm.has_num_directions:
            return
        old_data_node = lstm.out_node(0)
        num_directions = 2 if lstm.direction in ['bidirectional'] else 1
        mxnet_shape = lstm.out_node(0).shape.copy()

        if lstm.batch_dim == 0:
            mo_shape = shape_array([
                input.shape[lstm.batch_dim], input.shape[lstm.sequence_dim],
                lstm.hidden_size
            ])
        else:
            mo_shape = shape_array([
                input.shape[lstm.sequence_dim], input.shape[lstm.batch_dim],
                lstm.hidden_size
            ])

        if lstm.has_num_directions:
            mo_shape = shape_insert(mo_shape, 1, np.int64(num_directions))

        lstm_name = lstm.soft_get('name', lstm.id)

        new_data = Op._create_data_node(graph,
                                        name=lstm_name +
                                        '/Data/Reshape_mxnet/',
                                        attrs={'shape': mo_shape})
        graph.remove_edge(lstm.id, old_data_node.id)
        graph.add_edge(lstm.id, new_data.id, key=0, out=0)

        # Add Transpose
        permute_order = Const(
            graph, {
                'name': lstm_name + '/Transpose_mxnet_order',
                'value': int64_array([0, 2, 1, 3])
            }).create_node_with_data()
        permute_data = Transpose(graph, {
            'name': lstm_name + '/Transpose_mxnet/'
        }).create_node_with_data([new_data, permute_order])

        # Add Reshape
        reshape = Reshape(graph, {'name': lstm_name + '/Reshape_mxnet/'})
        reshape_dim_data = Const(
            graph, {
                'name': lstm_name + '/Reshape_mxnet_dim',
                'value': int64_array(unmask_shape(mxnet_shape))
            }).create_node_with_data()

        reshape.create_node_with_data([permute_data, reshape_dim_data],
                                      dict(),
                                      data_nodes=[old_data_node])
Beispiel #9
0
    def infer(node: Node):
        indices_shape = node.in_port(0).data.get_shape()
        assert indices_shape is not None
        dim = indices_shape.size

        assert_msg = "OneHot `{0}` ({1} input port value) should be scalar: node: `{2}`, {0} value: `{3}`"
        depth = node.in_port(1).data.get_value()
        assert depth is not None and depth.ndim == 0, assert_msg.format(
            'depth', '1', node.name, depth)
        depth = depth.item(0)

        assert node.has_valid('axis')
        axis = node['axis']
        assert -1 <= axis <= dim

        # If axis == -1 we need to insert new depth dimension in the end of indices_shape shape
        axis = dim if axis == -1 else axis

        if dim == 0:
            # scalar indices case
            output_shape = [depth]
        else:  # dim >= 1
            # vector/matrix indices case
            output_shape = shape_insert(indices_shape, axis, depth)

        node.out_port(0).data.set_shape(output_shape)

        indices = node.in_port(0).data.get_value()
        depth = node.in_port(1).data.get_value()
        on_value = node.in_port(2).data.get_value()
        off_value = node.in_port(3).data.get_value()

        if indices is not None and depth is not None and on_value is not None and off_value is not None:
            onehot_value = np.full(output_shape, off_value)

            for idx in np.ndindex(tuple(indices_shape)):
                if axis == 0:
                    hot_idx = indices[idx], *idx
                elif (axis > 0) and (axis < len(output_shape) - 1):
                    hot_idx = *idx[:axis], indices[idx], *idx[axis:]
                elif axis == len(output_shape) - 1:
                    hot_idx = *idx, indices[idx]

                if -depth <= indices[idx] < depth:
                    onehot_value[hot_idx] = on_value

            node.out_port(0).data.set_value(onehot_value)

        # This operation should be inferred in original layout
        node['reinterp_shape'] = True
        node['NCHW'] = True
Beispiel #10
0
    def infer(node: Node):
        name = node.soft_get('name', node.id)

        connected_in_ports = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        }
        assert len(connected_in_ports) == 2 and 0 in connected_in_ports and 1 in connected_in_ports, \
            "Tile should have 2 connected input port, but it doesn't for node: `{}`. Ports: {}" \
            "".format(name, connected_in_ports)

        shape = node.in_port(0).data.get_shape()
        assert shape is not None, "Undefined input shape for Tile node '{}'.".format(
            name)
        tile_array = node.in_port(1).data.get_value()
        assert tile_array is not None, "Undefined `repeats` (1st port input value) of Tile node '{}'".format(
            name)

        # align ranks of the tile_array tensor and input shape node
        if shape.size < tile_array.size:
            shape = shape_insert(shape, 0,
                                 [1] * (tile_array.size - shape.size))
        elif shape.size > tile_array.size:
            tile_array = shape_insert(tile_array, 0,
                                      [1] * (shape.size - tile_array.size))

        input_value = node.in_port(0).data.get_value()
        if input_value is not None and is_fully_defined(
                shape) and is_fully_defined(tile_array):
            node.out_port(0).data.set_value(
                np.tile(input_value.reshape(shape), tile_array))
        else:
            node.out_port(0).data.set_shape(shape * tile_array)

        PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0',
                                              'shape')
Beispiel #11
0
def normalize_eltwise_inputs(graph: Graph):
    '''
    The function normalizes input shapes for eltwise nodes.
    In the first step the function gets to know which shapes/unsqueeze dims for inputs are required for normalization.
    In the second step the function inserts Unsqueeze nodes between non-normalized inputs and eltwise nodes.
    '''
    # Generate a map for producers of eltwise nodes with non-normalized shapes
    # and in this map every producer has another map that reflects normalized shape
    # to a list of eltwise consumers
    mapping = {}
    for eltwise_node in graph.get_op_nodes(is_eltwise=True):
        unsqueeze_dims_map = compute_unsqueeze_map_for_eltwise(eltwise_node)
        for consumer_port in eltwise_node.in_ports().values():
            producer_port = consumer_port.get_source()
            unsqueeze_dims = unsqueeze_dims_map[producer_port]
            if unsqueeze_dims is not None and len(unsqueeze_dims) > 0:
                unsqueeze_dims = tuple([x for x in unsqueeze_dims])
                if producer_port not in mapping:
                    mapping.update({producer_port: {unsqueeze_dims: [consumer_port]}})
                elif unsqueeze_dims not in mapping[producer_port]:
                    mapping[producer_port].update({unsqueeze_dims: [consumer_port]})
                else:
                    mapping[producer_port][unsqueeze_dims].append(consumer_port)

    # Walk through each produced in the map and insert Unsqueeze nodes between a producer and eltwise nodes
    for producer_port in mapping.keys():
        producer_node = producer_port.node
        for unsqueeze_dims in mapping[producer_port].keys():
            unsqueeze_name = producer_node.soft_get('name', producer_node.id) + '/EltwiseUnsqueeze'
            unsqueeze_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(list(unsqueeze_dims))},
                                                         {'name': unsqueeze_name})
            unsqueeze_node.in_port(0).connect(producer_port)

            # Insert Unsqueeze with determined unsqueeze dimensions between the current producer and eltwise node
            for consumer_port in mapping[producer_port][unsqueeze_dims]:
                consumer_port.connect(unsqueeze_node.out_port(0))

            # The shape and value adjustments must be explicitly done within the transformation
            # since the transformation is called from Fusing transformation that excludes
            # automatic call of shape inference pass
            producer_port_value = producer_port.data.get_value()
            producer_port_shape = producer_port.data.get_shape()
            new_shape = producer_port_shape.copy()
            for unsqueeze_dim in unsqueeze_dims:
                new_shape = shape_insert(new_shape, unsqueeze_dim, 1)
            if producer_port_value is not None and is_fully_defined(new_shape):
                unsqueeze_node.out_port(0).data.set_value(np.reshape(producer_port_value, new_shape))
            else:
                unsqueeze_node.out_port(0).data.set_shape(new_shape)
Beispiel #12
0
    def infer(node: Node):
        node_name = node.soft_get('name', node.id)
        input_shape = node.in_port(0).data.get_shape()
        input_value = node.in_port(0).data.get_value()
        if input_shape is None:
            raise Error('Input shape for node "{}" is None'.format(node_name))

        assert len(node.in_nodes(
        )) == 1, 'Wrong number of inputs to the layer {}'.format(node_name)

        if not node.has_valid('expand_axis'):
            raise Error(
                'ExpandDims axis is not defined for node {}'.format(node_name))

        expand_axes = node.expand_axis
        if expand_axes is None:
            raise Error(
                'The "expand_axis" attribute is None for node "{}"'.format(
                    node_name))

        if isinstance(expand_axes, int):
            expand_axes = int64_array([expand_axes])
        elif expand_axes.ndim == 0:
            expand_axes = expand_axes.reshape([1])

        # expand_axis is a position where the new axis is placed so expand_dims works for negative axis in a different
        # way not as insert operation
        for expand_axis in expand_axes:
            if expand_axis < 0:
                expand_axis += len(input_shape) + 1

        expand_axes = sorted(expand_axes)
        output_shape = input_shape.copy()
        for expand_axis in expand_axes:
            output_shape = shape_insert(output_shape, expand_axis, 1)

        if input_value is not None and is_fully_defined(output_shape):
            node.out_port(0).data.set_value(input_value.reshape(output_shape))
        else:
            node.out_port(0).data.set_shape(output_shape)
Beispiel #13
0
 def test_shape_insert_raise_exception(self):
     with self.assertRaisesRegex(Error, '.*Incorrect parameter type.*'):
         shape_insert(gen_masked_array([1, 2, 3], []), 2, {})
Beispiel #14
0
 def test_shape_insert(self, shape, pos, values, result):
     self.assertTrue(
         strict_compare_tensors(shape_insert(shape, pos, values), result))
    def split_multilayer_cell(self, graph: Graph, match: dict):
        """
        Split one multilayer type=RNNSequence cell to num_layers consecutive cells.
        All parameters splits to parts for new num_layers cells.
        """
        input = match['input']
        rnn_layer = match['rnn_layer']
        params = match['params'].value.copy()

        have_hidden = False
        if 2 in rnn_layer.in_nodes():
            hidden_state_value = rnn_layer.in_node(2).value
            have_hidden = True

        have_cell = False
        if 3 in rnn_layer.in_nodes():
            cell_state_value = rnn_layer.in_node(3).value
            have_cell = True

        direction = 2 if rnn_layer.has_num_directions else 1
        num_layers = rnn_layer.num_layers
        input_size = input.shape[2]
        bsize = (2 * rnn_layer.hidden_size * direction *
                 num_layers) * rnn_layer.multiplier

        size = rnn_layer.hidden_size * direction * rnn_layer.multiplier
        first_layer_params_size = (input_size + rnn_layer.hidden_size +
                                   2) * size
        other_layer_params_size = (rnn_layer.hidden_size * direction +
                                   rnn_layer.hidden_size + 2) * size
        assert params.size == (first_layer_params_size +
                               (num_layers - 1) * other_layer_params_size)

        input_node = input
        params_layer_size_count = 0
        output_states = [[], []]

        param_w = params[0:len(params) - bsize]
        param_b = params[len(params) - bsize:]
        layer_bsize = (2 * rnn_layer.hidden_size *
                       direction) * rnn_layer.multiplier

        for l in range(num_layers):
            params_layer_size = first_layer_params_size if l == 0 else other_layer_params_size

            layer_params_w = param_w[
                params_layer_size_count:params_layer_size_count +
                (params_layer_size - layer_bsize)].copy()
            layer_params_b = param_b[layer_bsize * l:layer_bsize * l +
                                     layer_bsize].copy()
            layer_params = np.concatenate((layer_params_w, layer_params_b),
                                          axis=0)
            params_layer_size_count = params_layer_size_count + params_layer_size - layer_bsize

            op = self.get_new_cell(rnn_layer, l)
            name = str(rnn_layer.soft_get('name', rnn_layer.id))
            params_value_node = Const(
                rnn_layer.graph,
                dict(name=name + '/LayerSplittedParamsLSTM/{}/'.format(l),
                     value=layer_params)).create_node_with_data()

            if have_hidden:
                layer_hidden_state = hidden_state_value[l * direction:l *
                                                        direction + direction]
                hidden_state_value_node = Const(
                    rnn_layer.graph,
                    dict(name=name + '/LayerSplittedHiddenState/{}/'.format(l),
                         value=layer_hidden_state)).create_node_with_data()
            else:
                hidden_state_value_node = None

            if have_cell:
                layer_cell_state = cell_state_value[l *
                                                    direction:l * direction +
                                                    direction]
                cell_state_value_node = Const(
                    rnn_layer.graph,
                    dict(name=name + '/LayerSplittedCellState/{}/'.format(l),
                         value=layer_cell_state)).create_node_with_data()
            else:
                cell_state_value_node = None

            if l < num_layers - 1:
                output_data = Op._create_data_node(
                    rnn_layer.graph,
                    name=rnn_layer.out_node(0).name + '/LayerSplit/' + str(l),
                    attrs={'shape': rnn_layer.out_node(0).shape.copy()})
            else:
                output_data = rnn_layer.out_node(0)

            # Output nodes creating:
            state_size = int64_array(
                [input.shape[rnn_layer.batch_dim], rnn_layer.hidden_size])
            if rnn_layer.has_num_directions:
                state_size = shape_insert(state_size, 0, direction)

            output_hidden = Op._create_data_node(
                rnn_layer.graph,
                name=rnn_layer.out_node(1).name + '/LayerSplit/' + str(l),
                attrs={'shape': mo_array(state_size)})

            current_data_nodes = [output_data, output_hidden]

            if rnn_layer.op == 'LSTM':
                output_cell = Op._create_data_node(
                    rnn_layer.graph,
                    name=rnn_layer.out_node(2).name + '/LayerSplit/' + str(l),
                    attrs={'shape': mo_array(state_size)})
                current_data_nodes.append(output_cell)

            data_nodes = op.create_node_with_data(
                inputs=[
                    input_node, params_value_node, hidden_state_value_node,
                    cell_state_value_node
                ],
                data_nodes=current_data_nodes,
            )

            input_node = data_nodes[0]
            output_states[0].append(data_nodes[1])

            if rnn_layer.op == 'LSTM':
                output_states[1].append(data_nodes[2])

        return output_states
    def replace_pattern(graph, match: dict):
        # Here we will found all parts of TI: condition, inputs/outputs, back edges, body and create TensorIterator Op
        # and make all checks needed for TensorIterator work
        cond_data = match['condition'].out_node(
            0) if not match['condition'].out_port(0).disconnected() else None
        time_data = match['condition'].out_node(1) if len(
            match['condition'].out_nodes()) >= 1 else None
        name = match['condition'].name

        back_edges = []
        inputs = []
        outputs = []

        if cond_data is not None:
            for node in cond_data.out_nodes():
                if node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorBackEdge':
                    back_edges.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorInput':
                    inputs.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorOutput':
                    outputs.append(node.id)

        if time_data is not None:
            for node in time_data.out_nodes():
                if node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
                    inputs.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorOutput':
                    outputs.append(node.id)
                else:
                    # something goes wrong here
                    assert False
        condition = match['condition']
        tensor_sequence_length = condition.in_node(0)

        nodes_to_remove = [
            n.id
            for n in (condition, cond_data, time_data, tensor_sequence_length)
            if n is not None
        ]
        graph.remove_nodes_from(nodes_to_remove)

        body_nodes, extra_inputs = get_body(graph, inputs, outputs)

        if cond_data is not None:
            body_nodes = list(set(body_nodes) - set([cond_data]))

        inputs += extra_inputs

        assert all([node in graph.nodes() for node in body_nodes])

        inputs = [Node(graph, node) for node in inputs]
        outputs = [Node(graph, node) for node in outputs]
        back_edges = [Node(graph, node) for node in back_edges]

        external_inputs = [{
            'external_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'internal_data_id':
            node.out_node(0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in inputs]

        external_outputs = [{
            'external_data_id':
            node.out_node(0),
            'internal_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in outputs]

        back_edges_data = [{
            'from_data_id': node.in_node(1),
            'to_data_id': node.out_node(0),
            'init_data_id': node.in_node(0),
        } for node in back_edges]

        body = Graph(name='body')
        body.graph = graph.graph
        body.add_nodes_from([(node, graph.node[node]) for node in body_nodes])
        body.add_edges_from([
            (u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True)
            if u in body_nodes and v in body_nodes
        ])

        graph.remove_nodes_from(body_nodes + [match['condition'].id] +
                                [inp.id for inp in inputs] +
                                [out.id for out in outputs])
        internal_id_count = 0
        real_back_edges = []
        for edge in back_edges_data:
            assert edge['from_data_id'].id in body.nodes()
            assert edge['to_data_id'].id in body.nodes()
            assert edge['init_data_id'].id in body.nodes()
            edge['from_data_id'] = Node(body, edge['from_data_id'].id)
            edge['to_data_id'] = Node(body, edge['to_data_id'].id)
            edge['init_data_id'] = Node(body, edge['init_data_id'].id)
            add_opoutput(body, edge['from_data_id'].id, 0, False)

            # Assign/reuse ids for the back-edge start; it comes from from_data_id
            assert len(edge['from_data_id'].in_nodes()) == 1
            # layer id
            if not edge['from_data_id'].in_node().has_valid(
                    'internal_layer_id'):
                edge['from_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            edge['from_layer'] = edge['from_data_id'].in_node(
            )['internal_layer_id']

            # port id
            if 'internal_port_id' not in edge['from_data_id'].in_edge():
                edge['from_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            edge['from_port'] = edge['from_data_id'].in_edge(
            )['internal_port_id']

            # Look at all consumers for a data that ends a back-edge
            # For each such consumer, there will be a separate back-edge (and input)
            current_real_back_edges = []
            for _, consumer, key, edge_attrs in body.out_edges(
                    edge['to_data_id'].id, data=True, keys=True):

                real_edge = {}
                real_edge.update(
                    edge)  # all real back_edges have the same back-edge start

                consumer = Node(body, consumer)

                if real_edge['to_data_id'].in_node().has_valid(
                        'internal_layer_id'):
                    assert False
                    real_edge['to_data_id'].out_node()['internal_layer_id'] = \
                        real_edge['to_data_id'].in_node().internal_layer_id
                elif not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                real_edge['to_layer'] = consumer['internal_layer_id']

                assert 'internal_port_id' not in edge_attrs
                assert len(real_edge['init_data_id'].out_edges()) == 1
                assert not 'internal_port_id' in real_edge[
                    'init_data_id'].out_edge()
                edge_attrs['internal_port_id'] = internal_id_count
                internal_id_count += 1
                real_edge['to_port'] = edge_attrs['internal_port_id']
                real_edge['consumer'] = consumer
                real_edge['consumer_key'] = key

                real_edge['attrs'] = deepcopy(edge_attrs)
                current_real_back_edges.append(real_edge)

            # connect initial data node with each consumer providing actual edge attributes
            body.add_edges_from([
                (real_edge['init_data_id'].id, real_edge['consumer'].id,
                 real_edge['consumer_key'], real_edge['attrs'])
                for real_edge in current_real_back_edges
            ])

            body.remove_nodes_from(
                [edge['to_data_id'].id, edge['to_data_id'].in_node().id])
            real_back_edges += current_real_back_edges

        real_external_inputs = []

        for ext_inp in external_inputs:
            assert ext_inp['external_data_id'].id not in body.nodes()
            assert ext_inp['internal_data_id'].id in body.nodes()
            ext_inp['internal_data_id'] = Node(body,
                                               ext_inp['internal_data_id'].id)

            if ext_inp['axis'] is not None:
                # Insert squeezing resize at input port that has partitioning
                shape = ext_inp['internal_data_id'].shape.copy()
                assert not ext_inp['internal_data_id'].has_valid('value')
                new_input_data = Op._create_data_node(
                    body,
                    ext_inp['internal_data_id'].name + '/UnsqueezedInput',
                    dict(shape=shape_insert(shape, ext_inp['axis'], 1)))

                reshape_op = Squeeze(
                    body,
                    dict(name=ext_inp['internal_data_id'].name +
                         '/InputSqueeze'))
                reshape_dim_data = Const(
                    body, {
                        'name':
                        ext_inp['internal_data_id'].name + '/ReshapeDim',
                        'value': ext_inp['axis']
                    }).create_node_with_data()
                reshape_op.create_node_with_data(
                    [new_input_data, reshape_dim_data],
                    data_nodes=[ext_inp['internal_data_id']])
                ext_inp['internal_data_id'] = new_input_data

            ext_inp['internal_data_id']['is_input'] = True
            assert len(ext_inp['internal_data_id'].in_nodes()) == 0
            ext_inp['external_port_id'] = internal_id_count
            internal_id_count += 1
            for _, consumer, edge_attrs in body.out_edges(
                    ext_inp['internal_data_id'].id, data=True):
                real_ext_inp = {}
                real_ext_inp.update(ext_inp)
                consumer = Node(body, consumer)
                if not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                if not 'internal_port_id' in edge_attrs:
                    edge_attrs['internal_port_id'] = internal_id_count
                    internal_id_count += 1
                real_ext_inp['internal_layer_id'] = consumer[
                    'internal_layer_id']
                real_ext_inp['internal_port_id'] = edge_attrs[
                    'internal_port_id']
                real_external_inputs.append(real_ext_inp)

        for ext_out in external_outputs:
            assert ext_out['external_data_id'].id not in body.nodes()
            assert ext_out['internal_data_id'].id in body.nodes()
            ext_out['internal_data_id'] = Node(body,
                                               ext_out['internal_data_id'].id)

            if ext_out['axis'] is not None:
                # Insert unsqueezing resize at output port that has partitioning
                reshape_op = Unsqueeze(
                    body,
                    dict(name=ext_out['internal_data_id'].name +
                         '/OutputUnsqueeze'))
                reshape_dim_data = Const(
                    body, {
                        'name':
                        ext_out['internal_data_id'].name + '/ReshapeDim',
                        'value': ext_out['axis']
                    }).create_node_with_data()
                ext_out['internal_data_id'] = reshape_op.create_node_with_data(
                    [ext_out['internal_data_id'], reshape_dim_data])

            # TODO: add here working with simple outputs

            if not any([
                    out_node.soft_get('op', None) == 'Result'
                    for out_node in ext_out['internal_data_id'].out_nodes()
            ]):
                add_opoutput(body, ext_out['internal_data_id'].id, 0, False)

            # assert len(ext_out['internal_data_id'].out_nodes()) == 0
            assert len(ext_out['internal_data_id'].in_nodes()) == 1
            if not 'internal_layer_id' in ext_out['internal_data_id'].in_node(
            ):
                ext_out['internal_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            if not 'internal_port_id' in ext_out['internal_data_id'].in_edge():
                ext_out['internal_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            ext_out['internal_layer_id'] = ext_out['internal_data_id'].in_node(
            )['internal_layer_id']
            ext_out['internal_port_id'] = ext_out['internal_data_id'].in_edge(
            )['internal_port_id']
            ext_out['external_port_id'] = internal_id_count
            internal_id_count += 1

        # create TensorIterator layer with pre-computed components
        ti_op = TensorIterator(
            graph, {
                'name':
                name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                len(external_inputs),
                'out_ports_count':
                len(external_outputs),
                'input_port_map': [{
                    field: external_input[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_input in real_external_inputs],
                'output_port_map': [{
                    field: external_output[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_output in external_outputs],
                'back_edges': [{
                    field: edge[field]
                    for field in
                    ['from_layer', 'from_port', 'to_layer', 'to_port']
                } for edge in real_back_edges],
            })

        ti_outs = ti_op.create_node_with_data(
            inputs=[inp['external_data_id'] for inp in external_inputs],
            edge_attrs=[{
                'external_port_id': inp['external_port_id']
            } for inp in external_inputs],
            data_nodes=[out['external_data_id'] for out in external_outputs])

        if not isinstance(ti_outs, list):
            ti_outs = [ti_outs]

        for i, out in enumerate(ti_outs):
            out.in_edge(
            )['external_port_id'] = external_outputs[i]['external_port_id']

        ti = ti_outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)
Beispiel #17
0
def rnn_infer(node: Node, out_ports=None):
    """
    General infer function for RNN, GRU, LSTM layers.
    Assume that 0-port input of node is input data for recurrent layer and node have attrs:
    hidden_size,
    """
    if out_ports is None:
        out_ports = []

    # 1. Necessary checks (from ONNX specification)
    assert node.batch_dim <= 1
    assert node.sequence_dim <= 1
    assert node.batch_dim != node.sequence_dim
    assert node.direction in ['forward', 'reverse', 'bidirectional']

    if node.blobs_wrb:
        mark_input_bins(node, ['W', 'R', 'B'])
    else:
        mark_input_bins(node)

    # 2. Output shape calculations
    input_shape = node.in_node(0).shape
    assert len(input_shape) == 3

    # Reshape input nodes
    for port in [2, 3]:
        if port in node.in_nodes() and len(node.in_node(port).in_nodes()) > 0 and \
                'zero_shapes' in node.in_node(port).in_node():
            for i in node.in_node(port).in_node().zero_shapes:
                if node.in_node(port).shape[i] != input_shape[i]:
                    node.in_node(port).value = np.repeat(
                        node.in_node(port).value, input_shape[i], axis=i)
                    node.in_node(port).shape[i] = input_shape[i]

    out_shape = [
        input_shape[node.sequence_dim], input_shape[node.batch_dim],
        node.hidden_size
    ]

    if node.batch_dim == 0:
        out_shape = [
            input_shape[node.batch_dim], input_shape[node.sequence_dim],
            node.hidden_size
        ]

    num_directions = 2 if node.direction in ['bidirectional'] else 1
    if node.has_num_directions:
        if node.format == 'mxnet' and node.normalized is False:
            # In MXNet RNN layer return output with shape [seq_len, batch_size, hidden_size * num_directions]
            out_shape[-1] *= num_directions
        else:
            # ONNX-like, insert extra dimension to output shape for num_directions
            out_shape = shape_insert(out_shape, 1, np.int64(num_directions))

    # 0 output is required creating it if doesn't exist
    if 0 not in node.out_nodes():
        data_node = Op._create_data_node(node.graph,
                                         name=node.node +
                                         '/ExtraOutput/{}'.format(0),
                                         attrs={'executable': True})
        if 0 not in node.out_ports():
            node.add_output_port(0)
        node.graph.add_edge(node.id, data_node.id, key=0, out=0)
        add_opoutput(node.graph, data_node.id, 0, False)
    node.out_port(0).data.set_shape(out_shape)

    # 3. Extra outputs for hidden/cell states shape calculations (optional)
    state_size = [input_shape[node.batch_dim], node.hidden_size]
    if node.has_num_directions:
        state_size = shape_insert(state_size, 0, num_directions)

    if node.multilayers:
        # For multilayer case state sizes from every layer will be concatenated by last axis
        num_layers = node.num_layers
        state_size[-1] *= num_layers

    for i in out_ports:
        # If node hasn't consumers for hidden/cells state -> create them
        if i not in node.out_nodes():
            data_node = Op._create_data_node(node.graph,
                                             name=node.node + '/ExtraOutput/' +
                                             str(i),
                                             attrs={'executable': True})
            if i not in node.out_ports():
                node.add_output_port(i)
            node.graph.add_edge(node.id, data_node.id, key=0, out=i)
            add_opoutput(node.graph, data_node.id, 0, False)
        else:
            data_node = node.out_node(i)
        data_node.shape = shape_array(state_size)