def test5_not_constant(self):
        #        ,--------------->consumer3                ,->consumer3
        #   data---(new_shape1)-->consumer1      =>    data----->consumer1
        #        `-(new_shape1)-->consumer2                `-->consumer2
        #
        graph = build_graph(nodes_attributes,
                            [('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3])}),
                             ('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 3])}),
                             ('placeholder_1_data', 'consumer_3'),
                             ('consumer_1', 'concat'),
                             ('consumer_2', 'concat'),
                             ('consumer_3', 'concat'),
                             ],
                            {'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True)

        graph_ref = build_graph(nodes_attributes,
                            [('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3])}),
                             ('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 3])}),
                             ('placeholder_1_data', 'consumer_3'),
                             ('consumer_1', 'concat'),
                             ('consumer_2', 'concat'),
                             ('consumer_3', 'concat'),
                             ],
                            {'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True)

        pattern = EltwiseInputReshape()
        pattern.find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
        self.assertTrue(flag, resp)
    def test_3(self):
        graph = build_graph(
            nodes,
            edges, {
                'mul_const_data': {
                    'shape': np.array([3, 1, 1]),
                    'value': np.array([[[-1]], [[1]], [[-1]]])
                },
                'quantize_data': {
                    'shape': np.array([2, 3, 4, 4])
                },
                'mi_o_data': {
                    'shape': np.array([1, 1, 1, 1]),
                    'value': np.broadcast_to(np.array([0]), (1, 1, 1, 1))
                },
                'ma_o_data': {
                    'shape': np.array([1, 1, 1, 1]),
                    'value': np.broadcast_to(np.array([1]), (1, 1, 1, 1))
                },
            },
            nodes_with_edges_only=True)
        graph.stage = 'middle'
        graph_ref = build_graph(
            nodes,
            edges_ref, {
                'quantize_data': {
                    'shape': np.array([2, 3, 4, 4])
                },
                'mul_const_data': {
                    'shape': np.array([3, 1, 1]),
                    'value': np.array([[[-1]], [[1]], [[-1]]])
                },
                'mi_o_data': {
                    'shape': np.array([1, 3, 1, 1]),
                    'value': np.array([[[1]], [[0]], [[1]]])
                },
                'ma_o_data': {
                    'shape': np.array([1, 3, 1, 1]),
                    'value': np.array([[[0]], [[1]], [[0]]])
                },
                'mi_i_data': {
                    'shape': np.array([1, 3, 1, 1]),
                    'value': np.array([[[10]], [[-10]], [[10]]])
                },
                'ma_i_data': {
                    'shape': np.array([1, 3, 1, 1]),
                    'value': np.array([[[-10]], [[10]], [[-10]]])
                },
            },
            nodes_with_edges_only=True)

        MulFakeQuantizeFuse().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)

        self.assertTrue(flag, resp)
    def test_1(self):
        #
        #    NHWC           NCHW           NHWC
        #   Input->DATA->Transpose->DATA->Transpose->DATA  => Input->DATA
        #
        graph = build_graph(nodes_attributes, [
            ('placeholder_1', 'placeholder_1_data'),
            ('placeholder_1_data', 'permute_1'),
            ('permute_1', 'permute_1_data'),
            ('permute_1_data', 'permute_2'),
            ('permute_2', 'permute_2_data'),
            ('permute_2_data', 'op_output'),
            ('const_1', 'const_1_data'),
            ('const_1_data', 'permute_1', {
                'in': 1
            }),
            ('const_2', 'const_2_data'),
            ('const_2_data', 'permute_2', {
                'in': 1
            }),
        ], {
            'placeholder_1_data': {
                'shape': np.array([1, 227, 227, 3])
            },
            'const_1_data': {
                'value': np.array([0, 3, 1, 2])
            },
            'permute_1_data': {
                'shape': np.array([1, 3, 227, 227])
            },
            'const_2_data': {
                'value': np.array([0, 2, 3, 1])
            },
            'permute_2_data': {
                'shape': np.array([1, 227, 227, 3])
            },
        },
                            nodes_with_edges_only=True)

        graph.graph['layout'] = 'NHWC'
        graph.graph['cmd_params'] = Namespace(keep_shape_ops=False)

        graph_ref = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'op_output')],
            {'placeholder_1_data': {
                'shape': np.array([1, 227, 227, 3])
            }},
            nodes_with_edges_only=True)

        pattern = FuseTransposesSequence()
        pattern.find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'placeholder_1_data',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
    def test_2(self):
        #
        #   Input->DATA->Permute->DATA->Permute->DATA  => Input->DATA->Permute->DATA
        #
        graph = build_graph(nodes_attributes,
                            [('placeholder_1', 'placeholder_1_data'),
                             ('placeholder_1_data', 'permute_1'),
                             ('permute_1', 'permute_1_data'),
                             ('permute_1_data', 'permute_2'),
                             ('permute_2', 'permute_2_data')], {
                                 'placeholder_1_data': {
                                     'shape': np.array([1, 227, 227, 3])
                                 },
                                 'permute_1': {
                                     'order': np.array([0, 3, 1, 2])
                                 },
                                 'permute_1_data': {
                                     'shape': np.array([1, 3, 227, 227])
                                 },
                                 'permute_2': {
                                     'order': np.array([0, 1, 2, 3])
                                 },
                                 'permute_2_data': {
                                     'shape': np.array([1, 3, 227, 227]),
                                     'is_output': True
                                 },
                             },
                            nodes_with_edges_only=True)

        graph.graph['layout'] = 'NHWC'

        graph_ref = build_graph(nodes_attributes, [
            ('placeholder_1', 'placeholder_1_data'),
            ('placeholder_1_data', 'permute_1'),
            ('permute_1', 'permute_1_data'),
        ], {
            'placeholder_1_data': {
                'shape': np.array([1, 227, 227, 3])
            },
            'permute_1': {
                'order': np.array([0, 3, 1, 2])
            },
            'permute_1_data': {
                'shape': np.array([1, 3, 227, 227])
            },
        },
                                nodes_with_edges_only=True)

        pattern = FusePermutesSequence()
        pattern.find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'placeholder_1_data',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Beispiel #5
0
    def test6(self):
        #   Original graph
        #   data(1,64,1)-->Reduce(axis=-2,keep_dims=True, reduce_type=Sum)-->data(1,1,1)
        #
        #   Reference graph
        #   data(1,61,1)->Reshape(1,1,64,1)->Pool(1,1,1,1)->Reshape(1,1,1)->Power(scale=64)
        #
        graph = build_graph(nodes_attributes,
                            [('placeholder_1_data', 'reduce_1'),
                             ('reduce_1', 'reduce_1_data'),
                             ('reduce_1_data', 'concat'),
                             ],
                            {'placeholder_1_data': {'shape': np.array([1, 64, 1])},
                             'reduce_1': {'axis': np.array([-2]), 'keep_dims': True, 'reduce_type': 'Sum'},
                             'reduce_1_data': {'shape': np.array([1, 1, 1])},
                             }, nodes_with_edges_only=True)

        graph.graph['layout'] = 'NCHW'

        graph_ref = build_graph(nodes_attributes,
                                [('placeholder_1_data', 'reshape_1'),
                                 ('reshape_1', 'reshape_1_data'),
                                 ('reshape_1_data', 'pooling'),
                                 ('pooling', 'pooling_data'),
                                 ('pooling_data', 'reshape_2'),
                                 ('reshape_2', 'reshape_2_data'),
                                 ('reshape_2_data', 'power'),
                                 ('power', 'power_data'),
                                 ('power_data', 'concat'),
                                 ],
                                {'placeholder_1_data': {'shape': np.array([1, 64, 1])},
                                 'reshape_1': {'dim': np.array([1, 1, 64, 1])},
                                 'reshape_1_data': {'shape': np.array([1, 1, 64, 1])},
                                 'pooling': {'window': np.array([1, 1, 64, 1])},
                                 'pooling_data': {'shape': np.array([1, 1, 1, 1])},
                                 'reshape_2': {'dim': np.array([1, 1, 1])},
                                 'reshape_2_data': {'shape': np.array([1, 1, 1])},
                                 'power': {'scale': 64.0},
                                 'power_data': {'shape': np.array([1, 1, 1])},
                                 }, nodes_with_edges_only=True)

        pattern = ReduceReplacer()
        pattern.find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
        self.assertTrue(flag, resp)
Beispiel #6
0
    def test5(self):
        #   Original graph
        #   data(1, 16, 64, 64, 64, 4)-->Reduce(axis=[5],keep_dims=False)-->data(1, 16, 64, 64, 64)
        #
        #   Reference graph
        #   data(1, 16, 64, 64, 64, 4)->Reshape(1*16*64*64, 64, 4, 1)->Pool(1, 1, 4, 1)->Reshape(1, 16, 64, 64, 64)
        #
        graph = build_graph(nodes_attributes,
                            [('placeholder_1_data', 'reduce_1'),
                             ('reduce_1', 'reduce_1_data'),
                             ('reduce_1_data', 'concat'),
                             ],
                            {'placeholder_1_data': {'shape': np.array([1, 16, 64, 64, 64, 4])},
                             'reduce_1': {'axis': np.array([5]), 'keep_dims': False, 'reduce_type': 'max'},
                             'reduce_1_data': {'shape': np.array([1, 16, 64, 64, 64])},
                             }, nodes_with_edges_only=True)

        graph.graph['layout'] = 'NCHW'

        graph_ref = build_graph(nodes_attributes,
                                [('placeholder_1_data', 'reshape_1'),
                                 ('reshape_1', 'reshape_1_data'),
                                 ('reshape_1_data', 'pooling'),
                                 ('pooling', 'pooling_data'),
                                 ('pooling_data', 'reshape_2'),
                                 ('reshape_2', 'reshape_2_data'),
                                 ('reshape_2_data', 'concat'),
                                 ],
                                {'placeholder_1_data': {'shape': np.array([1, 16, 64, 64, 64, 4])},
                                 'reshape_1': {'dim': np.array([65536, 64, 4, 1])},
                                 'reshape_1_data': {'shape': np.array([65536, 64, 4, 1])},
                                 'pooling': {'window': np.array([1, 1, 4, 1])},
                                 'pooling_data': {'shape': np.array([65536, 64, 1, 1])},
                                 'reshape_2': {'dim': np.array([1, 16, 64, 64, 64])},
                                 'reshape_2_data': {'shape': np.array([1, 16, 64, 64, 64])},
                                 }, nodes_with_edges_only=True)

        pattern = ReduceReplacer()
        pattern.find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
        self.assertTrue(flag, resp)
    def test1_not_constant(self):
        #
        #   data1(1,3,64,64)----.                                                   data(1,3,64,64)-------.
        #   data2(1,64,1)-------->Eltwise-->data(1,3,64,64)   =>    data(1,64,1)->Reshape->data(1,1,64,1)-->Eltwise->...
        #   data3(64,1)------'                                       data(64,1)->Reshape->data(1,1,64,1)-'
        #
        graph = build_graph(nodes_attributes,
                            [('placeholder_1_data', 'eltwise_1'),
                             ('placeholder_2_data', 'eltwise_1'),
                             ('placeholder_3_data', 'eltwise_1'),
                             ('eltwise_1', 'eltwise_1_data')
                             ],
                            {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
                             'placeholder_2_data': {'shape': np.array([1, 64, 1])},
                             'placeholder_3_data': {'shape': np.array([64, 1])},
                             'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}
                             }, nodes_with_edges_only=True)

        graph_ref = build_graph(nodes_attributes,
                                [('placeholder_1_data', 'eltwise_1'),
                                 ('placeholder_2_data', 'reshape_1'),
                                 ('placeholder_3_data', 'reshape_2'),
                                 ('reshape_1', 'reshape_1_data'),
                                 ('reshape_2', 'reshape_2_data'),
                                 ('reshape_1_data', 'eltwise_1'),
                                 ('reshape_2_data', 'eltwise_1'),
                                 ('eltwise_1', 'eltwise_1_data')
                                 ],
                                {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
                                 'reshape_1': {'dim': np.array([1, 1, 64, 1])},
                                 'reshape_1_data': {'shape': np.array([1, 1, 64, 1])},
                                 'reshape_2': {'dim': np.array([1, 1, 64, 1])},
                                 'reshape_2_data': {'shape': np.array([1, 1, 64, 1])},
                                 'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}
                                 }, nodes_with_edges_only=True)

        pattern = EltwiseInputNormalize()
        pattern.find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_1', check_op_attrs=True)
        self.assertTrue(flag, resp)
Beispiel #8
0
    def test_mega_hardcore(self):
        #   ORIGINAL GRAPH
        #
        #   data1(1,3,64,64)---,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64)
        #                     /\                               /\                             /\
        #   data2(64,1)-----,-'--------------------------------'------------------------------'
        #                  \/                                 /
        #   data3(64,1)----`-->Eltwise3->data(64,1)----------'
        #
        #   REFERENCE GRAPH AFTER TRANSFORMATION
        #
        #   data1(1,3,64,64)---,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64)
        #                     /\                               /\                              /\
        #   data2(1,1,64,1)---'--------------------------------'-------------------------------'
        #                                                     /
        #   data4(64,1)-------,                        Reshape(1,1,64,1)
        #                    \/                           |
        #   data3(64,1)------`---->Eltwise3->data(64,1)---'
        #
        graph = build_graph(nodes_attributes, [
            ('placeholder_1_data', 'eltwise_1'),
            ('placeholder_2_data', 'eltwise_1'),
            ('eltwise_1', 'eltwise_1_data'),
            ('eltwise_1_data', 'eltwise_2'),
            ('placeholder_2_data', 'eltwise_3'),
            ('placeholder_3_data', 'eltwise_3'),
            ('eltwise_3', 'eltwise_3_data'),
            ('eltwise_3_data', 'eltwise_2'),
            ('eltwise_2', 'eltwise_2_data'),
            ('eltwise_2_data', 'eltwise_4'),
            ('placeholder_2_data', 'eltwise_4'),
            ('eltwise_4', 'eltwise_4_data'),
        ], {
            'placeholder_1_data': {
                'shape': np.array([1, 3, 64, 64])
            },
            'placeholder_2_data': {
                'shape': np.array([64, 1]),
                'value': np.ones([64, 1])
            },
            'placeholder_3_data': {
                'shape': np.array([64, 1])
            },
            'eltwise_1_data': {
                'shape': np.array([1, 3, 64, 64])
            },
            'eltwise_2_data': {
                'shape': np.array([1, 3, 64, 64])
            },
            'eltwise_3_data': {
                'shape': np.array([64, 1])
            },
            'eltwise_4_data': {
                'shape': np.array([1, 3, 64, 64])
            }
        },
                            nodes_with_edges_only=True)

        graph_ref = build_graph(nodes_attributes, [
            ('placeholder_1_data', 'eltwise_1'),
            ('placeholder_2_data', 'eltwise_1'),
            ('eltwise_1', 'eltwise_1_data'),
            ('eltwise_1_data', 'eltwise_2'),
            ('placeholder_4_data', 'eltwise_3'),
            ('placeholder_3_data', 'eltwise_3'),
            ('eltwise_3', 'eltwise_3_data'),
            ('eltwise_3_data', 'reshape_1'),
            ('reshape_1', 'reshape_1_data'),
            ('reshape_1_data', 'eltwise_2'),
            ('eltwise_2', 'eltwise_2_data'),
            ('eltwise_2_data', 'eltwise_4'),
            ('placeholder_2_data', 'eltwise_4'),
            ('eltwise_4', 'eltwise_4_data'),
        ], {
            'placeholder_1_data': {
                'shape': np.array([1, 3, 64, 64])
            },
            'placeholder_2_data': {
                'shape': np.array([1, 1, 64, 1]),
                'value': np.ones([1, 1, 64, 1])
            },
            'placeholder_3_data': {
                'shape': np.array([64, 1])
            },
            'placeholder_4_data': {
                'shape': np.array([64, 1]),
                'value': np.ones([64, 1])
            },
            'reshape_1': {
                'dim': np.array([1, 1, 64, 1])
            },
            'reshape_1_data': {
                'shape': np.array([1, 1, 64, 1])
            },
            'eltwise_1_data': {
                'shape': np.array([1, 3, 64, 64])
            },
            'eltwise_2_data': {
                'shape': np.array([1, 3, 64, 64])
            },
            'eltwise_3_data': {
                'shape': np.array([64, 1])
            },
            'eltwise_4_data': {
                'shape': np.array([1, 3, 64, 64])
            }
        },
                                nodes_with_edges_only=True)

        pattern = EltwiseInputNormalize()
        pattern.find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'eltwise_1',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
    def test4(self):
        #   Original graph
        #   data(2,3,64,64)-->Reduce(axis=[1,2,3],keep_dims=False)-->data(2)
        #
        #   Reference graph
        #   data(2,3,64,64)->Reshape(2,1,3*64*64,1)->Pool(2,1,1,1)->Reshape(2)
        #
        graph = build_graph(nodes_attributes, [
            ('placeholder_1', 'placeholder_1_data'),
            ('placeholder_1_data', 'reduce_1'),
            ('const', 'const_data'),
            ('const_data', 'reduce_1', {
                'in': 1
            }),
            ('reduce_1', 'reduce_1_data'),
            ('reduce_1_data', 'concat'),
        ], {
            'placeholder_1': {
                'shape': int64_array([2, 3, 64, 64])
            },
            'placeholder_1_data': {
                'shape': int64_array([2, 3, 64, 64])
            },
            'reduce_1': {
                'keep_dims': False,
                'type': 'ReduceMean'
            },
            'const_data': {
                'value': int64_array([1, 2, 3])
            },
            'reduce_1_data': {
                'shape': int64_array([2])
            },
        },
                            nodes_with_edges_only=True)

        graph.graph['layout'] = 'NCHW'

        graph_ref = build_graph(nodes_attributes, [
            ('placeholder_1', 'placeholder_1_data'),
            ('placeholder_1_data', 'reshape_1'),
            ('reshape_1_const', 'reshape_1_const_data'),
            ('reshape_1_const_data', 'reshape_1'),
            ('reshape_1', 'reshape_1_data'),
            ('reshape_1_data', 'pooling'),
            ('pooling', 'pooling_data'),
            ('pooling_data', 'reshape_2'),
            ('reshape_2_const', 'reshape_2_const_data'),
            ('reshape_2_const_data', 'reshape_2'),
            ('reshape_2', 'reshape_2_data'),
            ('reshape_2_data', 'concat'),
        ], {
            'placeholder_1': {
                'shape': int64_array([2, 3, 64, 64])
            },
            'placeholder_1_data': {
                'shape': int64_array([2, 3, 64, 64])
            },
            'reshape_1_const': {
                'value': int64_array([0, 1, 3 * 64 * 64, 1]),
                'shape': int64_array([4])
            },
            'reshape_1_const_data': {
                'value': int64_array([0, 1, 3 * 64 * 64, 1]),
                'shape': int64_array([4])
            },
            'reshape_1_data': {
                'shape': int64_array([2, 1, 3 * 64 * 64, 1])
            },
            'pooling': {
                'window': int64_array([1, 1, 3 * 64 * 64, 1])
            },
            'pooling_data': {
                'shape': int64_array([2, 1, 1, 1])
            },
            'reshape_2_const': {
                'value': int64_array([0]),
                'shape': int64_array([1])
            },
            'reshape_2_const_data': {
                'value': int64_array([0]),
                'shape': int64_array([1])
            },
            'reshape_2_data': {
                'shape': int64_array([2])
            },
        },
                                nodes_with_edges_only=True)

        ReduceReplacer().find_and_replace_pattern(graph)
        shape_inference(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'concat',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Beispiel #10
0
    def test_1(self):
        graph = build_graph(nodes_attributes, [
            ('placeholder_1', 'placeholder_1_data'),
            ('placeholder_1_data', 'fc'),
            ('fc_weights', 'fc'),
            ('fc', 'fc_data'),
        ], {
            'placeholder_1_data': {
                'shape': np.array([1, 16, 512])
            },
            'fc': {
                'out-size': 101
            },
            'fc_weights': {
                'shape': np.array([512, 101]),
                'value': np.ones([512, 101]),
                'input_channel_dim': 1
            },
            'fc_data': {
                'shape': np.array([1, 16, 101])
            },
        },
                            nodes_with_edges_only=True)

        graph_ref = build_graph(nodes_attributes, [
            ('placeholder_1', 'placeholder_1_data'),
            ('placeholder_1_data', 'reshape_1'),
            ('reshape_1', 'reshape_1_data'),
            ('reshape_1_data', 'fc'),
            ('fc_weights', 'fc'),
            ('fc', 'fc_data'),
            ('fc_data', 'reshape_2'),
            ('reshape_2', 'reshape_2_data'),
        ], {
            'placeholder_1_data': {
                'shape': np.array([1, 16, 512])
            },
            'reshape_1_data': {
                'shape': np.array([16, 512])
            },
            'reshape_2_data': {
                'shape': np.array([1, 16, 101])
            },
            'fc_weights': {
                'shape': np.array([512, 101]),
                'value': np.ones([512, 101])
            },
            'fc': {
                'out-size': 101
            },
            'fc_data': {
                'shape': np.array([16, 101])
            },
        },
                                nodes_with_edges_only=True)

        pattern = NormalizeFullyConnected()
        pattern.find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'placeholder_1_data',
                                      'placeholder_1_data',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Beispiel #11
0
    def test_ss_shrink_new(self):
        graph = build_graph(
            nodes_attributes_test, [
                ('placeholder_1', 'placeholder_1_data'),
                ('placeholder_1_data', 'sslice_2'),
                ('placeholder_begin_data', 'sslice_2'),
                ('placeholder_end_data', 'sslice_2'),
                ('placeholder_stride_data', 'sslice_2'),
                ('sslice_2', 'sslice_2_data'),
                ('sslice_2_data', 'placeholder_2'),
                ('placeholder_2', 'placeholder_2_data'),
            ], {
                'placeholder_1_data': {
                    'shape': np.array([1, 227, 227, 54])
                },
                'sslice_2': {
                    'slices':
                    np.array([
                        slice(0, 1, 1),
                        slice(0, 1, 1),
                        slice(0, 227, 1),
                        slice(0, 1, 1),
                        slice(0, 54, 1)
                    ]),
                    'shrink_axis_mask':
                    np.array([False, False, False, True, False]),
                    'new_axis_mask':
                    np.array([False, True, False, False, False])
                },
                'sslice_2_data': {
                    'shape': np.array([1, 1, 227, 54]),
                    'is_output': True
                }
            })
        graph.graph['layout'] = 'NHWC'

        graph_ref = build_graph(
            nodes_reshape,
            [('placeholder_1', 'placeholder_1_data'),
             ('placeholder_1_data', 'sslice_2'),
             ('placeholder_begin_data', 'sslice_2'),
             ('placeholder_end_data', 'sslice_2'),
             ('placeholder_stride_data', 'sslice_2'),
             ('sslice_2', 'sslice_2/Reshape_new_data'),
             ('sslice_2/Reshape_new_data', 'sslice_2/Reshape_new'),
             ('sslice_2/Reshape_new', 'sslice_2/Reshape_shrink_data'),
             ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
             ('sslice_2/Reshape_shrink', 'sslice_2_data'),
             ('sslice_2_data', 'placeholder_2'),
             ('placeholder_2', 'placeholder_2_data')], {
                 'placeholder_1_data': {
                     'shape': np.array([1, 227, 227, 54])
                 },
                 'sslice_2': {
                     'slices':
                     np.array([
                         slice(0, 1, 1),
                         slice(0, 1, 1),
                         slice(0, 227, 1),
                         slice(0, 1, 1),
                         slice(0, 54, 1)
                     ]),
                     'shrink_axis_mask':
                     np.array([False, False, False, False, False]),
                     'new_axis_mask':
                     np.array([False, False, False, False, False])
                 },
                 'sslice_2_data': {
                     'shape': np.array([1, 1, 227, 54])
                 },
                 'sslice_2/Reshape_new': {
                     'dim': np.array([1, 1, 227, 1, 54])
                 },
                 'sslice_2/Reshape_new_data': {
                     'shape': np.array([1, 227, 1, 54])
                 },
                 'sslice_2/Reshape_shrink': {
                     'dim': np.array([1, 1, 227, 54])
                 },
                 'sslice_2/Reshape_shrink_data': {
                     'shape': np.array([1, 1, 227, 1, 54])
                 },
             })

        pattern = AddReshapeAfterStridedSlice()
        pattern.find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'sslice_2_data',
                                      check_op_attrs=True)
        graph.clear()
        graph_ref.clear()
        self.assertTrue(flag, resp)
 def test_lstm_nonlinearity(self):
     graph = build_graph(
         {
             'in': {
                 'kind': 'op',
                 'op': 'Parameter'
             },
             'lstm': {
                 'kind': 'op',
                 'op': 'LstmNonLinearity',
                 'i_weights': np.array([]),
                 'f_weights': np.array([]),
                 'o_weights': np.array([]),
             },
             'out': {
                 'kind': 'op',
                 'op': 'Placeholder'
             }
         }, [('in', 'lstm'), ('lstm', 'out')],
         nodes_with_edges_only=True)
     graph.stage = 'front'
     # split input to (i_part, f_part, c_part, o_part, ct_1)
     ref_graph = build_graph(self.nodes_attributes, [
         ('in', 'split'),
         ('split', 'scale_i_c', {
             'out': 4
         }),
         ('scale_i_c', 'i_plus_c'),
         ('split', 'i_plus_c', {
             'out': 0
         }),
         ('i_plus_c', 'sigmoid_i'),
         ('split', 'scale_f_c', {
             'out': 4
         }),
         ('scale_f_c', 'f_plus_c'),
         ('split', 'f_plus_c', {
             'out': 1
         }),
         ('f_plus_c', 'sigmoid_f'),
         ('split', 'tanhcp', {
             'out': 2
         }),
         ('tanhcp', 'i_mul_tanhc'),
         ('sigmoid_i', 'i_mul_tanhc'),
         ('sigmoid_f', 'f_mul_c'),
         ('split', 'f_mul_c', {
             'out': 4
         }),
         ('f_mul_c', 'fc_plus_itanhc'),
         ('i_mul_tanhc', 'fc_plus_itanhc'),
         ('split', 'scale_o_c', {
             'out': 4
         }),
         ('scale_o_c', 'o_plus_c'),
         ('split', 'o_plus_c', {
             'out': 3
         }),
         ('o_plus_c', 'sigmoid_o'),
         ('fc_plus_itanhc', 'tanhc'),
         ('sigmoid_o', 'o_mul_tanhc'),
         ('tanhc', 'o_mul_tanhc'),
         ('fc_plus_itanhc', 'concat'),
         ('o_mul_tanhc', 'concat'),
         ('lstm', 'out'),
     ],
                             nodes_with_edges_only=True)
     ReplaceLstmNonLinearityPattern().replace_op(graph, Node(graph, 'lstm'))
     (flag, resp) = compare_graphs(graph,
                                   ref_graph,
                                   'out',
                                   check_op_attrs=True)
     self.assertTrue(flag, resp)