def replace_sub_graph(self, graph: Graph, match: dict): # node that is used to identify this pattern application instance for switching between supported # and not supported LSTMCell sub-graphs; this value will be searched in __class__.instances_supported_by_IE. anchor_node = match[__class__.anchor()] assert anchor_node.has_valid('name'), \ 'LSTMCell anchor node {} does\'t have attribute name; such nodes are not supported.' match['input_op'] = match['concat'].in_node(0) match['input_hidden_state'] = match['concat'].in_node(1) match['input_cell_state'] = match['mul_0'].in_node(0) \ if match['mul_0'].in_node(0).id != match['sigmoid_0'].id else match['mul_0'].in_node(1) pattern_edges = self.pattern()['edges'] pattern_edges.extend([('input_op', 'concat'), ('input_cell_state', 'mul_0'), ('input_hidden_state', 'concat')]) inputs = graph.get_inputs_with_ports( match, pattern_edges, __class__.inputs + __class__.extra_inputs) lstm_op = LSTMCell( graph, dict( name=match['concat'].name + '/LSTMCell', activations=None, )) lstm_node = lstm_op.create_node(inputs) lstm_node['old_infer'] = lstm_node.infer lstm_node.infer = __class__.infer # this node consumes one of the resulting LSTMCell outputs, # it should be removed before reconnecting the nodes, # otherwise it will be reconnected to the new cell output graph.remove_node(match['tanh_1'].id) for i, output in enumerate(__class__.outputs): match[output].replace_node(lstm_node, i) # Because of LSTMCell specification, this layer MUST have 2 outputs. # => we need to create fake consumers for LSTMCell # when this node haven't some outputs. for i in [0, 1]: if i not in lstm_node.out_nodes(): fake_output_node = Result( graph, dict(name=lstm_node.name + "/Output_{}".format(i))) fake_output_node.create_node(inputs=[lstm_node], edge_attrs={ 'out': i, 'in': 0 }) lstm_node['tf'] = True lstm_node['extra_inputs'] = { name: match[name].id for name in __class__.extra_inputs } lstm_node['inputs'] = { name: match[name].id for name in __class__.inputs }
def test_create_node(self): graph = build_graph(nodes, [('Op1', 'Op3', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('Op1', 'Op1')]}), ('Op2', 'Op3', {'in': 1, 'out': 0, 'fw_tensor_debug_info': [('Op2', 'Op2')]})]) graph.stage = 'front' input1 = Node(graph, 'Op1') input2 = Node(graph, 'Op2') inputs = [(input1, 0), (input2, 0)] lstm_op = LSTMCell(graph, dict(name='LSTMCell')) _ = lstm_op.create_node(inputs) self.assertTrue(input1.out_edge(0)['fw_tensor_debug_info'] == [('Op1', 'Op1')]) self.assertTrue(input2.out_edge(0)['fw_tensor_debug_info'] == [('Op2', 'Op2')])