def test5_not_constant(self): # ,--------------->consumer3 ,->consumer3 # data---(new_shape1)-->consumer1 => data----->consumer1 # `-(new_shape1)-->consumer2 `-->consumer2 # graph = build_graph(nodes_attributes, [('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3])}), ('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 3])}), ('placeholder_1_data', 'consumer_3'), ('consumer_1', 'concat'), ('consumer_2', 'concat'), ('consumer_3', 'concat'), ], {'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True) graph_ref = build_graph(nodes_attributes, [('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3])}), ('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 3])}), ('placeholder_1_data', 'consumer_3'), ('consumer_1', 'concat'), ('consumer_2', 'concat'), ('consumer_3', 'concat'), ], {'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True) pattern = EltwiseInputReshape() pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp)
def test_3(self): graph = build_graph( nodes, edges, { 'mul_const_data': { 'shape': np.array([3, 1, 1]), 'value': np.array([[[-1]], [[1]], [[-1]]]) }, 'quantize_data': { 'shape': np.array([2, 3, 4, 4]) }, 'mi_o_data': { 'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([0]), (1, 1, 1, 1)) }, 'ma_o_data': { 'shape': np.array([1, 1, 1, 1]), 'value': np.broadcast_to(np.array([1]), (1, 1, 1, 1)) }, }, nodes_with_edges_only=True) graph.stage = 'middle' graph_ref = build_graph( nodes, edges_ref, { 'quantize_data': { 'shape': np.array([2, 3, 4, 4]) }, 'mul_const_data': { 'shape': np.array([3, 1, 1]), 'value': np.array([[[-1]], [[1]], [[-1]]]) }, 'mi_o_data': { 'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[1]], [[0]], [[1]]]) }, 'ma_o_data': { 'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[0]], [[1]], [[0]]]) }, 'mi_i_data': { 'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[10]], [[-10]], [[10]]]) }, 'ma_i_data': { 'shape': np.array([1, 3, 1, 1]), 'value': np.array([[[-10]], [[10]], [[-10]]]) }, }, nodes_with_edges_only=True) MulFakeQuantizeFuse().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_1(self): # # NHWC NCHW NHWC # Input->DATA->Transpose->DATA->Transpose->DATA => Input->DATA # graph = build_graph(nodes_attributes, [ ('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'permute_1'), ('permute_1', 'permute_1_data'), ('permute_1_data', 'permute_2'), ('permute_2', 'permute_2_data'), ('permute_2_data', 'op_output'), ('const_1', 'const_1_data'), ('const_1_data', 'permute_1', { 'in': 1 }), ('const_2', 'const_2_data'), ('const_2_data', 'permute_2', { 'in': 1 }), ], { 'placeholder_1_data': { 'shape': np.array([1, 227, 227, 3]) }, 'const_1_data': { 'value': np.array([0, 3, 1, 2]) }, 'permute_1_data': { 'shape': np.array([1, 3, 227, 227]) }, 'const_2_data': { 'value': np.array([0, 2, 3, 1]) }, 'permute_2_data': { 'shape': np.array([1, 227, 227, 3]) }, }, nodes_with_edges_only=True) graph.graph['layout'] = 'NHWC' graph.graph['cmd_params'] = Namespace(keep_shape_ops=False) graph_ref = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'op_output')], {'placeholder_1_data': { 'shape': np.array([1, 227, 227, 3]) }}, nodes_with_edges_only=True) pattern = FuseTransposesSequence() pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1_data', check_op_attrs=True) self.assertTrue(flag, resp)
def test_2(self): # # Input->DATA->Permute->DATA->Permute->DATA => Input->DATA->Permute->DATA # graph = build_graph(nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'permute_1'), ('permute_1', 'permute_1_data'), ('permute_1_data', 'permute_2'), ('permute_2', 'permute_2_data')], { 'placeholder_1_data': { 'shape': np.array([1, 227, 227, 3]) }, 'permute_1': { 'order': np.array([0, 3, 1, 2]) }, 'permute_1_data': { 'shape': np.array([1, 3, 227, 227]) }, 'permute_2': { 'order': np.array([0, 1, 2, 3]) }, 'permute_2_data': { 'shape': np.array([1, 3, 227, 227]), 'is_output': True }, }, nodes_with_edges_only=True) graph.graph['layout'] = 'NHWC' graph_ref = build_graph(nodes_attributes, [ ('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'permute_1'), ('permute_1', 'permute_1_data'), ], { 'placeholder_1_data': { 'shape': np.array([1, 227, 227, 3]) }, 'permute_1': { 'order': np.array([0, 3, 1, 2]) }, 'permute_1_data': { 'shape': np.array([1, 3, 227, 227]) }, }, nodes_with_edges_only=True) pattern = FusePermutesSequence() pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1_data', check_op_attrs=True) self.assertTrue(flag, resp)
def test6(self): # Original graph # data(1,64,1)-->Reduce(axis=-2,keep_dims=True, reduce_type=Sum)-->data(1,1,1) # # Reference graph # data(1,61,1)->Reshape(1,1,64,1)->Pool(1,1,1,1)->Reshape(1,1,1)->Power(scale=64) # graph = build_graph(nodes_attributes, [('placeholder_1_data', 'reduce_1'), ('reduce_1', 'reduce_1_data'), ('reduce_1_data', 'concat'), ], {'placeholder_1_data': {'shape': np.array([1, 64, 1])}, 'reduce_1': {'axis': np.array([-2]), 'keep_dims': True, 'reduce_type': 'Sum'}, 'reduce_1_data': {'shape': np.array([1, 1, 1])}, }, nodes_with_edges_only=True) graph.graph['layout'] = 'NCHW' graph_ref = build_graph(nodes_attributes, [('placeholder_1_data', 'reshape_1'), ('reshape_1', 'reshape_1_data'), ('reshape_1_data', 'pooling'), ('pooling', 'pooling_data'), ('pooling_data', 'reshape_2'), ('reshape_2', 'reshape_2_data'), ('reshape_2_data', 'power'), ('power', 'power_data'), ('power_data', 'concat'), ], {'placeholder_1_data': {'shape': np.array([1, 64, 1])}, 'reshape_1': {'dim': np.array([1, 1, 64, 1])}, 'reshape_1_data': {'shape': np.array([1, 1, 64, 1])}, 'pooling': {'window': np.array([1, 1, 64, 1])}, 'pooling_data': {'shape': np.array([1, 1, 1, 1])}, 'reshape_2': {'dim': np.array([1, 1, 1])}, 'reshape_2_data': {'shape': np.array([1, 1, 1])}, 'power': {'scale': 64.0}, 'power_data': {'shape': np.array([1, 1, 1])}, }, nodes_with_edges_only=True) pattern = ReduceReplacer() pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp)
def test5(self): # Original graph # data(1, 16, 64, 64, 64, 4)-->Reduce(axis=[5],keep_dims=False)-->data(1, 16, 64, 64, 64) # # Reference graph # data(1, 16, 64, 64, 64, 4)->Reshape(1*16*64*64, 64, 4, 1)->Pool(1, 1, 4, 1)->Reshape(1, 16, 64, 64, 64) # graph = build_graph(nodes_attributes, [('placeholder_1_data', 'reduce_1'), ('reduce_1', 'reduce_1_data'), ('reduce_1_data', 'concat'), ], {'placeholder_1_data': {'shape': np.array([1, 16, 64, 64, 64, 4])}, 'reduce_1': {'axis': np.array([5]), 'keep_dims': False, 'reduce_type': 'max'}, 'reduce_1_data': {'shape': np.array([1, 16, 64, 64, 64])}, }, nodes_with_edges_only=True) graph.graph['layout'] = 'NCHW' graph_ref = build_graph(nodes_attributes, [('placeholder_1_data', 'reshape_1'), ('reshape_1', 'reshape_1_data'), ('reshape_1_data', 'pooling'), ('pooling', 'pooling_data'), ('pooling_data', 'reshape_2'), ('reshape_2', 'reshape_2_data'), ('reshape_2_data', 'concat'), ], {'placeholder_1_data': {'shape': np.array([1, 16, 64, 64, 64, 4])}, 'reshape_1': {'dim': np.array([65536, 64, 4, 1])}, 'reshape_1_data': {'shape': np.array([65536, 64, 4, 1])}, 'pooling': {'window': np.array([1, 1, 4, 1])}, 'pooling_data': {'shape': np.array([65536, 64, 1, 1])}, 'reshape_2': {'dim': np.array([1, 16, 64, 64, 64])}, 'reshape_2_data': {'shape': np.array([1, 16, 64, 64, 64])}, }, nodes_with_edges_only=True) pattern = ReduceReplacer() pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp)
def test1_not_constant(self): # # data1(1,3,64,64)----. data(1,3,64,64)-------. # data2(1,64,1)-------->Eltwise-->data(1,3,64,64) => data(1,64,1)->Reshape->data(1,1,64,1)-->Eltwise->... # data3(64,1)------' data(64,1)->Reshape->data(1,1,64,1)-' # graph = build_graph(nodes_attributes, [('placeholder_1_data', 'eltwise_1'), ('placeholder_2_data', 'eltwise_1'), ('placeholder_3_data', 'eltwise_1'), ('eltwise_1', 'eltwise_1_data') ], {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])}, 'placeholder_2_data': {'shape': np.array([1, 64, 1])}, 'placeholder_3_data': {'shape': np.array([64, 1])}, 'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])} }, nodes_with_edges_only=True) graph_ref = build_graph(nodes_attributes, [('placeholder_1_data', 'eltwise_1'), ('placeholder_2_data', 'reshape_1'), ('placeholder_3_data', 'reshape_2'), ('reshape_1', 'reshape_1_data'), ('reshape_2', 'reshape_2_data'), ('reshape_1_data', 'eltwise_1'), ('reshape_2_data', 'eltwise_1'), ('eltwise_1', 'eltwise_1_data') ], {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])}, 'reshape_1': {'dim': np.array([1, 1, 64, 1])}, 'reshape_1_data': {'shape': np.array([1, 1, 64, 1])}, 'reshape_2': {'dim': np.array([1, 1, 64, 1])}, 'reshape_2_data': {'shape': np.array([1, 1, 64, 1])}, 'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])} }, nodes_with_edges_only=True) pattern = EltwiseInputNormalize() pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_1', check_op_attrs=True) self.assertTrue(flag, resp)
def test_mega_hardcore(self): # ORIGINAL GRAPH # # data1(1,3,64,64)---,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64) # /\ /\ /\ # data2(64,1)-----,-'--------------------------------'------------------------------' # \/ / # data3(64,1)----`-->Eltwise3->data(64,1)----------' # # REFERENCE GRAPH AFTER TRANSFORMATION # # data1(1,3,64,64)---,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64) # /\ /\ /\ # data2(1,1,64,1)---'--------------------------------'-------------------------------' # / # data4(64,1)-------, Reshape(1,1,64,1) # \/ | # data3(64,1)------`---->Eltwise3->data(64,1)---' # graph = build_graph(nodes_attributes, [ ('placeholder_1_data', 'eltwise_1'), ('placeholder_2_data', 'eltwise_1'), ('eltwise_1', 'eltwise_1_data'), ('eltwise_1_data', 'eltwise_2'), ('placeholder_2_data', 'eltwise_3'), ('placeholder_3_data', 'eltwise_3'), ('eltwise_3', 'eltwise_3_data'), ('eltwise_3_data', 'eltwise_2'), ('eltwise_2', 'eltwise_2_data'), ('eltwise_2_data', 'eltwise_4'), ('placeholder_2_data', 'eltwise_4'), ('eltwise_4', 'eltwise_4_data'), ], { 'placeholder_1_data': { 'shape': np.array([1, 3, 64, 64]) }, 'placeholder_2_data': { 'shape': np.array([64, 1]), 'value': np.ones([64, 1]) }, 'placeholder_3_data': { 'shape': np.array([64, 1]) }, 'eltwise_1_data': { 'shape': np.array([1, 3, 64, 64]) }, 'eltwise_2_data': { 'shape': np.array([1, 3, 64, 64]) }, 'eltwise_3_data': { 'shape': np.array([64, 1]) }, 'eltwise_4_data': { 'shape': np.array([1, 3, 64, 64]) } }, nodes_with_edges_only=True) graph_ref = build_graph(nodes_attributes, [ ('placeholder_1_data', 'eltwise_1'), ('placeholder_2_data', 'eltwise_1'), ('eltwise_1', 'eltwise_1_data'), ('eltwise_1_data', 'eltwise_2'), ('placeholder_4_data', 'eltwise_3'), ('placeholder_3_data', 'eltwise_3'), ('eltwise_3', 'eltwise_3_data'), ('eltwise_3_data', 'reshape_1'), ('reshape_1', 'reshape_1_data'), ('reshape_1_data', 'eltwise_2'), ('eltwise_2', 'eltwise_2_data'), ('eltwise_2_data', 'eltwise_4'), ('placeholder_2_data', 'eltwise_4'), ('eltwise_4', 'eltwise_4_data'), ], { 'placeholder_1_data': { 'shape': np.array([1, 3, 64, 64]) }, 'placeholder_2_data': { 'shape': np.array([1, 1, 64, 1]), 'value': np.ones([1, 1, 64, 1]) }, 'placeholder_3_data': { 'shape': np.array([64, 1]) }, 'placeholder_4_data': { 'shape': np.array([64, 1]), 'value': np.ones([64, 1]) }, 'reshape_1': { 'dim': np.array([1, 1, 64, 1]) }, 'reshape_1_data': { 'shape': np.array([1, 1, 64, 1]) }, 'eltwise_1_data': { 'shape': np.array([1, 3, 64, 64]) }, 'eltwise_2_data': { 'shape': np.array([1, 3, 64, 64]) }, 'eltwise_3_data': { 'shape': np.array([64, 1]) }, 'eltwise_4_data': { 'shape': np.array([1, 3, 64, 64]) } }, nodes_with_edges_only=True) pattern = EltwiseInputNormalize() pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_1', check_op_attrs=True) self.assertTrue(flag, resp)
def test4(self): # Original graph # data(2,3,64,64)-->Reduce(axis=[1,2,3],keep_dims=False)-->data(2) # # Reference graph # data(2,3,64,64)->Reshape(2,1,3*64*64,1)->Pool(2,1,1,1)->Reshape(2) # graph = build_graph(nodes_attributes, [ ('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'reduce_1'), ('const', 'const_data'), ('const_data', 'reduce_1', { 'in': 1 }), ('reduce_1', 'reduce_1_data'), ('reduce_1_data', 'concat'), ], { 'placeholder_1': { 'shape': int64_array([2, 3, 64, 64]) }, 'placeholder_1_data': { 'shape': int64_array([2, 3, 64, 64]) }, 'reduce_1': { 'keep_dims': False, 'type': 'ReduceMean' }, 'const_data': { 'value': int64_array([1, 2, 3]) }, 'reduce_1_data': { 'shape': int64_array([2]) }, }, nodes_with_edges_only=True) graph.graph['layout'] = 'NCHW' graph_ref = build_graph(nodes_attributes, [ ('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'reshape_1'), ('reshape_1_const', 'reshape_1_const_data'), ('reshape_1_const_data', 'reshape_1'), ('reshape_1', 'reshape_1_data'), ('reshape_1_data', 'pooling'), ('pooling', 'pooling_data'), ('pooling_data', 'reshape_2'), ('reshape_2_const', 'reshape_2_const_data'), ('reshape_2_const_data', 'reshape_2'), ('reshape_2', 'reshape_2_data'), ('reshape_2_data', 'concat'), ], { 'placeholder_1': { 'shape': int64_array([2, 3, 64, 64]) }, 'placeholder_1_data': { 'shape': int64_array([2, 3, 64, 64]) }, 'reshape_1_const': { 'value': int64_array([0, 1, 3 * 64 * 64, 1]), 'shape': int64_array([4]) }, 'reshape_1_const_data': { 'value': int64_array([0, 1, 3 * 64 * 64, 1]), 'shape': int64_array([4]) }, 'reshape_1_data': { 'shape': int64_array([2, 1, 3 * 64 * 64, 1]) }, 'pooling': { 'window': int64_array([1, 1, 3 * 64 * 64, 1]) }, 'pooling_data': { 'shape': int64_array([2, 1, 1, 1]) }, 'reshape_2_const': { 'value': int64_array([0]), 'shape': int64_array([1]) }, 'reshape_2_const_data': { 'value': int64_array([0]), 'shape': int64_array([1]) }, 'reshape_2_data': { 'shape': int64_array([2]) }, }, nodes_with_edges_only=True) ReduceReplacer().find_and_replace_pattern(graph) shape_inference(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp)
def test_1(self): graph = build_graph(nodes_attributes, [ ('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'fc'), ('fc_weights', 'fc'), ('fc', 'fc_data'), ], { 'placeholder_1_data': { 'shape': np.array([1, 16, 512]) }, 'fc': { 'out-size': 101 }, 'fc_weights': { 'shape': np.array([512, 101]), 'value': np.ones([512, 101]), 'input_channel_dim': 1 }, 'fc_data': { 'shape': np.array([1, 16, 101]) }, }, nodes_with_edges_only=True) graph_ref = build_graph(nodes_attributes, [ ('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'reshape_1'), ('reshape_1', 'reshape_1_data'), ('reshape_1_data', 'fc'), ('fc_weights', 'fc'), ('fc', 'fc_data'), ('fc_data', 'reshape_2'), ('reshape_2', 'reshape_2_data'), ], { 'placeholder_1_data': { 'shape': np.array([1, 16, 512]) }, 'reshape_1_data': { 'shape': np.array([16, 512]) }, 'reshape_2_data': { 'shape': np.array([1, 16, 101]) }, 'fc_weights': { 'shape': np.array([512, 101]), 'value': np.ones([512, 101]) }, 'fc': { 'out-size': 101 }, 'fc_data': { 'shape': np.array([16, 101]) }, }, nodes_with_edges_only=True) pattern = NormalizeFullyConnected() pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1_data', 'placeholder_1_data', check_op_attrs=True) self.assertTrue(flag, resp)
def test_ss_shrink_new(self): graph = build_graph( nodes_attributes_test, [ ('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'sslice_2'), ('placeholder_begin_data', 'sslice_2'), ('placeholder_end_data', 'sslice_2'), ('placeholder_stride_data', 'sslice_2'), ('sslice_2', 'sslice_2_data'), ('sslice_2_data', 'placeholder_2'), ('placeholder_2', 'placeholder_2_data'), ], { 'placeholder_1_data': { 'shape': np.array([1, 227, 227, 54]) }, 'sslice_2': { 'slices': np.array([ slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1) ]), 'shrink_axis_mask': np.array([False, False, False, True, False]), 'new_axis_mask': np.array([False, True, False, False, False]) }, 'sslice_2_data': { 'shape': np.array([1, 1, 227, 54]), 'is_output': True } }) graph.graph['layout'] = 'NHWC' graph_ref = build_graph( nodes_reshape, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'sslice_2'), ('placeholder_begin_data', 'sslice_2'), ('placeholder_end_data', 'sslice_2'), ('placeholder_stride_data', 'sslice_2'), ('sslice_2', 'sslice_2/Reshape_new_data'), ('sslice_2/Reshape_new_data', 'sslice_2/Reshape_new'), ('sslice_2/Reshape_new', 'sslice_2/Reshape_shrink_data'), ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'), ('sslice_2/Reshape_shrink', 'sslice_2_data'), ('sslice_2_data', 'placeholder_2'), ('placeholder_2', 'placeholder_2_data')], { 'placeholder_1_data': { 'shape': np.array([1, 227, 227, 54]) }, 'sslice_2': { 'slices': np.array([ slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1) ]), 'shrink_axis_mask': np.array([False, False, False, False, False]), 'new_axis_mask': np.array([False, False, False, False, False]) }, 'sslice_2_data': { 'shape': np.array([1, 1, 227, 54]) }, 'sslice_2/Reshape_new': { 'dim': np.array([1, 1, 227, 1, 54]) }, 'sslice_2/Reshape_new_data': { 'shape': np.array([1, 227, 1, 54]) }, 'sslice_2/Reshape_shrink': { 'dim': np.array([1, 1, 227, 54]) }, 'sslice_2/Reshape_shrink_data': { 'shape': np.array([1, 1, 227, 1, 54]) }, }) pattern = AddReshapeAfterStridedSlice() pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True) graph.clear() graph_ref.clear() self.assertTrue(flag, resp)
def test_lstm_nonlinearity(self): graph = build_graph( { 'in': { 'kind': 'op', 'op': 'Parameter' }, 'lstm': { 'kind': 'op', 'op': 'LstmNonLinearity', 'i_weights': np.array([]), 'f_weights': np.array([]), 'o_weights': np.array([]), }, 'out': { 'kind': 'op', 'op': 'Placeholder' } }, [('in', 'lstm'), ('lstm', 'out')], nodes_with_edges_only=True) graph.stage = 'front' # split input to (i_part, f_part, c_part, o_part, ct_1) ref_graph = build_graph(self.nodes_attributes, [ ('in', 'split'), ('split', 'scale_i_c', { 'out': 4 }), ('scale_i_c', 'i_plus_c'), ('split', 'i_plus_c', { 'out': 0 }), ('i_plus_c', 'sigmoid_i'), ('split', 'scale_f_c', { 'out': 4 }), ('scale_f_c', 'f_plus_c'), ('split', 'f_plus_c', { 'out': 1 }), ('f_plus_c', 'sigmoid_f'), ('split', 'tanhcp', { 'out': 2 }), ('tanhcp', 'i_mul_tanhc'), ('sigmoid_i', 'i_mul_tanhc'), ('sigmoid_f', 'f_mul_c'), ('split', 'f_mul_c', { 'out': 4 }), ('f_mul_c', 'fc_plus_itanhc'), ('i_mul_tanhc', 'fc_plus_itanhc'), ('split', 'scale_o_c', { 'out': 4 }), ('scale_o_c', 'o_plus_c'), ('split', 'o_plus_c', { 'out': 3 }), ('o_plus_c', 'sigmoid_o'), ('fc_plus_itanhc', 'tanhc'), ('sigmoid_o', 'o_mul_tanhc'), ('tanhc', 'o_mul_tanhc'), ('fc_plus_itanhc', 'concat'), ('o_mul_tanhc', 'concat'), ('lstm', 'out'), ], nodes_with_edges_only=True) ReplaceLstmNonLinearityPattern().replace_op(graph, Node(graph, 'lstm')) (flag, resp) = compare_graphs(graph, ref_graph, 'out', check_op_attrs=True) self.assertTrue(flag, resp)