def replace_pattern(graph: Graph, match: dict): nodes = [ ('input_unsqueezed'), ('squeeze', dict(op='Reshape')), ('input_squeezed'), ('input_hidden'), ('input_cell'), ('weights'), ('biases'), ('lstm', dict(op='LSTMCell')), ('output_hidden'), ('output_cell'), ('unsqueeze', dict(op='Reshape')), ('output_unsqueezed'), ] edges = [ ('input_unsqueezed', 'squeeze'), ('squeeze', 'input_squeezed'), ('input_squeezed', 'lstm', { 'in': 0 }), ('input_hidden', 'lstm', { 'in': 1 }), ('input_cell', 'lstm', { 'in': 2 }), ('weights', 'lstm', { 'in': 3 }), ('biases', 'lstm', { 'in': 4 }), ('lstm', 'output_hidden', { 'out': 0 }), ('lstm', 'output_cell', { 'out': 1 }), ('output_hidden', 'unsqueeze'), ('unsqueeze', 'output_unsqueezed'), ] ti = match['ti'] isomorphisms = find_isomorphisms(ti.body, nodes, edges) if len(list(isomorphisms)) != 1: raise Error( 'Unsupported TensorIterator layer {} was found: either its body, ports or ' 'edges are not supported by Inference Engine. ' 'Only TensorIterator with LSTMCell in a body of strict form is supported. ' 'Please modify the original network ' 'to meet the requirements.'.format(ti.soft_get('name'))) body_match = isomorphisms[0] if body_match['input_hidden'].has_valid( 'value') or body_match['input_cell'].has_valid('value'): raise Error( 'Unsupported TensorIterator layer {} was found: initial hidden and/or cell states ' 'for LSTMCell are constants. This is not supported. ' 'Only TensorIterator with LSTMCell in a body of strict form is supported. ' 'Please modify the original network ' 'to meet the requirements.'.format(ti.soft_get('name')))
def replace_pattern(self, graph: Graph, match: dict): # This transformation works if and only if a body of TI # matches the following topology (Reshape -> LSTMCell -> Reshape) nodes = [('input_unsqueezed'), ('squeeze', dict(op='Reshape')), ('input_squeezed'), ('input_hidden'), ('input_cell'), ('weights'), ('biases'), ('lstm', dict(op='LSTMCell')), ('output_hidden'), ('output_cell'), ('unsqueeze', dict(op='Reshape')), ('output_unsqueezed'), ('const_w', dict(op='Const')), ('const_b', dict(op='Const')), ('op_output', dict(op='OpOutput')), ('op_output_1', dict(op='OpOutput')), ('op_output_2', dict(op='OpOutput'))] edges = [ ('input_unsqueezed', 'squeeze'), ('squeeze', 'input_squeezed'), ('input_squeezed', 'lstm', { 'in': 0 }), ('input_hidden', 'lstm', { 'in': 1 }), ('input_cell', 'lstm', { 'in': 2 }), ('weights', 'lstm', { 'in': 3 }), ('biases', 'lstm', { 'in': 4 }), ('const_w', 'weights'), ('const_b', 'biases'), ('lstm', 'output_hidden', { 'out': 0 }), ('lstm', 'output_cell', { 'out': 1 }), ('output_hidden', 'unsqueeze'), ('unsqueeze', 'output_unsqueezed'), ('output_unsqueezed', 'op_output'), ('output_hidden', 'op_output_1'), ('output_cell', 'op_output_2'), ] ti = match['ti'] isomorphisms = find_isomorphisms(ti.body, nodes, edges) if len(list(isomorphisms)) != 1: return isomorphism = isomorphisms[0] direct_permute = match['direct_permute'] inverse_permute = match['inverse_permute'] permute_order = [1, 0, 2] # Check both perumute orders exactly match expected one - [1, 0, 2] if not direct_permute.has_valid('order') or not np.array_equal( direct_permute.order, permute_order): return if not inverse_permute.has_valid('order') or not np.array_equal( inverse_permute.order, permute_order): return def find_ports(port_map: list, attrs: dict): """ Find all ports in a given port map with specified attributes """ result = [] for i, port in enumerate(port_map): if dict_includes(port, attrs): result.append(i) return result # Check TI has only single partitioned input/output port; all partitioned ports have defined axis data_input_port = find_ports(ti.input_port_map, {'axis': lambda attr: attr in [0, 1]}) data_output_port = find_ports(ti.output_port_map, {'axis': lambda attr: attr in [0, 1]}) assert len(data_input_port) == 1 assert len(data_output_port) == 1 data_input_port = data_input_port[0] data_output_port = data_output_port[0] # Verify that they are really connected to Permute layers (guarantied by port numbers of TI, see the pattern) assert ti.in_edge(0)['external_port_id'] == ti.input_port_map[ data_input_port]['external_port_id'] assert ti.out_edge(0)['external_port_id'] == ti.output_port_map[ data_output_port]['external_port_id'] # Verify that the TI body have required Reshapes connected to the found ports squeeze = isomorphism['squeeze'] unsqueeze = isomorphism['unsqueeze'] assert squeeze['internal_layer_id'] == ti.input_port_map[ data_input_port]['internal_layer_id'] assert squeeze.in_edge(0)['internal_port_id'] == ti.input_port_map[ data_input_port]['internal_port_id'] assert unsqueeze['internal_layer_id'] == ti.output_port_map[ data_output_port]['internal_layer_id'] assert unsqueeze.out_edge(0)['internal_port_id'] == ti.output_port_map[ data_output_port]['internal_port_id'] assert len(squeeze.in_node().shape) == 3 assert len(squeeze.out_node().shape) == 2 assert len(unsqueeze.in_node().shape) == 2 assert len(unsqueeze.out_node().shape) == 3 # Remove permutes remove_op_node_with_data_node(graph, direct_permute) remove_op_node_with_data_node(graph, inverse_permute) match['output'].shape = match['output'].shape[permute_order] # swap 0/1 axis for partitioned ports ti.input_port_map[data_input_port][ 'axis'] = 1 - ti.input_port_map[data_input_port]['axis'] ti.output_port_map[data_output_port][ 'axis'] = 1 - ti.output_port_map[data_output_port]['axis'] # smap 0-th and 1-th shape entries for reshapes inside body squeeze.in_node().shape = squeeze.in_node().shape[[1, 0, 2]] unsqueeze.out_node().shape = unsqueeze.out_node().shape[[1, 0, 2]] unsqueeze.dim = unsqueeze.dim[[1, 0, 2]]
def replace_pattern(self, graph: Graph, match: dict): # This transformation works if and only if a body of TI # matches the following topology (Squeeze -> LSTMCell -> Unsqueeze) nodes = [ ('squeeze_dim', dict(kind='op', op='Const')), ('squeeze_dim_data', dict(kind='data')), ('unsqueeze_dim', dict(kind='op', op='Const')), ('unsqueeze_dim_data', dict(kind='data')), ('input_unsqueezed', dict(kind='data')), ('squeeze', dict(kind='op', op='Squeeze')), ('input_squeezed', dict(kind='data')), ('input_hidden', dict(kind='data')), ('input_cell', dict(kind='data')), ('weights', dict(kind='data')), ('biases', dict(kind='data')), ('lstm', dict(kind='op', op='LSTMCell')), ('output_hidden', dict(kind='data')), ('output_cell', dict(kind='data')), ('unsqueeze', dict(kind='op', op='Unsqueeze')), ('output_unsqueezed', dict(kind='data')), ('const_w', dict(kind='op', op='Const')), ('const_b', dict(kind='op', op='Const')), ('op_output', dict(kind='op', op='Result')), ('op_output_1', dict(kind='op', op='Result')), ('op_output_2', dict(kind='op', op='Result')), ('input_unsqueezed_i', dict(kind='op', op='Parameter')), ('input_hidden_i', dict(kind='op', op='Parameter')), ('input_cell_i', dict(kind='op', op='Parameter')), ] edges = [ ('input_unsqueezed', 'squeeze', { 'in': 0 }), ('squeeze', 'input_squeezed'), ('squeeze_dim', 'squeeze_dim_data'), ('squeeze_dim_data', 'squeeze', { 'in': 1 }), ('input_squeezed', 'lstm', { 'in': 0 }), ('input_hidden', 'lstm', { 'in': 1 }), ('input_cell', 'lstm', { 'in': 2 }), ('weights', 'lstm', { 'in': 3 }), ('biases', 'lstm', { 'in': 4 }), ('const_w', 'weights'), ('const_b', 'biases'), ('lstm', 'output_hidden', { 'out': 0 }), ('lstm', 'output_cell', { 'out': 1 }), ('output_hidden', 'unsqueeze'), ('unsqueeze', 'output_unsqueezed'), ('unsqueeze_dim', 'unsqueeze_dim_data'), ('unsqueeze_dim_data', 'unsqueeze', { 'in': 1 }), ('output_unsqueezed', 'op_output'), ('output_hidden', 'op_output_1'), ('output_cell', 'op_output_2'), ('input_unsqueezed_i', 'input_unsqueezed'), ('input_hidden_i', 'input_hidden'), ('input_cell_i', 'input_cell'), ] ti = match['ti'] isomorphisms = find_isomorphisms(ti.body, nodes, edges) if len(list(isomorphisms)) != 1: return isomorphism = isomorphisms[0] direct_permute = match['direct_permute'] inverse_permute = match['inverse_permute'] permute_order = [1, 0, 2] # Check both perumute orders exactly match expected one - [1, 0, 2] direct_order = direct_permute.in_port(1).data.get_value() if direct_order is None or not np.array_equal(direct_order, permute_order): return inverse_order = inverse_permute.in_port(1).data.get_value() if inverse_order is None or not np.array_equal(inverse_order, permute_order): return # Check non-ShapeOf output out of direct Transpose is exactly one direct_permute_dsts = direct_permute.out_port(0).get_destinations() if len([ dst for dst in direct_permute_dsts if dst.node.soft_get('type') != 'ShapeOf' ]) != 1: return for shape_of_dst in [ dst for dst in direct_permute_dsts if dst.node.soft_get('type') == 'ShapeOf' ]: name = shape_of_dst.node.soft_get( 'name', shape_of_dst.node.id) + '/FusedToTITranspose' gather = create_op_with_const_inputs( graph, op=Gather, op_attrs={'name': name}, port_value_dict={ 1: int64_array(permute_order), 2: int64_array(0) }) shape_of_dst.node.out_port(0).get_connection().insert_node(gather) def find_ports(port_map: list, attrs: dict): """ Find all ports in a given port map with specified attributes """ result = [] for i, port in enumerate(port_map): if dict_includes(port, attrs): result.append(i) return result # Check TI has only single partitioned input/output port; all partitioned ports have defined axis data_input_port = find_ports(ti.input_port_map, {'axis': lambda attr: attr in [0, 1]}) data_output_port = find_ports(ti.output_port_map, {'axis': lambda attr: attr in [0, 1]}) assert len(data_input_port) == 1 assert len(data_output_port) == 1 data_input_port = data_input_port[0] data_output_port = data_output_port[0] # Verify that they are really connected to Transpose layers (guarantied by port numbers of TI, see the pattern) assert ti.in_edge(0)['external_port_id'] == ti.input_port_map[ data_input_port]['external_port_id'] assert ti.out_edge(0)['external_port_id'] == ti.output_port_map[ data_output_port]['external_port_id'] # Verify that the TI body have required Reshapes connected to the found ports squeeze = isomorphism['squeeze'] unsqueeze = isomorphism['unsqueeze'] assert len(squeeze.in_node().shape) == 3 assert len(squeeze.out_node().shape) == 2 assert len(unsqueeze.in_node().shape) == 2 assert len(unsqueeze.out_node().shape) == 3 # Remove permutes remove_op_node_with_data_node(graph, direct_permute) remove_op_node_with_data_node(graph, inverse_permute) match['output'].shape = match['output'].shape[permute_order] # swap 0/1 axis for partitioned ports ti.input_port_map[data_input_port][ 'axis'] = 1 - ti.input_port_map[data_input_port]['axis'] ti.output_port_map[data_output_port][ 'axis'] = 1 - ti.output_port_map[data_output_port]['axis'] isomorphism['input_unsqueezed_i'].shape = isomorphism[ 'input_unsqueezed_i'].shape[[1, 0, 2]] isomorphism['input_unsqueezed_i'].infer( isomorphism['input_unsqueezed_i']) isomorphism['squeeze_dim'].value = ti.input_port_map[data_input_port][ 'axis'] isomorphism['squeeze_dim'].infer(isomorphism['squeeze_dim']) isomorphism['squeeze']['need_shape_inference'] = True isomorphism['unsqueeze_dim'].value = ti.output_port_map[ data_output_port]['axis'] isomorphism['unsqueeze_dim'].infer(isomorphism['unsqueeze_dim']) isomorphism['unsqueeze'].infer(isomorphism['unsqueeze'])