def replace_sub_graph(self, graph: Graph, match: dict): # node that is used to identify this pattern application instance for switching between supported # and not supported LSTMCell sub-graphs; this value will be searched in __class__.instances_supported_by_IE. anchor_node = match[__class__.anchor()] assert anchor_node.has_valid('name'), \ 'LSTMCell anchor node {} does\'t have attribute name; such nodes are not supported.' match['input_op'] = match['concat'].in_node(0) match['input_hidden_state'] = match['concat'].in_node(1) match['input_cell_state'] = match['mul_0'].in_node(0) \ if match['mul_0'].in_node(0).id != match['sigmoid_0'].id else match['mul_0'].in_node(1) pattern_edges = self.pattern()['edges'] pattern_edges.extend([('input_op', 'concat'), ('input_cell_state', 'mul_0'), ('input_hidden_state', 'concat')]) inputs = graph.get_inputs_with_ports( match, pattern_edges, __class__.inputs + __class__.extra_inputs) lstm_op = LSTMCell( graph, dict( name=match['concat'].name + '/LSTMCell', activations=None, )) lstm_node = lstm_op.create_node(inputs) lstm_node['old_infer'] = lstm_node.infer lstm_node.infer = __class__.infer # this node consumes one of the resulting LSTMCell outputs, # it should be removed before reconnecting the nodes, # otherwise it will be reconnected to the new cell output graph.remove_node(match['tanh_1'].id) for i, output in enumerate(__class__.outputs): match[output].replace_node(lstm_node, i) # Because of LSTMCell specification, this layer MUST have 2 outputs. # => we need to create fake consumers for LSTMCell # when this node haven't some outputs. for i in [0, 1]: if i not in lstm_node.out_nodes(): fake_output_node = Result( graph, dict(name=lstm_node.name + "/Output_{}".format(i))) fake_output_node.create_node(inputs=[lstm_node], edge_attrs={ 'out': i, 'in': 0 }) lstm_node['tf'] = True lstm_node['extra_inputs'] = { name: match[name].id for name in __class__.extra_inputs } lstm_node['inputs'] = { name: match[name].id for name in __class__.inputs }
def test_create_node(self): graph = build_graph(nodes, [('Op1', 'Op3', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('Op1', 'Op1')]}), ('Op2', 'Op3', {'in': 1, 'out': 0, 'fw_tensor_debug_info': [('Op2', 'Op2')]})]) graph.stage = 'front' input1 = Node(graph, 'Op1') input2 = Node(graph, 'Op2') inputs = [(input1, 0), (input2, 0)] lstm_op = LSTMCell(graph, dict(name='LSTMCell')) _ = lstm_op.create_node(inputs) self.assertTrue(input1.out_edge(0)['fw_tensor_debug_info'] == [('Op1', 'Op1')]) self.assertTrue(input2.out_edge(0)['fw_tensor_debug_info'] == [('Op2', 'Op2')])
def extract(cls, node): clip_value = 50 pb = node.parameters res = collect_until_whitespace(pb) if res == b'<CellClip>': clip_value = get_uint32(pb.read(4)) collect_until_token(pb, b'FM') gifo_x_weights, gifo_x_weights_shape = read_binary_matrix(pb, False) gifo_r_weights, gifo_r_weights_shape = read_binary_matrix(pb) gifo_biases = read_binary_vector(pb) input_gate_weights = read_binary_vector(pb) forget_gate_weights = read_binary_vector(pb) output_gate_weights = read_binary_vector(pb) projection_weights, projection_weights_shape = read_binary_matrix(pb) mapping_rule = { 'gifo_x_weights_shape': gifo_x_weights_shape, 'gifo_r_weights_shape': gifo_r_weights_shape, 'projection_weights_shape': projection_weights_shape, 'clip_value': clip_value, 'format': 'kaldi', } embed_input(mapping_rule, 1, 'gifo_x_weights', gifo_x_weights) embed_input(mapping_rule, 2, 'gifo_r_weights', gifo_r_weights) embed_input(mapping_rule, 3, 'gifo_biases', gifo_biases) embed_input(mapping_rule, 4, 'input_gate_weights', input_gate_weights) embed_input(mapping_rule, 5, 'forget_gate_weights', forget_gate_weights) embed_input(mapping_rule, 6, 'output_gate_weights', output_gate_weights) embed_input(mapping_rule, 7, 'projection_weights', projection_weights) LSTMCell.update_node_stat(node, mapping_rule) return cls.enabled
def replace_pattern(self, graph: Graph, match: dict): lstm = match['lstm'] # Build TensorIterator body first body = Graph(name=lstm.name + '/sub_graph') body.graph = graph.graph # 1. Input squeeze Reshape inputs = [Op._create_data_node(body, lstm.name + '/inport/' + str(inp), {'shape': lstm.in_node(inp).shape.copy(), 'value': lstm.in_node(inp).value.copy() if lstm.in_node(inp).value is not None and inp in [1, 2] else None}) for inp in [0, 4, 5, 1, 2]] # X, WR, B, h_init, c_init inputs[0].shape[lstm.sequence_dim] = 1 input_squeeze = Squeeze(body, dict(name=lstm.name + '/input_squeeze', internal_layer_id=0)) squeeze_dim_data = Const(body, {'name': lstm.name + '/input_squeeze_dim', 'value': [lstm.sequence_dim]}).create_node_with_data() inputs[0] = input_squeeze.create_node_with_data([inputs[0], squeeze_dim_data], edge_attrs=[{'internal_port_id': 0}]) # 2. Output unsqueeze Reshape outputs = [Op._create_data_node(body, lstm.name + '/outport/' + str(out), {'shape': lstm.out_node(out).shape.copy() if out in lstm.out_nodes() else lstm.in_node(4).shape.copy()}) for out in [0, 1]] for out in outputs: add_opoutput(body, out.id, 0, False) outputs[0].shape = shape_delete(outputs[0].shape, lstm.sequence_dim) output_unsqueeze = Unsqueeze(body, dict(name=lstm.name + 'output_unsqueeze', internal_layer_id=2)) unsqueeze_dim_data = Const(body, {'name': lstm.name + '/output_unsqueeze_dim', 'value': [lstm.sequence_dim]}).create_node_with_data() # 3. LSTMCell lstm_cell_op = LSTMCell(body, dict(hidden_size=lstm.hidden_size, activations=lstm.activations, activation_alpha=lstm.activation_alpha, activation_beta=lstm.activation_beta, clip=lstm.clip, input_forget=lstm.input_forget, name=lstm.name + '/LSTMCell', internal_layer_id=1)) lstm_cell_node = lstm_cell_op.create_node_with_data(inputs, data_nodes=outputs, edge_attrs=[{}, {'internal_port_id': 1}, {'internal_port_id': 2}, {'bin': 'weights'}, {'bin': 'biases'}]) lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4 lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5 lstm_cell_node[0] = output_unsqueeze.create_node_with_data([lstm_cell_node[0], unsqueeze_dim_data]) lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3 add_opoutput(body, lstm_cell_node[0].id, 0, False) # 4. TensorIterator layer creating assert lstm.direction in ['forward', 'reverse'] if lstm.direction == 'forward': stride = 1 start = None end = None else: assert lstm.direction == 'reverse' stride = -1 start = -1 end = 0 output_port_map = [{ 'external_port_id': 3, 'internal_layer_id': 2, 'internal_port_id': 3, 'axis': lstm.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }] # Adding h_state, c_state to outputs if len(lstm.out_nodes()) == 3: output_port_map.extend([{ 'external_port_id': 4, 'internal_layer_id': 1, 'internal_port_id': 4, }, { 'external_port_id': 5, 'internal_layer_id': 1, 'internal_port_id': 5, }]) ti_op = TensorIterator(graph, { 'name': lstm.name + '/TensorIterator', 'body': body, 'in_ports_count': 3, 'out_ports_count': len(lstm.out_nodes()), 'input_port_map': [ { 'external_port_id': 0, 'internal_layer_id': 0, 'internal_port_id': 0, 'axis': lstm.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }, { 'external_port_id': 1, 'internal_layer_id': 1, 'internal_port_id': 1, }, { 'external_port_id': 2, 'internal_layer_id': 1, 'internal_port_id': 2, }, ], 'output_port_map': output_port_map, 'back_edges': [ { 'from_layer': 1, 'from_port': 4, 'to_layer': 1, 'to_port': 1, }, { 'from_layer': 1, 'from_port': 5, 'to_layer': 1, 'to_port': 2, }, ] }) assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \ "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id) outs = ti_op.create_node_with_data([lstm.in_node(i) for i in [0, 4, 5]], # X, h_init, c_init data_nodes=[lstm.out_node(i) for i in range(len(lstm.out_nodes()))], edge_attrs=[{'external_port_id': 0}, {'external_port_id': 1}, {'external_port_id': 2}]) if not isinstance(outs, list): outs = list([outs]) graph.remove_node(lstm.id) outs[0].in_edge(0)['external_port_id'] = 3 for i, out in enumerate(outs[1:]): external_port_id = 4 + i out.in_edge()['external_port_id'] = external_port_id ti = outs[0].in_node() TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti) TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti) TensorIterator.normalize_internal_ids(ti)