Ejemplo n.º 1
0
    def replace_sub_graph(self, graph: Graph, match: dict):

        # node that is used to identify this pattern application instance for switching between supported
        # and not supported LSTMCell sub-graphs; this value will be searched in __class__.instances_supported_by_IE.
        anchor_node = match[__class__.anchor()]
        assert anchor_node.has_valid('name'), \
            'LSTMCell anchor node {} does\'t have attribute name; such nodes are not supported.'

        match['input_op'] = match['concat'].in_node(0)
        match['input_hidden_state'] = match['concat'].in_node(1)
        match['input_cell_state'] = match['mul_0'].in_node(0) \
            if match['mul_0'].in_node(0).id != match['sigmoid_0'].id else match['mul_0'].in_node(1)

        pattern_edges = self.pattern()['edges']
        pattern_edges.extend([('input_op', 'concat'),
                              ('input_cell_state', 'mul_0'),
                              ('input_hidden_state', 'concat')])
        inputs = graph.get_inputs_with_ports(
            match, pattern_edges, __class__.inputs + __class__.extra_inputs)

        lstm_op = LSTMCell(
            graph,
            dict(
                name=match['concat'].name + '/LSTMCell',
                activations=None,
            ))
        lstm_node = lstm_op.create_node(inputs)
        lstm_node['old_infer'] = lstm_node.infer
        lstm_node.infer = __class__.infer

        # this node consumes one of the resulting LSTMCell outputs,
        # it should be removed before reconnecting the nodes,
        # otherwise it will be reconnected to the new cell output
        graph.remove_node(match['tanh_1'].id)

        for i, output in enumerate(__class__.outputs):
            match[output].replace_node(lstm_node, i)

        # Because of LSTMCell specification, this layer MUST have 2 outputs.
        # => we need to create fake consumers for LSTMCell
        # when this node haven't some outputs.
        for i in [0, 1]:
            if i not in lstm_node.out_nodes():
                fake_output_node = Result(
                    graph, dict(name=lstm_node.name + "/Output_{}".format(i)))
                fake_output_node.create_node(inputs=[lstm_node],
                                             edge_attrs={
                                                 'out': i,
                                                 'in': 0
                                             })

        lstm_node['tf'] = True
        lstm_node['extra_inputs'] = {
            name: match[name].id
            for name in __class__.extra_inputs
        }
        lstm_node['inputs'] = {
            name: match[name].id
            for name in __class__.inputs
        }
Ejemplo n.º 2
0
    def test_create_node(self):
        graph = build_graph(nodes, [('Op1', 'Op3', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('Op1', 'Op1')]}),
                                    ('Op2', 'Op3', {'in': 1, 'out': 0, 'fw_tensor_debug_info': [('Op2', 'Op2')]})])
        graph.stage = 'front'
        input1 = Node(graph, 'Op1')
        input2 = Node(graph, 'Op2')
        inputs = [(input1, 0), (input2, 0)]

        lstm_op = LSTMCell(graph, dict(name='LSTMCell'))
        _ = lstm_op.create_node(inputs)

        self.assertTrue(input1.out_edge(0)['fw_tensor_debug_info'] == [('Op1', 'Op1')])
        self.assertTrue(input2.out_edge(0)['fw_tensor_debug_info'] == [('Op2', 'Op2')])