def extract(node):
        clip_value = 50
        pb = node.parameters
        res = collect_until_whitespace(pb)
        if res == b'<CellClip>':
            clip_value = get_uint32(pb.read(4))
        collect_until_token(pb, b'FM')
        gifo_x_weights, gifo_x_weights_shape = read_binary_matrix(pb, False)
        gifo_r_weights, gifo_r_weights_shape = read_binary_matrix(pb)
        gifo_biases = read_binary_vector(pb)
        input_gate_weights = read_binary_vector(pb)
        forget_gate_weights = read_binary_vector(pb)
        output_gate_weights = read_binary_vector(pb)

        projection_weights, projection_weights_shape = read_binary_matrix(pb)

        mapping_rule = {'gifo_x_weights_shape': gifo_x_weights_shape,
                        'gifo_r_weights_shape': gifo_r_weights_shape,
                        'projection_weights_shape': projection_weights_shape,
                        'clip_value': clip_value
                        }

        embed_input(mapping_rule, 1, 'gifo_x_weights', gifo_x_weights)
        embed_input(mapping_rule, 2, 'gifo_r_weights', gifo_r_weights)
        embed_input(mapping_rule, 3, 'gifo_biases', gifo_biases)
        embed_input(mapping_rule, 4, 'input_gate_weights', input_gate_weights)
        embed_input(mapping_rule, 5, 'forget_gate_weights', forget_gate_weights)
        embed_input(mapping_rule, 6, 'output_gate_weights', output_gate_weights)
        embed_input(mapping_rule, 7, 'projection_weights', projection_weights)

        LSTMCell.update_node_stat(node, mapping_rule)
        return __class__.enabled
    def replace_sub_graph(self, graph: Graph, match: dict):

        # node that is used to identify this pattern application instance for switching between supported
        # and not supported LSTMCell sub-graphs; this value will be searched in __class__.instances_supported_by_IE.
        anchor_node = match[__class__.anchor()]
        assert anchor_node.has_valid('name'), \
            'LSTMCell anchor node {} does\'t have attribute name; such nodes are not supported.'

        match['input_op'] = match['concat'].in_node(0)
        match['input_hidden_state'] = match['concat'].in_node(1)
        match['input_cell_state'] = match['mul_0'].in_node(0) \
            if match['mul_0'].in_node(0).id != match['sigmoid_0'].id else match['mul_0'].in_node(1)

        pattern_edges = self.pattern()['edges']
        pattern_edges.extend([('input_op', 'concat'),
                              ('input_cell_state', 'mul_0'),
                              ('input_hidden_state', 'concat')])
        inputs = graph.get_inputs_with_ports(
            match, pattern_edges, __class__.inputs + __class__.extra_inputs)

        lstm_op = LSTMCell(
            graph,
            dict(
                name=match['concat'].name + '/LSTMCell',
                activations=None,
            ))
        lstm_node = lstm_op.create_node(inputs)
        lstm_node['old_infer'] = lstm_node.infer
        lstm_node.infer = __class__.infer

        # this node consumes one of the resulting LSTMCell outputs,
        # it should be removed before reconnecting the nodes,
        # otherwise it will be reconnected to the new cell output
        graph.remove_node(match['tanh_1'].id)

        for i, output in enumerate(__class__.outputs):
            match[output].replace_node(lstm_node, i)

        # Because of LSTMCell specification, this layer MUST have 2 outputs.
        # => we need to create fake consumers for LSTMCell
        # when this node haven't some outputs.
        for i in [0, 1]:
            if i not in lstm_node.out_nodes():
                fake_output_node = Result(
                    graph, dict(name=lstm_node.name + "/Output_{}".format(i)))
                fake_output_node.create_node(inputs=[lstm_node],
                                             edge_attrs={
                                                 'out': i,
                                                 'in': 0
                                             })

        lstm_node['tf'] = True
        lstm_node['extra_inputs'] = {
            name: match[name].id
            for name in __class__.extra_inputs
        }
        lstm_node['inputs'] = {
            name: match[name].id
            for name in __class__.inputs
        }
Beispiel #3
0
    def test_create_node(self):
        graph = build_graph(nodes, [('Op1', 'Op3', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('Op1', 'Op1')]}),
                                    ('Op2', 'Op3', {'in': 1, 'out': 0, 'fw_tensor_debug_info': [('Op2', 'Op2')]})])
        graph.stage = 'front'
        input1 = Node(graph, 'Op1')
        input2 = Node(graph, 'Op2')
        inputs = [(input1, 0), (input2, 0)]

        lstm_op = LSTMCell(graph, dict(name='LSTMCell'))
        _ = lstm_op.create_node(inputs)

        self.assertTrue(input1.out_edge(0)['fw_tensor_debug_info'] == [('Op1', 'Op1')])
        self.assertTrue(input2.out_edge(0)['fw_tensor_debug_info'] == [('Op2', 'Op2')])
Beispiel #4
0
    def replace_pattern(self, graph: Graph, match: dict):
        lstm = match['lstm']

        # Build TensorIterator body first
        body = Graph(name=lstm.name + '/sub_graph')
        body.graph = graph.graph

        # 1. Input squeeze Reshape
        inputs = [
            Op._create_data_node(
                body, lstm.name + '/inport/' + str(inp), {
                    'shape':
                    lstm.in_node(inp).shape.copy(),
                    'value':
                    lstm.in_node(inp).value.copy() if lstm.in_node(inp).value
                    is not None and inp in [1, 2] else None
                }) for inp in [0, 4, 5, 1, 2]
        ]  # X, WR, B, h_init, c_init

        inputs[0].shape[lstm.sequence_dim] = 1
        reshape_dim = inputs[0].shape.copy()
        reshape_dim[lstm.batch_dim] = -1
        reshape_dim = np.delete(reshape_dim, lstm.sequence_dim)
        input_squeeze = Reshape(
            body, dict(name=lstm.name + '/input_squeeze', internal_layer_id=0))
        squeeze_dim_data = Const(body, {
            'name': lstm.name + '/input_squeeze_dim',
            'value': reshape_dim
        }).create_node_with_data()
        inputs[0] = input_squeeze.create_node_with_data(
            [inputs[0], squeeze_dim_data],
            edge_attrs=[{
                'internal_port_id': 0
            }])

        # 2. Output unsqueeze Reshape
        outputs = [
            Op._create_data_node(
                body, lstm.name + '/outport/' + str(out), {
                    'shape':
                    lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
                    else lstm.in_node(4).shape.copy()
                }) for out in [0, 1]
        ]
        for out in outputs:
            add_opoutput(body, out.id, 0, False)

        unsqueezed_output_shape = outputs[0].shape.copy()
        unsqueezed_output_shape[lstm.sequence_dim] = 1
        squeezed_output_shape = np.delete(unsqueezed_output_shape,
                                          lstm.sequence_dim)
        outputs[0].shape = squeezed_output_shape
        unsqueezed_output_shape[lstm.batch_dim] = -1
        output_unsqueeze = Reshape(
            body, dict(name=lstm.name + 'output_unsqueeze',
                       internal_layer_id=2))
        unsqueeze_dim_data = Const(
            body, {
                'name': lstm.name + '/output_unsqueeze_dim',
                'value': unsqueezed_output_shape
            }).create_node_with_data()

        # 3. LSTMCell
        lstm_cell_op = LSTMCell(
            body,
            dict(hidden_size=lstm.hidden_size,
                 activations=lstm.activations,
                 activation_alpha=lstm.activation_alpha,
                 activation_beta=lstm.activation_beta,
                 clip=lstm.clip,
                 input_forget=lstm.input_forget,
                 name=lstm.name + '/LSTMCell',
                 internal_layer_id=1))
        lstm_cell_node = lstm_cell_op.create_node_with_data(
            inputs,
            data_nodes=outputs,
            edge_attrs=[{}, {
                'internal_port_id': 1
            }, {
                'internal_port_id': 2
            }, {
                'bin': 'weights'
            }, {
                'bin': 'biases'
            }])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4
        lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
        lstm_cell_node[0] = output_unsqueeze.create_node_with_data(
            [lstm_cell_node[0], unsqueeze_dim_data])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
        add_opoutput(body, lstm_cell_node[0].id, 0, False)

        # 4. TensorIterator layer creating
        assert lstm.direction in ['forward', 'reverse']
        if lstm.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert lstm.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,
            'axis': lstm.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        # Adding h_state, c_state to outputs
        if len(lstm.out_nodes()) == 3:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }, {
                'external_port_id': 5,
                'internal_layer_id': 1,
                'internal_port_id': 5,
            }])

        ti_op = TensorIterator(
            graph, {
                'name':
                lstm.name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                3,
                'out_ports_count':
                len(lstm.out_nodes()),
                'input_port_map': [
                    {
                        'external_port_id': 0,
                        'internal_layer_id': 0,
                        'internal_port_id': 0,
                        'axis': lstm.sequence_dim,
                        'stride': stride,
                        'start': start,
                        'end': end,
                        'part_size': 1,
                    },
                    {
                        'external_port_id': 1,
                        'internal_layer_id': 1,
                        'internal_port_id': 1,
                    },
                    {
                        'external_port_id': 2,
                        'internal_layer_id': 1,
                        'internal_port_id': 2,
                    },
                ],
                'output_port_map':
                output_port_map,
                'back_edges': [
                    {
                        'from_layer': 1,
                        'from_port': 4,
                        'to_layer': 1,
                        'to_port': 1,
                    },
                    {
                        'from_layer': 1,
                        'from_port': 5,
                        'to_layer': 1,
                        'to_port': 2,
                    },
                ]
            })

        assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
            "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)

        outs = ti_op.create_node_with_data(
            [lstm.in_node(i) for i in [0, 4, 5]],  # X, h_init, c_init
            data_nodes=[
                lstm.out_node(i) for i in range(len(lstm.out_nodes()))
            ],
            edge_attrs=[{
                'external_port_id': 0
            }, {
                'external_port_id': 1
            }, {
                'external_port_id': 2
            }])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(lstm.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id
Beispiel #5
0
    def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
        lstm = match['lstm']

        # Build TensorIterator body first
        body = nx.MultiDiGraph(name=lstm.name + '/sub_graph',
                               layout=graph.graph['layout'])
        inputs = [
            Op._create_data_node(
                body, lstm.name + '/inport/' + str(inp), {
                    'shape':
                    lstm.in_node(inp).shape.copy(),
                    'value':
                    lstm.in_node(inp).value.copy() if lstm.in_node(inp).value
                    is not None and inp in [1, 2] else None
                }) for inp in [0, 3, 4, 1, 2]
        ]
        inputs[0].shape[lstm.sequence_dim] = 1
        reshape_dim = inputs[0].shape.copy()
        reshape_dim[lstm.batch_dim] = -1
        reshape_dim = np.delete(reshape_dim, lstm.sequence_dim)
        input_squeeze = Reshape(
            body,
            dict(name=lstm.name + '/input_squeeze',
                 internal_layer_id=0,
                 dim=reshape_dim))
        inputs[0] = input_squeeze.create_node_with_data([inputs[0]],
                                                        edge_attrs=[{
                                                            'internal_port_id':
                                                            0
                                                        }])
        lstm_cell_op = LSTMCell(
            body,
            dict(hidden_size=match['lstm'].hidden_size,
                 name=lstm.name + '/LSTMCell',
                 internal_layer_id=1))
        outputs = [
            Op._create_data_node(
                body, lstm.name + '/outport/' + str(out), {
                    'shape':
                    lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
                    else lstm.in_node(3).shape.copy(),
                    'is_output':
                    True
                }) for out in [0, 1]
        ]
        unsqueezed_output_shape = outputs[0].shape.copy()
        unsqueezed_output_shape[lstm.sequence_dim] = 1
        squeezed_output_shape = np.delete(unsqueezed_output_shape,
                                          lstm.sequence_dim)
        outputs[0].shape = squeezed_output_shape
        unsqueezed_output_shape[lstm.batch_dim] = -1
        output_unsqueeze = Reshape(
            body,
            dict(name=lstm.name + 'output_unsqueeze',
                 dim=unsqueezed_output_shape,
                 internal_layer_id=2))
        # TODO edge attributes should be assigned by the op itself
        lstm_cell_node = lstm_cell_op.create_node_with_data(
            inputs,
            data_nodes=outputs,
            edge_attrs=[{}, {
                'internal_port_id': 1
            }, {
                'internal_port_id': 2
            }, {
                'bin': 'weights'
            }, {
                'bin': 'biases'
            }])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4
        lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
        lstm_cell_node[0] = output_unsqueeze.create_node_with_data(
            [lstm_cell_node[0]])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
        lstm_cell_node[0]['is_output'] = True

        assert lstm.direction in ['forward', 'reverse']
        if lstm.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert lstm.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,
            'axis': lstm.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        if len(lstm.out_nodes()) == 3:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }, {
                'external_port_id': 5,
                'internal_layer_id': 1,
                'internal_port_id': 5,
            }])

        ti_op = TensorIterator(
            graph, {
                'name':
                lstm.name + '/TensorIterator',
                'body':
                body,
                'input_port_map': [
                    {
                        'external_port_id': 0,
                        'internal_layer_id': 0,
                        'internal_port_id': 0,
                        'axis': lstm.sequence_dim,
                        'stride': stride,
                        'start': start,
                        'end': end,
                        'part_size': 1,
                    },
                    {
                        'external_port_id': 1,
                        'internal_layer_id': 1,
                        'internal_port_id': 1,
                    },
                    {
                        'external_port_id': 2,
                        'internal_layer_id': 1,
                        'internal_port_id': 2,
                    },
                ],
                'output_port_map':
                output_port_map,
                'back_edges': [
                    {
                        'from_layer': 1,
                        'from_port': 4,
                        'to_layer': 1,
                        'to_port': 1,
                    },
                    {
                        'from_layer': 1,
                        'from_port': 5,
                        'to_layer': 1,
                        'to_port': 2,
                    },
                ]
            })

        assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
            "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)
        outs = ti_op.create_node_with_data(
            [lstm.in_node(i) for i in [0, 3, 4]],
            data_nodes=[
                lstm.out_node(i) for i in range(len(lstm.out_nodes()))
            ],
            edge_attrs=[{
                'external_port_id': 0
            }, {
                'external_port_id': 1
            }, {
                'external_port_id': 2
            }])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(lstm.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id