def find_and_replace_pattern(self, graph: Graph):
        # Iterate over all data nodes and find all with >= 1 consumers
        data_nodes = [Node(graph, node) for node in graph.node if Node(graph, node).kind == 'data']
        for input_data in data_nodes:
            # We don't use constant data nodes
            if input_data.value is not None:
                continue

            input_shape = np.array(input_data.shape)

            # Get all StridedSlice consumers
            out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice' and node.in_node(0).name == input_data.name]
            if len(out_nodes) < 1:
                continue

            valid_for_replacement = True

            for node in out_nodes:
                if len(node.slices) != len(out_nodes[0].slices):
                    valid_for_replacement = False

            # Detect dimension for splitting
            split_channel_dim = None
            for dim_id, s in enumerate(out_nodes[0].slices):
                l, r, stride = s.start, s.stop, s.step
                if l != 0 or r != input_shape[dim_id]:
                    if split_channel_dim is None:
                        split_channel_dim = dim_id
                    else:
                        valid_for_replacement = False

            # split_dims contains tuples with split range and output data node
            split_dims = []
            for out_id, node in enumerate(out_nodes):
                # Check that StridedSlice op has stride eq 1 and splits only feature channel
                for id, s in enumerate(node.slices):
                    l, r, stride = s.start, s.stop, s.step
                    # We don't support StridedSlice with stride != 1
                    if stride != 1:
                        valid_for_replacement = False
                    if id == split_channel_dim:
                        split_dims.append((s.start, s.stop, node.out_node()))

            if not valid_for_replacement:
                continue

            # Check feature split intersection
            final_data_nodes_list = []
            sorted_split_dims = sorted(split_dims, key=lambda item: (item[0], item[1]))

            # check if we have similar StridedSlice operations with different outputs
            prev_sd = sorted_split_dims[0]
            to_remove = []
            for i in range(1, len(sorted_split_dims)):
                if sorted_split_dims[i][0] == prev_sd[0] and sorted_split_dims[i][1] == prev_sd[1] and sorted_split_dims[i][2].name != prev_sd[2].name:
                    cur_node = sorted_split_dims[i][2]
                    for out in cur_node.out_nodes():
                        attrs = deepcopy(graph.get_edge_data(cur_node.id, out.id)[0])
                        graph.remove_edge(cur_node.id, out.id)
                        graph.add_edge(prev_sd[2].id, out.id, **attrs)
                    to_remove.append(i)

            for ind in reversed(to_remove):
                sorted_split_dims.pop(ind)

            size_splits = []
            prev_r = 0
            for l, r, out in sorted_split_dims:
                # Split dims shouldn't intersect
                if l < prev_r:
                    valid_for_replacement = False
                # Save missing tensor part
                if l > prev_r:
                    shape = np.array(input_shape)
                    size_splits.append(l - prev_r)
                    shape[split_channel_dim] = l - prev_r
                    data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape})
                    add_opoutput(graph, data_node.id, 0, False)
                    final_data_nodes_list.append(data_node)

                prev_r = r
                size_splits.append(r - l)
                final_data_nodes_list.append(out)

            if prev_r > input_shape[split_channel_dim]:
                valid_for_replacement = False
            elif prev_r != input_shape[split_channel_dim]:
                # Add last part of tensor
                shape = input_shape.copy()
                shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r
                size_splits.append(input_shape[split_channel_dim] - prev_r)
                data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape})
                add_opoutput(graph, data_node.id, 0, False)
                final_data_nodes_list.append(data_node)

            if not valid_for_replacement:
                continue

            for node in out_nodes:
                if not np.all([x == 0 for x in node.shrink_axis_mask]):
                    out_node = node.out_node()
                    if np.any(node['shrink_axis_mask']):
                        self.add_reshape_for_shrink(graph, node)
                    if np.any(node['new_axis_mask']):
                        self.add_reshape_for_new(graph, node)

                    for i in range(len(final_data_nodes_list)):
                        if final_data_nodes_list[i].name == out_node.name:
                            final_data_nodes_list[i] = node.out_node()
                            break

            # Insert Split layer and remove old StridedSlice layers
            # 1. Remove connections from input_data to StridedSlice ops
            out_data_nodes = []
            name_for_future_split = out_nodes[0].name
            for node in out_nodes:
                out_data_nodes.append(node.out_node())
                graph.remove_edge(input_data.id, node.id)
                graph.remove_edge(node.id, node.out_node().id)
                graph.remove_node(node.id)
                log.debug("Removed: {}".format(node.id))

            # 2. Create Split layer and reorder outputs
            split = SplitV(graph, dict(name=name_for_future_split + "/Split", axis=split_channel_dim,
                                       size_splits=size_splits, out_ports_count=len(size_splits)))
            split.create_node_with_data(inputs=[input_data], data_nodes=final_data_nodes_list)
    def replace_pattern(self, graph: Graph, match: dict):
        if match['rnn_layer']['op'] == 'LSTM':
            return

        rnn_layer = match['rnn_layer']

        # Build TensorIterator body first
        body = Graph(name=rnn_layer.name + '/sub_graph')
        body.graph = graph.graph

        # 1. Input squeeze Reshape
        inputs = [
            Op._create_data_node(
                body, rnn_layer.name + '/inport/' + str(inp), {
                    'shape':
                    rnn_layer.in_node(inp).shape.copy(),
                    'value':
                    rnn_layer.in_node(inp).value.copy()
                    if rnn_layer.in_node(inp).value is not None
                    and inp in [1, 2] else None
                }) for inp in [0, 4, 1, 2]
        ]  # X, h_init, WR, B

        inputs[0].shape[rnn_layer.sequence_dim] = 1
        input_squeeze = Squeeze(
            body,
            dict(name=rnn_layer.name + '/input_squeeze', internal_layer_id=0))
        input_squeeze_dim = Const(
            body,
            dict(name=rnn_layer.name + '/input_squeeze_dim',
                 value=rnn_layer.sequence_dim)).create_node_with_data()
        inputs[0] = input_squeeze.create_node_with_data(
            [inputs[0], input_squeeze_dim],
            edge_attrs=[{
                'internal_port_id': 0
            }])

        # 2. Output unsqueeze Reshape
        outputs = [
            Op._create_data_node(
                body, rnn_layer.name + '/outport/' + str(out), {
                    'shape':
                    rnn_layer.out_node(out).shape.copy()
                    if out in rnn_layer.out_nodes() else None
                }) for out in [0]
        ]
        for out in outputs:
            add_opoutput(body, out.id, 0, False)

        outputs[0].shape = shape_delete(outputs[0].shape,
                                        rnn_layer.sequence_dim)
        output_unsqueeze_dim = Const(
            body,
            dict(name=rnn_layer.name + '/output_unsqueeze_dim',
                 value=rnn_layer.sequence_dim)).create_node_with_data()
        output_unsqueeze = Unsqueeze(
            body,
            dict(name=rnn_layer.name + '/output_unsqueeze/',
                 internal_layer_id=2))

        additional_attrs = dict(activations=rnn_layer.activations,
                                activation_alpha=rnn_layer.activation_alpha,
                                activation_beta=rnn_layer.activation_beta,
                                clip=rnn_layer.clip)
        if rnn_layer.op == 'GRU':
            additional_attrs[
                'linear_before_reset'] = rnn_layer.linear_before_reset

        # 3. ***Cell
        rnn_cell_op = self.get_rnn_cell(rnn_layer['op'])(
            body,
            dict(hidden_size=rnn_layer.hidden_size,
                 name=rnn_layer.name + '/{}Cell'.format(rnn_layer.op),
                 **additional_attrs,
                 internal_layer_id=1))

        gru_cell = rnn_cell_op.create_node_with_data(inputs,
                                                     data_nodes=outputs,
                                                     edge_attrs=[{}, {
                                                         'internal_port_id':
                                                         1
                                                     }, {
                                                         'internal_port_id':
                                                         2
                                                     }, {
                                                         'bin':
                                                         'weights'
                                                     }, {
                                                         'bin':
                                                         'biases'
                                                     }])

        # internal ports for outputs of cell
        gru_cell.in_node().out_edge(0)['internal_port_id'] = 4  # h_state

        gru_cell = output_unsqueeze.create_node_with_data(
            [gru_cell, output_unsqueeze_dim])
        gru_cell.in_node().out_edge(0)['internal_port_id'] = 3
        add_opoutput(body, gru_cell.id, 0, False)

        # 4. TensorIterator layer creating
        assert rnn_layer.direction in ['forward', 'reverse']
        if rnn_layer.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert rnn_layer.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        # stacked h_state
        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,
            'axis': rnn_layer.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        # Adding last h_state to outputs
        if len(rnn_layer.out_nodes()) == 2:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }])

        ti_op = TensorIterator(
            graph,
            {
                'name':
                rnn_layer.name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                4,
                'out_ports_count':
                len(rnn_layer.out_nodes()),
                'input_port_map': [
                    {
                        'external_port_id': 0,
                        'internal_layer_id': 0,
                        'internal_port_id': 0,
                        'axis': rnn_layer.sequence_dim,
                        'stride': stride,
                        'start': start,
                        'end': end,
                        'part_size': 1,
                    },
                    {
                        'external_port_id': 1,
                        'internal_layer_id': 1,
                        'internal_port_id': 1,
                    },
                ],
                'output_port_map':
                output_port_map,
                # only for h state
                'back_edges': [
                    {
                        'from_layer': 1,
                        'from_port': 4,
                        'to_layer': 1,
                        'to_port': 1,
                    },
                ]
            })

        assert sorted(rnn_layer.out_nodes().keys()) == list(range(len(rnn_layer.out_nodes()))), \
            "There are gaps in output ports of GRUSequence operation. Node {}".format(rnn_layer.id)

        outs = ti_op.create_node_with_data(
            [rnn_layer.in_node(i) for i in [0, 4]],  # X, h_init
            data_nodes=[
                rnn_layer.out_node(i)
                for i in range(len(rnn_layer.out_nodes()))
            ],
            edge_attrs=[{
                'external_port_id': 0
            }, {
                'external_port_id': 1
            }])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(rnn_layer.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id

        ti = outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)
示例#3
0
    def split_bidirectional(self,
                            bidirectional_cell: Node,
                            new_init_hiddens: list,
                            new_init_cells: list,
                            splitted_W: tuple,
                            splitted_R: tuple,
                            splitted_B: tuple):
        """
            Split one bidirectional RNNSequence node into 2 one-directional RNNSequence nodes.

            All input data nodes should be already prepared; they are
            have 2 in the num_dir dimension.
        """
        all_outputs = []
        for i in [0, 1]:
            direction = ['forward', 'reverse'][i]
            op = self.get_new_cell(bidirectional_cell, direction)

            output_data = Op._create_data_node(
                bidirectional_cell.graph,
                name=bidirectional_cell.out_node(0).name + '/Split/' + str(i),
                attrs={'shape': bidirectional_cell.out_node(0).shape.copy()}
            )

            assert output_data.shape[1] == 2
            output_data.shape[1] = 1

            output_hidden = Op._create_data_node(
                bidirectional_cell.graph,
                name=bidirectional_cell.out_node(1).name + '/Split/' + str(i),
                attrs={'shape': bidirectional_cell.out_node(1).shape.copy()}
            )

            assert output_hidden.shape[0] == 2
            output_hidden.shape[0] = 1

            data_nodes = [
                output_data,
                output_hidden,
            ]

            if bidirectional_cell.op == 'LSTM':
                output_cell = Op._create_data_node(
                    bidirectional_cell.graph,
                    name=bidirectional_cell.out_node(2).name + '/Split/' + str(i),
                    attrs={'shape': bidirectional_cell.out_node(2).shape.copy()}
                )

                assert output_cell.shape[0] == 2
                output_cell.shape[0] = 1

                data_nodes.append(output_cell)

            all_outputs.append(
                op.create_node_with_data(
                    inputs=[
                        bidirectional_cell.in_node(0),
                        splitted_W[i],
                        splitted_R[i],
                        splitted_B[i],
                        None,
                        new_init_hiddens[i],
                        new_init_cells[i] if bidirectional_cell.op == 'LSTM' else None,
                    ],
                    data_nodes=data_nodes
                )
            )
        return all_outputs
    def replace_pattern(self, graph: Graph, match: dict):
        lstm = match['lstm']

        # Build TensorIterator body first
        body = Graph(name=lstm.name + '/sub_graph')
        body.graph = graph.graph

        # 1. Input squeeze Reshape
        inputs = [
            Op._create_data_node(
                body, lstm.name + '/inport/' + str(inp), {
                    'shape':
                    lstm.in_node(inp).shape.copy(),
                    'value':
                    lstm.in_node(inp).value.copy() if lstm.in_node(inp).value
                    is not None and inp in [1, 2] else None
                }) for inp in [0, 4, 5, 1, 2]
        ]  # X, WR, B, h_init, c_init

        inputs[0].shape[lstm.sequence_dim] = 1
        input_squeeze = Squeeze(
            body, dict(name=lstm.name + '/input_squeeze', internal_layer_id=0))
        squeeze_dim_data = Const(body, {
            'name': lstm.name + '/input_squeeze_dim',
            'value': [lstm.sequence_dim]
        }).create_node_with_data()
        inputs[0] = input_squeeze.create_node_with_data(
            [inputs[0], squeeze_dim_data],
            edge_attrs=[{
                'internal_port_id': 0
            }])

        # 2. Output unsqueeze Reshape
        outputs = [
            Op._create_data_node(
                body, lstm.name + '/outport/' + str(out), {
                    'shape':
                    lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
                    else lstm.in_node(4).shape.copy()
                }) for out in [0, 1]
        ]
        for out in outputs:
            add_opoutput(body, out.id, 0, False)

        outputs[0].shape = np.delete(outputs[0].shape, lstm.sequence_dim)
        output_unsqueeze = Unsqueeze(
            body, dict(name=lstm.name + 'output_unsqueeze',
                       internal_layer_id=2))
        unsqueeze_dim_data = Const(
            body, {
                'name': lstm.name + '/output_unsqueeze_dim',
                'value': [lstm.sequence_dim]
            }).create_node_with_data()

        # 3. LSTMCell
        lstm_cell_op = LSTMCell(
            body,
            dict(hidden_size=lstm.hidden_size,
                 activations=lstm.activations,
                 activation_alpha=lstm.activation_alpha,
                 activation_beta=lstm.activation_beta,
                 clip=lstm.clip,
                 input_forget=lstm.input_forget,
                 name=lstm.name + '/LSTMCell',
                 internal_layer_id=1))
        lstm_cell_node = lstm_cell_op.create_node_with_data(
            inputs,
            data_nodes=outputs,
            edge_attrs=[{}, {
                'internal_port_id': 1
            }, {
                'internal_port_id': 2
            }, {
                'bin': 'weights'
            }, {
                'bin': 'biases'
            }])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4
        lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
        lstm_cell_node[0] = output_unsqueeze.create_node_with_data(
            [lstm_cell_node[0], unsqueeze_dim_data])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
        add_opoutput(body, lstm_cell_node[0].id, 0, False)

        # 4. TensorIterator layer creating
        assert lstm.direction in ['forward', 'reverse']
        if lstm.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert lstm.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,
            'axis': lstm.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        # Adding h_state, c_state to outputs
        if len(lstm.out_nodes()) == 3:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }, {
                'external_port_id': 5,
                'internal_layer_id': 1,
                'internal_port_id': 5,
            }])

        ti_op = TensorIterator(
            graph, {
                'name':
                lstm.name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                3,
                'out_ports_count':
                len(lstm.out_nodes()),
                'input_port_map': [
                    {
                        'external_port_id': 0,
                        'internal_layer_id': 0,
                        'internal_port_id': 0,
                        'axis': lstm.sequence_dim,
                        'stride': stride,
                        'start': start,
                        'end': end,
                        'part_size': 1,
                    },
                    {
                        'external_port_id': 1,
                        'internal_layer_id': 1,
                        'internal_port_id': 1,
                    },
                    {
                        'external_port_id': 2,
                        'internal_layer_id': 1,
                        'internal_port_id': 2,
                    },
                ],
                'output_port_map':
                output_port_map,
                'back_edges': [
                    {
                        'from_layer': 1,
                        'from_port': 4,
                        'to_layer': 1,
                        'to_port': 1,
                    },
                    {
                        'from_layer': 1,
                        'from_port': 5,
                        'to_layer': 1,
                        'to_port': 2,
                    },
                ]
            })

        assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
            "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)

        outs = ti_op.create_node_with_data(
            [lstm.in_node(i) for i in [0, 4, 5]],  # X, h_init, c_init
            data_nodes=[
                lstm.out_node(i) for i in range(len(lstm.out_nodes()))
            ],
            edge_attrs=[{
                'external_port_id': 0
            }, {
                'external_port_id': 1
            }, {
                'external_port_id': 2
            }])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(lstm.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id

        ti = outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)
示例#5
0
    def split_multilayer_cell(self, graph: Graph, match: dict):
        """
        Split one multilayer type=RNNSequence cell to num_layers consecutive cells.
        All parameters splits to parts for new num_layers cells.
        """
        input = match['input']
        rnn_layer = match['rnn_layer']
        params = match['params'].value.copy()

        have_hidden = False
        if 2 in rnn_layer.in_nodes():
            hidden_state_value = rnn_layer.in_node(2).value
            have_hidden = True

        have_cell = False
        if 3 in rnn_layer.in_nodes():
            cell_state_value = rnn_layer.in_node(3).value
            have_cell = True

        direction = 2 if rnn_layer.has_num_directions else 1
        num_layers = rnn_layer.num_layers
        input_size = input.shape[2]
        bsize = (2 * rnn_layer.hidden_size * direction *
                 num_layers) * rnn_layer.multiplier

        size = rnn_layer.hidden_size * direction * rnn_layer.multiplier
        first_layer_params_size = (input_size + rnn_layer.hidden_size +
                                   2) * size
        other_layer_params_size = (rnn_layer.hidden_size * direction +
                                   rnn_layer.hidden_size + 2) * size
        assert params.size == (first_layer_params_size +
                               (num_layers - 1) * other_layer_params_size)

        input_node = input
        params_layer_size_count = 0
        output_states = [[], []]

        param_w = params[0:len(params) - bsize]
        param_b = params[len(params) - bsize:]
        layer_bsize = (2 * rnn_layer.hidden_size *
                       direction) * rnn_layer.multiplier

        for l in range(num_layers):
            params_layer_size = first_layer_params_size if l == 0 else other_layer_params_size

            layer_params_w = param_w[
                params_layer_size_count:params_layer_size_count +
                (params_layer_size - layer_bsize)].copy()
            layer_params_b = param_b[layer_bsize * l:layer_bsize * l +
                                     layer_bsize].copy()
            layer_params = np.concatenate((layer_params_w, layer_params_b),
                                          axis=0)
            params_layer_size_count = params_layer_size_count + params_layer_size - layer_bsize

            op = self.get_new_cell(rnn_layer, l)
            name = str(rnn_layer.soft_get('name', rnn_layer.id))
            params_value_node = Const(
                rnn_layer.graph,
                dict(name=name + '/LayerSplittedParamsLSTM/{}/'.format(l),
                     value=layer_params)).create_node_with_data()

            if have_hidden:
                layer_hidden_state = hidden_state_value[l * direction:l *
                                                        direction + direction]
                hidden_state_value_node = Const(
                    rnn_layer.graph,
                    dict(name=name + '/LayerSplittedHiddenState/{}/'.format(l),
                         value=layer_hidden_state)).create_node_with_data()
            else:
                hidden_state_value_node = None

            if have_cell:
                layer_cell_state = cell_state_value[l *
                                                    direction:l * direction +
                                                    direction]
                cell_state_value_node = Const(
                    rnn_layer.graph,
                    dict(name=name + '/LayerSplittedCellState/{}/'.format(l),
                         value=layer_cell_state)).create_node_with_data()
            else:
                cell_state_value_node = None

            if l < num_layers - 1:
                output_data = Op._create_data_node(
                    rnn_layer.graph,
                    name=rnn_layer.out_node(0).name + '/LayerSplit/' + str(l),
                    attrs={'shape': rnn_layer.out_node(0).shape.copy()})
            else:
                output_data = rnn_layer.out_node(0)

            # Output nodes creating:
            state_size = np.array(
                [input.shape[rnn_layer.batch_dim], rnn_layer.hidden_size],
                dtype=np.int64)
            if rnn_layer.has_num_directions:
                state_size = np.insert(state_size, 0, direction)

            output_hidden = Op._create_data_node(
                rnn_layer.graph,
                name=rnn_layer.out_node(1).name + '/LayerSplit/' + str(l),
                attrs={'shape': np.array(state_size)})

            current_data_nodes = [output_data, output_hidden]

            if rnn_layer.op == 'LSTM':
                output_cell = Op._create_data_node(
                    rnn_layer.graph,
                    name=rnn_layer.out_node(2).name + '/LayerSplit/' + str(l),
                    attrs={'shape': np.array(state_size)})
                current_data_nodes.append(output_cell)

            data_nodes = op.create_node_with_data(
                inputs=[
                    input_node, params_value_node, hidden_state_value_node,
                    cell_state_value_node
                ],
                data_nodes=current_data_nodes,
            )

            input_node = data_nodes[0]
            output_states[0].append(data_nodes[1])

            if rnn_layer.op == 'LSTM':
                output_states[1].append(data_nodes[2])

        return output_states
    def replace_pattern(graph, match: dict):
        # Here we will found all parts of TI: condition, inputs/outputs, back edges, body and create TensorIterator Op
        # and make all checks needed for TensorIteator work
        cond_data = match['condition'].out_node(0)
        time_data = match['condition'].out_node(1) if len(
            match['condition'].out_nodes()) > 1 else None
        name = match['condition'].name

        assert match['condition'].in_node(0).has_valid('value')

        back_edges = []
        inputs = []
        outputs = []

        for node in cond_data.out_nodes():
            if node['kind'] == 'op' and node['op'] == 'TensorIteratorBackEdge':
                back_edges.append(node.id)
            elif node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
                inputs.append(node.id)
            elif node['kind'] == 'op' and node['op'] == 'TensorIteratorOutput':
                outputs.append(node.id)

        if time_data is not None:
            for node in time_data.out_nodes():
                if node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
                    inputs.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorOutput':
                    outputs.append(node.id)
                else:
                    # something goes wrong here
                    assert False

        condition = match['condition']
        tensor_sequence_length = condition.in_node(0)
        graph.remove_nodes_from(
            [condition.id, cond_data.id, tensor_sequence_length.id])
        if time_data is not None:
            graph.remove_nodes_from([time_data.id])

        body_nodes, extra_inputs = get_body(graph, inputs, outputs)
        body_nodes = list(set(body_nodes) - set([cond_data]))

        inputs += extra_inputs

        assert all([node in graph.nodes() for node in body_nodes])

        inputs = [Node(graph, node) for node in inputs]
        outputs = [Node(graph, node) for node in outputs]
        back_edges = [Node(graph, node) for node in back_edges]

        external_inputs = [{
            'external_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'internal_data_id':
            node.out_node(0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in inputs]

        external_outputs = [{
            'external_data_id':
            node.out_node(0),
            'internal_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in outputs]

        back_edges_data = [{
            'from_data_id': node.in_node(1),
            'to_data_id': node.out_node(0),
            'init_data_id': node.in_node(0),
        } for node in back_edges]

        body = Graph(name='body')
        body.graph = graph.graph
        body.add_nodes_from([(node, graph.node[node]) for node in body_nodes])
        body.add_edges_from([
            (u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True)
            if u in body_nodes and v in body_nodes
        ])

        graph.remove_nodes_from(body_nodes + [match['condition'].id] +
                                [inp.id for inp in inputs] +
                                [out.id for out in outputs])
        internal_id_count = 0
        real_back_edges = []
        for edge in back_edges_data:
            assert edge['from_data_id'].id in body.nodes()
            assert edge['to_data_id'].id in body.nodes()
            assert edge['init_data_id'].id in body.nodes()
            edge['from_data_id'] = Node(body, edge['from_data_id'].id)
            edge['to_data_id'] = Node(body, edge['to_data_id'].id)
            edge['init_data_id'] = Node(body, edge['init_data_id'].id)
            add_opoutput(body, edge['from_data_id'].id, 0, False)

            # Assign/reuse ids for the back-edge start; it comes from from_data_id
            assert len(edge['from_data_id'].in_nodes()) == 1
            # layer id
            if not edge['from_data_id'].in_node().has_valid(
                    'internal_layer_id'):
                edge['from_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            edge['from_layer'] = edge['from_data_id'].in_node(
            )['internal_layer_id']

            # port id
            if 'internal_port_id' not in edge['from_data_id'].in_edge():
                edge['from_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            edge['from_port'] = edge['from_data_id'].in_edge(
            )['internal_port_id']

            # Look at all consumers for a data that ends a back-edge
            # For each such consumer, there will be a separate back-edge (and input)
            current_real_back_edges = []
            for _, consumer, key, edge_attrs in body.out_edges(
                    edge['to_data_id'].id, data=True, keys=True):

                real_edge = {}
                real_edge.update(
                    edge)  # all real back_edges have the same back-edge start

                consumer = Node(body, consumer)

                if real_edge['to_data_id'].in_node().has_valid(
                        'internal_layer_id'):
                    assert False
                    real_edge['to_data_id'].out_node()['internal_layer_id'] = \
                        real_edge['to_data_id'].in_node().internal_layer_id
                elif not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                real_edge['to_layer'] = consumer['internal_layer_id']

                assert 'internal_port_id' not in edge_attrs
                assert len(real_edge['init_data_id'].out_edges()) == 1
                assert not 'internal_port_id' in real_edge[
                    'init_data_id'].out_edge()
                edge_attrs['internal_port_id'] = internal_id_count
                internal_id_count += 1
                real_edge['to_port'] = edge_attrs['internal_port_id']
                real_edge['consumer'] = consumer
                real_edge['consumer_key'] = key

                real_edge['attrs'] = deepcopy(edge_attrs)
                current_real_back_edges.append(real_edge)

            # connect initial data node with each consumer providing actual edge attributes
            body.add_edges_from([
                (real_edge['init_data_id'].id, real_edge['consumer'].id,
                 real_edge['consumer_key'], real_edge['attrs'])
                for real_edge in current_real_back_edges
            ])

            body.remove_nodes_from(
                [edge['to_data_id'].id, edge['to_data_id'].in_node().id])
            real_back_edges += current_real_back_edges

        real_external_inputs = []

        for ext_inp in external_inputs:
            assert ext_inp['external_data_id'].id not in body.nodes()
            assert ext_inp['internal_data_id'].id in body.nodes()
            ext_inp['internal_data_id'] = Node(body,
                                               ext_inp['internal_data_id'].id)

            if ext_inp['axis'] is not None:
                # Insert squeezing resize at input port that has partitioning
                shape = ext_inp['internal_data_id'].shape.copy()
                assert not ext_inp['internal_data_id'].has_valid('value')
                new_input_data = Op._create_data_node(
                    body,
                    ext_inp['internal_data_id'].name + '/UnsqueezedInput',
                    dict(shape=np.insert(shape, ext_inp['axis'], 1)))
                dim = shape.copy()
                # try to do it dynamically reshapable along one of the axis
                # it is practically useful to reshape along batch dimension, but here we cannot detect where it is
                # so, we are guessing based onother transflormaions that it is the major dimension
                dim[0] = -1
                reshape_op = Reshape(
                    body,
                    dict(name=ext_inp['internal_data_id'].name +
                         '/InputSqueeze',
                         dim=dim))
                reshape_op.create_node_with_data(
                    [new_input_data], data_nodes=[ext_inp['internal_data_id']])
                ext_inp['internal_data_id'] = new_input_data

            ext_inp['internal_data_id']['is_input'] = True
            assert len(ext_inp['internal_data_id'].in_nodes()) == 0
            ext_inp['external_port_id'] = internal_id_count
            internal_id_count += 1
            for _, consumer, edge_attrs in body.out_edges(
                    ext_inp['internal_data_id'].id, data=True):
                real_ext_inp = {}
                real_ext_inp.update(ext_inp)
                consumer = Node(body, consumer)
                if not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                if not 'internal_port_id' in edge_attrs:
                    edge_attrs['internal_port_id'] = internal_id_count
                    internal_id_count += 1
                real_ext_inp['internal_layer_id'] = consumer[
                    'internal_layer_id']
                real_ext_inp['internal_port_id'] = edge_attrs[
                    'internal_port_id']
                real_external_inputs.append(real_ext_inp)

        for ext_out in external_outputs:
            assert ext_out['external_data_id'].id not in body.nodes()
            assert ext_out['internal_data_id'].id in body.nodes()
            ext_out['internal_data_id'] = Node(body,
                                               ext_out['internal_data_id'].id)

            if ext_out['axis'] is not None:
                # Insert unsqueezing resize at output port that has partitioning
                dim = ext_out['internal_data_id'].shape.copy()
                # trying to make it dynamically reshapable (see related comment above for the first Reshape)
                dim[0] = -1
                assert not ext_out['internal_data_id'].has_valid('value')
                reshape_op = Reshape(
                    body,
                    dict(name=ext_out['internal_data_id'].name +
                         '/OutputUnsqueeze',
                         dim=np.insert(dim, ext_out['axis'], 1)))
                ext_out['internal_data_id'] = reshape_op.create_node_with_data(
                    [ext_out['internal_data_id']])

            # TODO: add here working with simple outputs

            add_opoutput(body, ext_out['internal_data_id'].id, 0, False)
            # assert len(ext_out['internal_data_id'].out_nodes()) == 0
            assert len(ext_out['internal_data_id'].in_nodes()) == 1
            if not 'internal_layer_id' in ext_out['internal_data_id'].in_node(
            ):
                ext_out['internal_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            if not 'internal_port_id' in ext_out['internal_data_id'].in_edge():
                ext_out['internal_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            ext_out['internal_layer_id'] = ext_out['internal_data_id'].in_node(
            )['internal_layer_id']
            ext_out['internal_port_id'] = ext_out['internal_data_id'].in_edge(
            )['internal_port_id']
            ext_out['external_port_id'] = internal_id_count
            internal_id_count += 1

        ti_op = TensorIterator(
            graph, {
                'name':
                name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                len(external_inputs),
                'out_ports_count':
                len(external_outputs),
                'input_port_map': [{
                    field: external_input[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_input in real_external_inputs],
                'output_port_map': [{
                    field: external_output[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_output in external_outputs],
                'back_edges': [{
                    field: edge[field]
                    for field in
                    ['from_layer', 'from_port', 'to_layer', 'to_port']
                } for edge in real_back_edges],
            })

        ti_outs = ti_op.create_node_with_data(
            inputs=[inp['external_data_id'] for inp in external_inputs],
            edge_attrs=[{
                'external_port_id': inp['external_port_id']
            } for inp in external_inputs],
            data_nodes=[out['external_data_id'] for out in external_outputs])

        if not isinstance(ti_outs, list):
            ti_outs = [ti_outs]

        for i, out in enumerate(ti_outs):
            out.in_edge(
            )['external_port_id'] = external_outputs[i]['external_port_id']
    def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
        # Iterate over all data nodes and find all with >= 1 consumers
        data_nodes = [Node(graph, node) for node in graph.node if Node(graph, node).kind == 'data']
        for input_data in data_nodes:
            # We don't use constant data nodes
            if input_data.value is not None:
                continue

            input_shape = np.array(input_data.shape)

            # Get all StridedSlice consumers
            out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice']
            if len(out_nodes) < 1:
                continue

            valid_for_replacement = True

            # Detect dimension for splitting
            split_channel_dim = None
            for dim_id, s in enumerate(out_nodes[0].slices):
                l, r, stride = s.start, s.stop, s.step
                if l != 0 or r != input_shape[dim_id]:
                    if split_channel_dim is None:
                        split_channel_dim = dim_id
                    else:
                        valid_for_replacement = False

            # split_dims contains tuples with split range and output data node
            split_dims = []
            for out_id, node in enumerate(out_nodes):
                # Check that StridedSlice op has no shrink_axis_mask attribute
                if not np.all([x == False for x in node.shrink_axis_mask]):
                    valid_for_replacement = False
                # Check that StridedSlice op has stride eq 1 and splits only feature channel
                for id, s in enumerate(node.slices):
                    l, r, stride = s.start, s.stop, s.step
                    # We don't support StridedSlice with stride != 1
                    if stride != 1:
                        valid_for_replacement = False
                    if id == split_channel_dim:
                        split_dims.append((s.start, s.stop, node.out_node()))

            if not valid_for_replacement:
                continue

            # Check feature split intersection
            final_data_nodes_list = []
            sorted_split_dims = sorted(split_dims)
            size_splits = []
            prev_r = 0
            for l, r, out in sorted_split_dims:
                # Split dims shouldn't intersect
                if l < prev_r:
                    valid_for_replacement = False
                # Save missing tensor part
                if l > prev_r:
                    shape = np.array(input_shape)
                    size_splits.append(l - prev_r)
                    shape[split_channel_dim] = l - prev_r
                    data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape, 'is_output': True})
                    final_data_nodes_list.append(data_node)


                prev_r = r
                size_splits.append(r - l)
                final_data_nodes_list.append(out)

            if prev_r > input_shape[split_channel_dim]:
                valid_for_replacement = False
            elif prev_r != input_shape[split_channel_dim]:
                # Add last part of tensor
                shape = input_shape.copy()
                shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r
                size_splits.append(input_shape[split_channel_dim] - prev_r)
                data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape, 'is_output': True})
                final_data_nodes_list.append(data_node)

            if not valid_for_replacement:
                continue

            # Insert Split layer and remove old StridedSlice layers
            # 1. Remove connections from input_data to StridedSlice ops
            out_data_nodes = []
            name_for_future_split = out_nodes[0].name
            for node in out_nodes:
                out_data_nodes.append(node.out_node())
                graph.remove_edge(input_data.id, node.id)
                graph.remove_edge(node.id, node.out_node().id)
                graph.remove_node(node.id)
                log.debug("Removed: {}".format(node.id))

            # 2. Create Split layer and reorder outputs
            split = SplitV(graph, dict(name=name_for_future_split + "/Split", axis=split_channel_dim,
                                       size_splits=size_splits))
            split.create_node_with_data(inputs=[input_data], data_nodes=final_data_nodes_list)
示例#8
0
def rnn_infer(node: Node, out_ports=None):
    """
    General infer function for RNN, GRU, LSTM layers.
    Assume that 0-port input of node is input data for recurrent layer and node have attrs:
    hidden_size,
    """
    if out_ports is None:
        out_ports = []

    # 1. Necessary checks (from ONNX specification)
    assert node.batch_dim <= 1
    assert node.sequence_dim <= 1
    assert node.batch_dim != node.sequence_dim
    assert node.direction in ['forward', 'reverse', 'bidirectional']

    if node.blobs_wrb:
        mark_input_bins(node, ['W', 'R', 'B'])
    else:
        mark_input_bins(node)

    # 2. Output shape calculations
    input_shape = node.in_node(0).shape
    assert len(input_shape) == 3

    # Reshape input nodes
    for port in [2, 3]:
        if port in node.in_nodes() and len(node.in_node(port).in_nodes()) > 0 and \
                'zero_shapes' in node.in_node(port).in_node():
            for i in node.in_node(port).in_node().zero_shapes:
                if node.in_node(port).shape[i] != input_shape[i]:
                    node.in_node(port).value = np.repeat(
                        node.in_node(port).value, input_shape[i], axis=i)
                    node.in_node(port).shape[i] = input_shape[i]

    out_shape = np.array([
        input_shape[node.sequence_dim], input_shape[node.batch_dim],
        node.hidden_size
    ],
                         dtype=np.int64)

    if node.batch_dim == 0:
        out_shape = np.array([
            input_shape[node.batch_dim], input_shape[node.sequence_dim],
            node.hidden_size
        ],
                             dtype=np.int64)

    num_directions = 2 if node.direction in ['bidirectional'] else 1
    if node.has_num_directions:
        if node.format == 'mxnet' and node.normalized is False:
            # In MXNet RNN layer return output with shape [seq_len, batch_size, hidden_size * num_directions]
            out_shape[-1] *= num_directions
        else:
            # ONNX-like, insert extra dimension to output shape for num_directions
            out_shape = np.insert(out_shape, 1, np.int64(num_directions))
    node.out_node(0).shape = out_shape

    # 3. Extra outputs for hidden/cell states shape calculations (optional)
    state_size = np.array([input_shape[node.batch_dim], node.hidden_size],
                          dtype=np.int64)
    if node.has_num_directions:
        state_size = np.insert(state_size, 0, num_directions)

    if node.multilayers:
        # For multilayer case state sizes from every layer will be concatenated by last axis
        num_layers = node.num_layers
        state_size[-1] *= num_layers

    for i in out_ports:
        # If node hasn't consumers for hidden/cells state -> create them
        if i not in node.out_nodes():
            data_node = Op._create_data_node(node.graph,
                                             name=node.node + '/ExtraOutput/' +
                                             str(i),
                                             attrs={'executable': True})
            if i not in node.out_ports():
                node.add_output_port(i)
            node.graph.add_edge(node.id, data_node.id, key=0, out=i)
            add_opoutput(node.graph, data_node.id, 0, False)
        else:
            data_node = node.out_node(i)
        data_node.shape = state_size.copy()
    def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
        # add Reshape for shrink_axis_mask
        if True in match['strided_slice']['shrink_axis_mask']:
            log.info("StridedSlice op with shrink mask '{}' has been detected".
                     format(match['strided_slice'].id))
            node = match['strided_slice']

            if len(node.in_nodes()) != 4 or len(node.out_nodes()) != 1:
                return

            shape_in = node.in_node().shape
            shape_out = node.out_node().shape
            dim = shape_out.copy()
            ss_shape = []
            k = 0

            # Don't permute reshape if channels were squeezed
            dont_permute = False
            if graph.graph['layout'] == 'NHWC' and node['shrink_axis_mask'][
                    -1] == True:
                dont_permute = True

            for i in range(0, len(node['shrink_axis_mask'])):
                if not node['shrink_axis_mask'][i]:
                    ss_shape.append(shape_out[k])
                    k = k + 1
                else:
                    node['shrink_axis_mask'][i] = False
                    ss_shape.append(1)

            out_node = node.out_node(0)

            # insert data node for StridedSlice
            data_node = Op._create_data_node(
                graph, node.name + "/Reshape_shrink_data", {'shape': ss_shape})
            attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
            graph.remove_edge(node.id, out_node.id)
            graph.add_edge(node.id, data_node.id, **attrs)

            # insert Reshape
            if dont_permute:
                reshape = Reshape(
                    graph,
                    dict(name=node.name + "/Reshape_shrink",
                         dim=np.array(dim, dtype=np.int64),
                         nchw_layout=True))
                reshape_data_node = reshape.create_node_with_data(
                    [data_node], reshape.attrs, data_nodes=[out_node])
                reshape_data_node['nchw_layout'] = True
            else:
                reshape = Reshape(
                    graph,
                    dict(name=node.name + "/Reshape_shrink",
                         dim=np.array(dim, dtype=np.int64)))
                reshape_data_node = reshape.create_node_with_data(
                    [data_node], reshape.attrs, data_nodes=[out_node])

        # add Reshape for new_axis_mask
        if True in match['strided_slice']['new_axis_mask']:
            log.info(
                "StridedSlice op with new axis mask '{}' has been detected".
                format(match['strided_slice'].id))
            node = match['strided_slice']

            if len(node.in_nodes()) != 4 or len(node.out_nodes()) != 1:
                return

            shape_in = node.in_node().shape
            shape_out = node.out_node().shape
            dim = shape_out.copy()
            ss_shape = []
            for i in range(0, len(node['new_axis_mask'])):
                if not node['new_axis_mask'][i]:
                    ss_shape.append(shape_out[i])
                else:
                    node['new_axis_mask'][i] = False

            out_node = node.out_node(0)
            # insert data node for StridedSlice
            data_node = Op._create_data_node(graph,
                                             node.name + "/Reshape_new_data",
                                             {'shape': ss_shape})
            attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
            graph.remove_edge(node.id, out_node.id)
            graph.add_edge(node.id, data_node.id, **attrs)

            # insert Reshape
            reshape = Reshape(
                graph,
                dict(name=node.name + "/Reshape_new",
                     dim=np.array(dim, dtype=np.int64)))
            reshape_data_node = reshape.create_node_with_data(
                [data_node], reshape.attrs, data_nodes=[out_node])