Exemplo n.º 1
0
    def unsqueeze_num_directions(graph: Graph, match: dict):
        """ Assuming considered LSTM/GRU/RNN node should has num_directions in output shape and add Unsqueeze
            to match it.
        """

        rnn_layer = match['rnn_layer']
        rnn_layer_name = rnn_layer.soft_get('name', rnn_layer.id)
        # num_directions is at 1st position in output shape, and in 0st position in hidden and cell states
        # please refer to docs in this transform

        direction_dim = [1, 0, 0]  # index of dimension with direction index
        for i in rnn_layer.out_nodes():
            old_data_node = rnn_layer.out_node(i)
            old_shape = old_data_node.shape.copy()
            new_shape = shape_delete(old_shape, direction_dim[i])

            data = Op._create_data_node(graph, name=rnn_layer.name + '/Out/{}/'.format(i), attrs={'shape': new_shape})
            graph.remove_edge(rnn_layer.id, old_data_node.id)
            graph.add_edge(rnn_layer.id, data.id, key=0, out=i)

            unsqueeze = Unsqueeze(graph, dict())

            unsqueeze_dim_data = Const(graph, {'name': rnn_layer.name + '/UnsqueezeNumDirections/{}/Dim'.format(i),
                                               'value': int64_array([direction_dim[i]])}).create_node_with_data()

            unsqueeze.create_node_with_data([data, unsqueeze_dim_data],
                                            dict(name=rnn_layer_name + '/UnsqueezeNumDirections/{}'.format(i)),
                                            data_nodes=[old_data_node])
    def replace_pattern(self, graph: Graph, match: dict):
        lstm = match['lstm']

        # Build TensorIterator body first
        body = Graph(name=lstm.name + '/sub_graph')
        body.graph = graph.graph

        # 1. Input squeeze Reshape
        inputs = [Op._create_data_node(body, lstm.name + '/inport/' + str(inp),
                                       {'shape': lstm.in_node(inp).shape.copy(),
                                        'value': lstm.in_node(inp).value.copy()
                                        if lstm.in_node(inp).value is not None and inp in [1, 2] else None})
                  for inp in [0, 4, 5, 1, 2]]  # X, WR, B, h_init, c_init

        inputs[0].shape[lstm.sequence_dim] = 1
        input_squeeze = Squeeze(body, dict(name=lstm.name + '/input_squeeze', internal_layer_id=0))
        squeeze_dim_data = Const(body, {'name': lstm.name + '/input_squeeze_dim',
                                        'value': [lstm.sequence_dim]}).create_node_with_data()
        inputs[0] = input_squeeze.create_node_with_data([inputs[0], squeeze_dim_data],
                                                        edge_attrs=[{'internal_port_id': 0}])

        # 2. Output unsqueeze Reshape
        outputs = [Op._create_data_node(body, lstm.name + '/outport/' + str(out),
                                        {'shape': lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
                                        else lstm.in_node(4).shape.copy()}) for out in [0, 1]]
        for out in outputs:
            add_opoutput(body, out.id, 0, False)

        outputs[0].shape = shape_delete(outputs[0].shape, lstm.sequence_dim)
        output_unsqueeze = Unsqueeze(body, dict(name=lstm.name + 'output_unsqueeze', internal_layer_id=2))
        unsqueeze_dim_data = Const(body, {'name': lstm.name + '/output_unsqueeze_dim',
                                          'value': [lstm.sequence_dim]}).create_node_with_data()

        # 3. LSTMCell
        lstm_cell_op = LSTMCell(body, dict(hidden_size=lstm.hidden_size,
                                           activations=lstm.activations,
                                           activation_alpha=lstm.activation_alpha,
                                           activation_beta=lstm.activation_beta,
                                           clip=lstm.clip,
                                           input_forget=lstm.input_forget,
                                           name=lstm.name + '/LSTMCell',
                                           internal_layer_id=1))
        lstm_cell_node = lstm_cell_op.create_node_with_data(inputs, data_nodes=outputs,
                                                            edge_attrs=[{}, {'internal_port_id': 1},
                                                                        {'internal_port_id': 2}, {'bin': 'weights'},
                                                                        {'bin': 'biases'}])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4
        lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
        lstm_cell_node[0] = output_unsqueeze.create_node_with_data([lstm_cell_node[0], unsqueeze_dim_data])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
        add_opoutput(body, lstm_cell_node[0].id, 0, False)

        # 4. TensorIterator layer creating
        assert lstm.direction in ['forward', 'reverse']
        if lstm.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert lstm.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,

            'axis': lstm.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        # Adding h_state, c_state to outputs
        if len(lstm.out_nodes()) == 3:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }, {
                'external_port_id': 5,
                'internal_layer_id': 1,
                'internal_port_id': 5,
            }])

        ti_op = TensorIterator(graph, {
            'name': lstm.name + '/TensorIterator',
            'body': body,
            'in_ports_count': 3,
            'out_ports_count': len(lstm.out_nodes()),

            'input_port_map': [
                {
                    'external_port_id': 0,
                    'internal_layer_id': 0,
                    'internal_port_id': 0,

                    'axis': lstm.sequence_dim,
                    'stride': stride,
                    'start': start,
                    'end': end,
                    'part_size': 1,
                },
                {
                    'external_port_id': 1,
                    'internal_layer_id': 1,
                    'internal_port_id': 1,
                },
                {
                    'external_port_id': 2,
                    'internal_layer_id': 1,
                    'internal_port_id': 2,
                },
            ],

            'output_port_map': output_port_map,

            'back_edges': [
                {
                    'from_layer': 1,
                    'from_port': 4,
                    'to_layer': 1,
                    'to_port': 1,
                },
                {
                    'from_layer': 1,
                    'from_port': 5,
                    'to_layer': 1,
                    'to_port': 2,
                },
            ]
        })

        assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
            "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)

        outs = ti_op.create_node_with_data([lstm.in_node(i) for i in [0, 4, 5]],  # X, h_init, c_init
                                           data_nodes=[lstm.out_node(i) for i in range(len(lstm.out_nodes()))],
                                           edge_attrs=[{'external_port_id': 0}, {'external_port_id': 1},
                                                       {'external_port_id': 2}])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(lstm.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id

        ti = outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)
    def replace_pattern(self, graph: Graph, match: dict):
        if match['rnn_layer']['op'] == 'LSTM':
            return

        rnn_layer = match['rnn_layer']

        # Build TensorIterator body first
        body = Graph(name=rnn_layer.name + '/sub_graph')
        body.graph = graph.graph

        # 1. Input squeeze Reshape
        inputs = [
            Op._create_data_node(
                body, rnn_layer.name + '/inport/' + str(inp), {
                    'shape':
                    rnn_layer.in_node(inp).shape.copy(),
                    'value':
                    rnn_layer.in_node(inp).value.copy()
                    if rnn_layer.in_node(inp).value is not None
                    and inp in [1, 2] else None
                }) for inp in [0, 4, 1, 2]
        ]  # X, h_init, WR, B

        inputs[0].shape[rnn_layer.sequence_dim] = 1
        input_squeeze = Squeeze(
            body,
            dict(name=rnn_layer.name + '/input_squeeze', internal_layer_id=0))
        input_squeeze_dim = Const(
            body,
            dict(name=rnn_layer.name + '/input_squeeze_dim',
                 value=rnn_layer.sequence_dim)).create_node_with_data()
        inputs[0] = input_squeeze.create_node_with_data(
            [inputs[0], input_squeeze_dim],
            edge_attrs=[{
                'internal_port_id': 0
            }])

        # 2. Output unsqueeze Reshape
        outputs = [
            Op._create_data_node(
                body, rnn_layer.name + '/outport/' + str(out), {
                    'shape':
                    rnn_layer.out_node(out).shape.copy()
                    if out in rnn_layer.out_nodes() else None
                }) for out in [0]
        ]
        for out in outputs:
            add_opoutput(body, out.id, 0, False)

        outputs[0].shape = shape_delete(outputs[0].shape,
                                        rnn_layer.sequence_dim)
        output_unsqueeze_dim = Const(
            body,
            dict(name=rnn_layer.name + '/output_unsqueeze_dim',
                 value=rnn_layer.sequence_dim)).create_node_with_data()
        output_unsqueeze = Unsqueeze(
            body,
            dict(name=rnn_layer.name + '/output_unsqueeze/',
                 internal_layer_id=2))

        additional_attrs = dict(activations=rnn_layer.activations,
                                activation_alpha=rnn_layer.activation_alpha,
                                activation_beta=rnn_layer.activation_beta,
                                clip=rnn_layer.clip)
        if rnn_layer.op == 'GRU':
            additional_attrs[
                'linear_before_reset'] = rnn_layer.linear_before_reset

        # 3. ***Cell
        rnn_cell_op = self.get_rnn_cell(rnn_layer['op'])(
            body,
            dict(hidden_size=rnn_layer.hidden_size,
                 name=rnn_layer.name + '/{}Cell'.format(rnn_layer.op),
                 **additional_attrs,
                 internal_layer_id=1))

        gru_cell = rnn_cell_op.create_node_with_data(inputs,
                                                     data_nodes=outputs,
                                                     edge_attrs=[{}, {
                                                         'internal_port_id':
                                                         1
                                                     }, {
                                                         'internal_port_id':
                                                         2
                                                     }, {
                                                         'bin':
                                                         'weights'
                                                     }, {
                                                         'bin':
                                                         'biases'
                                                     }])

        # internal ports for outputs of cell
        gru_cell.in_node().out_edge(0)['internal_port_id'] = 4  # h_state

        gru_cell = output_unsqueeze.create_node_with_data(
            [gru_cell, output_unsqueeze_dim])
        gru_cell.in_node().out_edge(0)['internal_port_id'] = 3
        add_opoutput(body, gru_cell.id, 0, False)

        # 4. TensorIterator layer creating
        assert rnn_layer.direction in ['forward', 'reverse']
        if rnn_layer.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert rnn_layer.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        # stacked h_state
        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,
            'axis': rnn_layer.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        # Adding last h_state to outputs
        if len(rnn_layer.out_nodes()) == 2:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }])

        ti_op = TensorIterator(
            graph,
            {
                'name':
                rnn_layer.name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                4,
                'out_ports_count':
                len(rnn_layer.out_nodes()),
                'input_port_map': [
                    {
                        'external_port_id': 0,
                        'internal_layer_id': 0,
                        'internal_port_id': 0,
                        'axis': rnn_layer.sequence_dim,
                        'stride': stride,
                        'start': start,
                        'end': end,
                        'part_size': 1,
                    },
                    {
                        'external_port_id': 1,
                        'internal_layer_id': 1,
                        'internal_port_id': 1,
                    },
                ],
                'output_port_map':
                output_port_map,
                # only for h state
                'back_edges': [
                    {
                        'from_layer': 1,
                        'from_port': 4,
                        'to_layer': 1,
                        'to_port': 1,
                    },
                ]
            })

        assert sorted(rnn_layer.out_nodes().keys()) == list(range(len(rnn_layer.out_nodes()))), \
            "There are gaps in output ports of GRUSequence operation. Node {}".format(rnn_layer.id)

        outs = ti_op.create_node_with_data(
            [rnn_layer.in_node(i) for i in [0, 4]],  # X, h_init
            data_nodes=[
                rnn_layer.out_node(i)
                for i in range(len(rnn_layer.out_nodes()))
            ],
            edge_attrs=[{
                'external_port_id': 0
            }, {
                'external_port_id': 1
            }])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(rnn_layer.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id

        ti = outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)