def unsqueeze_num_directions(graph: Graph, match: dict): """ Assuming considered LSTM/GRU/RNN node should has num_directions in output shape and add Unsqueeze to match it. """ rnn_layer = match['rnn_layer'] rnn_layer_name = rnn_layer.soft_get('name', rnn_layer.id) # num_directions is at 1st position in output shape, and in 0st position in hidden and cell states # please refer to docs in this transform direction_dim = [1, 0, 0] # index of dimension with direction index for i in rnn_layer.out_nodes(): old_data_node = rnn_layer.out_node(i) old_shape = old_data_node.shape.copy() new_shape = shape_delete(old_shape, direction_dim[i]) data = Op._create_data_node(graph, name=rnn_layer.name + '/Out/{}/'.format(i), attrs={'shape': new_shape}) graph.remove_edge(rnn_layer.id, old_data_node.id) graph.add_edge(rnn_layer.id, data.id, key=0, out=i) unsqueeze = Unsqueeze(graph, dict()) unsqueeze_dim_data = Const(graph, {'name': rnn_layer.name + '/UnsqueezeNumDirections/{}/Dim'.format(i), 'value': int64_array([direction_dim[i]])}).create_node_with_data() unsqueeze.create_node_with_data([data, unsqueeze_dim_data], dict(name=rnn_layer_name + '/UnsqueezeNumDirections/{}'.format(i)), data_nodes=[old_data_node])
def replace_pattern(self, graph: Graph, match: dict): lstm = match['lstm'] # Build TensorIterator body first body = Graph(name=lstm.name + '/sub_graph') body.graph = graph.graph # 1. Input squeeze Reshape inputs = [Op._create_data_node(body, lstm.name + '/inport/' + str(inp), {'shape': lstm.in_node(inp).shape.copy(), 'value': lstm.in_node(inp).value.copy() if lstm.in_node(inp).value is not None and inp in [1, 2] else None}) for inp in [0, 4, 5, 1, 2]] # X, WR, B, h_init, c_init inputs[0].shape[lstm.sequence_dim] = 1 input_squeeze = Squeeze(body, dict(name=lstm.name + '/input_squeeze', internal_layer_id=0)) squeeze_dim_data = Const(body, {'name': lstm.name + '/input_squeeze_dim', 'value': [lstm.sequence_dim]}).create_node_with_data() inputs[0] = input_squeeze.create_node_with_data([inputs[0], squeeze_dim_data], edge_attrs=[{'internal_port_id': 0}]) # 2. Output unsqueeze Reshape outputs = [Op._create_data_node(body, lstm.name + '/outport/' + str(out), {'shape': lstm.out_node(out).shape.copy() if out in lstm.out_nodes() else lstm.in_node(4).shape.copy()}) for out in [0, 1]] for out in outputs: add_opoutput(body, out.id, 0, False) outputs[0].shape = shape_delete(outputs[0].shape, lstm.sequence_dim) output_unsqueeze = Unsqueeze(body, dict(name=lstm.name + 'output_unsqueeze', internal_layer_id=2)) unsqueeze_dim_data = Const(body, {'name': lstm.name + '/output_unsqueeze_dim', 'value': [lstm.sequence_dim]}).create_node_with_data() # 3. LSTMCell lstm_cell_op = LSTMCell(body, dict(hidden_size=lstm.hidden_size, activations=lstm.activations, activation_alpha=lstm.activation_alpha, activation_beta=lstm.activation_beta, clip=lstm.clip, input_forget=lstm.input_forget, name=lstm.name + '/LSTMCell', internal_layer_id=1)) lstm_cell_node = lstm_cell_op.create_node_with_data(inputs, data_nodes=outputs, edge_attrs=[{}, {'internal_port_id': 1}, {'internal_port_id': 2}, {'bin': 'weights'}, {'bin': 'biases'}]) lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4 lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5 lstm_cell_node[0] = output_unsqueeze.create_node_with_data([lstm_cell_node[0], unsqueeze_dim_data]) lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3 add_opoutput(body, lstm_cell_node[0].id, 0, False) # 4. TensorIterator layer creating assert lstm.direction in ['forward', 'reverse'] if lstm.direction == 'forward': stride = 1 start = None end = None else: assert lstm.direction == 'reverse' stride = -1 start = -1 end = 0 output_port_map = [{ 'external_port_id': 3, 'internal_layer_id': 2, 'internal_port_id': 3, 'axis': lstm.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }] # Adding h_state, c_state to outputs if len(lstm.out_nodes()) == 3: output_port_map.extend([{ 'external_port_id': 4, 'internal_layer_id': 1, 'internal_port_id': 4, }, { 'external_port_id': 5, 'internal_layer_id': 1, 'internal_port_id': 5, }]) ti_op = TensorIterator(graph, { 'name': lstm.name + '/TensorIterator', 'body': body, 'in_ports_count': 3, 'out_ports_count': len(lstm.out_nodes()), 'input_port_map': [ { 'external_port_id': 0, 'internal_layer_id': 0, 'internal_port_id': 0, 'axis': lstm.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }, { 'external_port_id': 1, 'internal_layer_id': 1, 'internal_port_id': 1, }, { 'external_port_id': 2, 'internal_layer_id': 1, 'internal_port_id': 2, }, ], 'output_port_map': output_port_map, 'back_edges': [ { 'from_layer': 1, 'from_port': 4, 'to_layer': 1, 'to_port': 1, }, { 'from_layer': 1, 'from_port': 5, 'to_layer': 1, 'to_port': 2, }, ] }) assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \ "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id) outs = ti_op.create_node_with_data([lstm.in_node(i) for i in [0, 4, 5]], # X, h_init, c_init data_nodes=[lstm.out_node(i) for i in range(len(lstm.out_nodes()))], edge_attrs=[{'external_port_id': 0}, {'external_port_id': 1}, {'external_port_id': 2}]) if not isinstance(outs, list): outs = list([outs]) graph.remove_node(lstm.id) outs[0].in_edge(0)['external_port_id'] = 3 for i, out in enumerate(outs[1:]): external_port_id = 4 + i out.in_edge()['external_port_id'] = external_port_id ti = outs[0].in_node() TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti) TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti) TensorIterator.normalize_internal_ids(ti)
def replace_pattern(self, graph: Graph, match: dict): if match['rnn_layer']['op'] == 'LSTM': return rnn_layer = match['rnn_layer'] # Build TensorIterator body first body = Graph(name=rnn_layer.name + '/sub_graph') body.graph = graph.graph # 1. Input squeeze Reshape inputs = [ Op._create_data_node( body, rnn_layer.name + '/inport/' + str(inp), { 'shape': rnn_layer.in_node(inp).shape.copy(), 'value': rnn_layer.in_node(inp).value.copy() if rnn_layer.in_node(inp).value is not None and inp in [1, 2] else None }) for inp in [0, 4, 1, 2] ] # X, h_init, WR, B inputs[0].shape[rnn_layer.sequence_dim] = 1 input_squeeze = Squeeze( body, dict(name=rnn_layer.name + '/input_squeeze', internal_layer_id=0)) input_squeeze_dim = Const( body, dict(name=rnn_layer.name + '/input_squeeze_dim', value=rnn_layer.sequence_dim)).create_node_with_data() inputs[0] = input_squeeze.create_node_with_data( [inputs[0], input_squeeze_dim], edge_attrs=[{ 'internal_port_id': 0 }]) # 2. Output unsqueeze Reshape outputs = [ Op._create_data_node( body, rnn_layer.name + '/outport/' + str(out), { 'shape': rnn_layer.out_node(out).shape.copy() if out in rnn_layer.out_nodes() else None }) for out in [0] ] for out in outputs: add_opoutput(body, out.id, 0, False) outputs[0].shape = shape_delete(outputs[0].shape, rnn_layer.sequence_dim) output_unsqueeze_dim = Const( body, dict(name=rnn_layer.name + '/output_unsqueeze_dim', value=rnn_layer.sequence_dim)).create_node_with_data() output_unsqueeze = Unsqueeze( body, dict(name=rnn_layer.name + '/output_unsqueeze/', internal_layer_id=2)) additional_attrs = dict(activations=rnn_layer.activations, activation_alpha=rnn_layer.activation_alpha, activation_beta=rnn_layer.activation_beta, clip=rnn_layer.clip) if rnn_layer.op == 'GRU': additional_attrs[ 'linear_before_reset'] = rnn_layer.linear_before_reset # 3. ***Cell rnn_cell_op = self.get_rnn_cell(rnn_layer['op'])( body, dict(hidden_size=rnn_layer.hidden_size, name=rnn_layer.name + '/{}Cell'.format(rnn_layer.op), **additional_attrs, internal_layer_id=1)) gru_cell = rnn_cell_op.create_node_with_data(inputs, data_nodes=outputs, edge_attrs=[{}, { 'internal_port_id': 1 }, { 'internal_port_id': 2 }, { 'bin': 'weights' }, { 'bin': 'biases' }]) # internal ports for outputs of cell gru_cell.in_node().out_edge(0)['internal_port_id'] = 4 # h_state gru_cell = output_unsqueeze.create_node_with_data( [gru_cell, output_unsqueeze_dim]) gru_cell.in_node().out_edge(0)['internal_port_id'] = 3 add_opoutput(body, gru_cell.id, 0, False) # 4. TensorIterator layer creating assert rnn_layer.direction in ['forward', 'reverse'] if rnn_layer.direction == 'forward': stride = 1 start = None end = None else: assert rnn_layer.direction == 'reverse' stride = -1 start = -1 end = 0 # stacked h_state output_port_map = [{ 'external_port_id': 3, 'internal_layer_id': 2, 'internal_port_id': 3, 'axis': rnn_layer.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }] # Adding last h_state to outputs if len(rnn_layer.out_nodes()) == 2: output_port_map.extend([{ 'external_port_id': 4, 'internal_layer_id': 1, 'internal_port_id': 4, }]) ti_op = TensorIterator( graph, { 'name': rnn_layer.name + '/TensorIterator', 'body': body, 'in_ports_count': 4, 'out_ports_count': len(rnn_layer.out_nodes()), 'input_port_map': [ { 'external_port_id': 0, 'internal_layer_id': 0, 'internal_port_id': 0, 'axis': rnn_layer.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }, { 'external_port_id': 1, 'internal_layer_id': 1, 'internal_port_id': 1, }, ], 'output_port_map': output_port_map, # only for h state 'back_edges': [ { 'from_layer': 1, 'from_port': 4, 'to_layer': 1, 'to_port': 1, }, ] }) assert sorted(rnn_layer.out_nodes().keys()) == list(range(len(rnn_layer.out_nodes()))), \ "There are gaps in output ports of GRUSequence operation. Node {}".format(rnn_layer.id) outs = ti_op.create_node_with_data( [rnn_layer.in_node(i) for i in [0, 4]], # X, h_init data_nodes=[ rnn_layer.out_node(i) for i in range(len(rnn_layer.out_nodes())) ], edge_attrs=[{ 'external_port_id': 0 }, { 'external_port_id': 1 }]) if not isinstance(outs, list): outs = list([outs]) graph.remove_node(rnn_layer.id) outs[0].in_edge(0)['external_port_id'] = 3 for i, out in enumerate(outs[1:]): external_port_id = 4 + i out.in_edge()['external_port_id'] = external_port_id ti = outs[0].in_node() TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti) TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti) TensorIterator.normalize_internal_ids(ti)