コード例 #1
0
    def test_negative_2(self):
        graph = build_graph(nodes,
                            edges, {'mul': {
                                'can_be_fused': False
                            }},
                            nodes_with_edges_only=True)
        graph.stage = 'middle'
        graph_ref = build_graph(nodes,
                                edges, {'mul': {
                                    'can_be_fused': False
                                }},
                                nodes_with_edges_only=True)

        MulFakeQuantizeFuse().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)

        self.assertTrue(flag, resp)
コード例 #2
0
    def test_negative_fq_unacceptable_levels(self, levels):
        nodes = nodes_dict(np.float32, None, levels)

        graph = build_graph(nodes, [
            *connect('weights:0', '0:FQ'),
            *connect('il:0', '1:FQ'),
            *connect('ih:0', '2:FQ'),
            *connect('ol:0', '3:FQ'),
            *connect('oh:0', '4:FQ'),
            *connect('FQ:0', 'output'),
        ],
                            nodes_with_edges_only=True)
        graph_ref = graph.copy()
        CompressQuantizeWeights().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
コード例 #3
0
 def test_nonreplacement(self):
     graph = build_graph(
         nodes_attrs=graph_node_attrs,
         edges=graph_edges,
         update_attributes={'roll': {
             'input_rank_changed': False
         }})
     graph.stage = 'front'
     CorrectRollAxes().find_and_replace_pattern(graph)
     ref_graph = build_graph(
         nodes_attrs=graph_node_attrs,
         edges=graph_edges,
         update_attributes={'roll': {
             'input_rank_changed': False
         }})
     (flag, resp) = compare_graphs(graph,
                                   ref_graph,
                                   'output',
                                   check_op_attrs=True)
     self.assertTrue(flag, resp)
コード例 #4
0
    def test_mean_values_explicit_and_scale_values_explicit(self):
        graph_ref = build_graph(nodes, [
            *connect('parameter', '0:add_mean'),
            *connect('mean', '1:add_mean'),
            *connect('add_mean', '0:mul_scale'),
            *connect('scale', '1:mul_scale'),
            *connect('mul_scale', 'result'),
        ])

        argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]])
        graph = build_graph(nodes, [*connect('parameter', 'result')],
                            nodes_with_edges_only=True, cli=argv)
        self.set_graph_attrs(graph, ['parameter'])
        self.set_graph_attrs(graph_ref, ['parameter'])
        graph.graph['layout'] = 'NCHW'

        AddMeanScaleValues().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
        self.assertTrue(flag, resp)
        self.check_graph_attrs(graph, graph_ref, ['parameter'])
コード例 #5
0
    def test_div_test_2(self):
        # Test with two same inputs from one placeholder
        graph = build_graph(nodes, [
            *connect('placeholder_1:0', '0:div'),
            *connect_data('placeholder_1:0', '1:div'),
            *connect('div', 'output'),
        ], nodes_with_edges_only=True)
        Div().find_and_replace_pattern(graph)

        graph_ref = build_graph(nodes, [
            *connect('placeholder_1:0', '0:mul'),
            *connect_data('placeholder_1:0', '0:reciprocal'),
            *connect('minus_one', '1:reciprocal'),
            *connect('reciprocal', '1:mul'),
            *connect('mul', 'output'),
        ], nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
        self.assertTrue(flag, resp)
        self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='Multiply')[0]]['name'] == 'my_div')
コード例 #6
0
    def test(self):
        nodes = {
            **regular_op('input', {'type': 'Parameter'}),
            **const('depth', int64_array([2])),
            **regular_op('onehot', {
                'type': 'OneHot',
                'kind': 'op',
                'op': 'OneHot'
            }),
            **regular_op('reshape', {
                'type': 'Reshape',
                'kind': 'op',
                'op': 'Reshape'
            }),
            **const('reshape_dims', int64_array([])),
            **result('result'),
        }
        edges = [
            ('input', 'onehot'),
            ('depth', 'onehot'),
            ('onehot', 'result'),
        ]
        graph = build_graph(nodes, edges)

        graph.graph['layout'] = 'NCHW'
        graph.stage = 'front'

        edges_ref = [
            ('input', 'onehot'),
            ('depth', 'reshape'),
            ('reshape_dims', 'reshape'),
            ('reshape', 'onehot'),
            ('onehot', 'result'),
        ]

        graph_ref = build_graph(nodes, edges_ref)

        OneHotDepthNormalizer().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)
    def test_simple_pooling(self):
        graph = build_graph(self.nodes, [
            *connect_front('input', 'splice'),
            *connect_front('splice', 'pool'),
            *connect_front('pool', 'out_op')
        ], nodes_with_edges_only=True)
        graph.stage = 'front'
        AddReshapeTransposeAroundConvPool.find_and_replace_pattern(graph)

        ref_graph = build_graph(self.ref_nodes,
                                [
                                    *connect_front('input', 'splice'),
                                    *connect_front('splice', '0:reshape_in'),

                                    *connect_front('splice', 'shapeof'),
                                    *connect_front('shapeof:0', '0:gather_batch'),
                                    *connect_front('ind', '1:gather_batch'),
                                    *connect_front('axis', '2:gather_batch'),
                                    *connect_front('shapeof:0', '0:gather_h'),
                                    *connect_front('ind_h', '1:gather_h'),
                                    *connect_front('axis', '2:gather_h'),
                                    *connect_front('gather_h', '0:div'),
                                    *connect_front('th', '1:div'),
                                    *connect_front('gather_batch', '0:concat'),
                                    *connect_front('t', '1:concat'),
                                    *connect_front('h', '3:concat'),
                                    *connect_front('div', '2:concat'),
                                    *connect_front('concat', '1:reshape_in'),

                                    *connect_front('reshape_in', '0:transpose_in'),
                                    *connect_front('transpose_in_order', "1:transpose_in"),
                                    *connect_front('transpose_in', 'pool'),
                                    *connect_front('pool', '0:transpose_out'),
                                    *connect_front('transpose_out_order', '1:transpose_out'),
                                    *connect_front('transpose_out', '0:reshape_out'),
                                    *connect_front('reshape_out_shape', '1:reshape_out'),
                                    *connect_front('reshape_out', 'out_op')
                                ])

        (flag, resp) = compare_graphs(graph, ref_graph, 'out_op', check_op_attrs=True)
        self.assertTrue(flag, resp)
コード例 #8
0
    def test_conv_reshape_pool(self):
        graph = build_graph(self.nodes, [
            *connect('conv', '0:transpose_out'),
            *connect('transpose_out_order', '1:transpose_out'),
            *connect('transpose_out', '0:reshape_out'),
            *connect('reshape_out_shape', '1:reshape_out'),
            *connect('reshape_out', 'shapeof'),

            *connect('shapeof', '0:gather_batch'),
            *connect('ind', '1:gather_batch'),
            *connect('axis', '2:gather_batch'),
            *connect('shapeof', '0:gather_h', skip_data=True),
            *connect('ind_h', '1:gather_h'),
            *connect('axis', '2:gather_h', skip_data=True),
            *connect('gather_h', '0:div'),
            *connect('th', '1:div'),
            *connect('gather_batch', '0:concat'),
            *connect('t', '1:concat'),
            *connect('h', '2:concat'),
            *connect('div', '3:concat'),
            *connect('concat', '1:reshape_in'),

            *connect('reshape_out', '0:reshape_in', skip_data=True),
            *connect('reshape_in', '0:transpose_in'),
            *connect('transpose_in_order', "1:transpose_in"),
            *connect('transpose_in', 'pool'),
        ], nodes_with_edges_only=True)

        FuseReshapesSequenceKaldi().find_and_replace_pattern(graph)

        ref_graph = build_graph(self.ref_nodes,
                                [
                                    *connect('conv', '0:transpose_out'),
                                    *connect('transpose_out_order', '1:transpose_out'),
                                    *connect('transpose_out', '0:transpose_in'),
                                    *connect('transpose_in_order', "1:transpose_in"),
                                    *connect('transpose_in', 'pool'),
                                ])

        (flag, resp) = compare_graphs(graph, ref_graph, 'pool')
        self.assertTrue(flag, resp)
コード例 #9
0
    def test6_not_constant(self):
        #        ,--------------->consumer3                ,->consumer3
        #   data---(new_shape1)-->consumer1      =>    data----->consumer1
        #        `-(new_shape1)-->consumer2                `-->consumer2
        #
        graph = build_graph(nodes_attributes,
                            [('placeholder_1', 'placeholder_1_data'),
                             ('placeholder_1_data', 'eltwise_1'),
                             ('placeholder_1_data', 'eltwise_2'),
                             ('placeholder_1_data', 'eltwise_3'),
                             ('eltwise_1', 'eltwise_1_data'),
                             ('eltwise_2', 'eltwise_2_data'),
                             ('eltwise_3', 'eltwise_3_data'),
                             ('eltwise_1_data', 'concat'),
                             ('eltwise_2_data', 'concat'),
                             ('eltwise_3_data', 'concat'),
                             ],
                            {'placeholder_1_data': {'shape': int64_array([1, 3])},
                             'eltwise_1_data': {'shape': int64_array([1, 3])},
                             'eltwise_2_data': {'shape': int64_array([1, 3])},
                             'eltwise_3_data': {'shape': int64_array([1, 3])},
                             }, nodes_with_edges_only=True)

        graph_ref = build_graph(nodes_attributes,
                                [('placeholder_1', 'placeholder_1_data'),
                                 ('placeholder_1_data', 'eltwise_1'),
                                 ('placeholder_1_data', 'eltwise_2'),
                                 ('placeholder_1_data', 'eltwise_3'),
                                 ('eltwise_1', 'eltwise_1_data'),
                                 ('eltwise_2', 'eltwise_2_data'),
                                 ('eltwise_3', 'eltwise_3_data'),
                                 ('eltwise_1_data', 'concat'),
                                 ('eltwise_2_data', 'concat'),
                                 ('eltwise_3_data', 'concat'),
                                 ],
                                {'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True)

        normalize_eltwise_inputs(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
        self.assertTrue(flag, resp)
コード例 #10
0
    def test_attributed_slice_replacer(self, attributed_slice_attrs):
        nodes = {
            **regular_op_with_empty_data('input', {'type': 'Parameter'}),
            **regular_op_with_empty_data('attributed_slice', attributed_slice_attrs),
            **result(),

            # nodes after replacement
            **const('start', np.array([0, 0])),
            **const('end', np.array([1, -1])),
            **const('axis', np.array(np.array([0, 1]))),
            **regular_op_with_empty_data('slice', {
                'op': 'Slice',
                'type': None
            }),
        }

        graph = build_graph(nodes_attrs=nodes,
                            edges=[
                                ('input', 'attributed_slice'),
                                ('attributed_slice', 'output'),
                            ],
                            nodes_with_edges_only=True)
        graph.stage = 'front'

        AttributedSliceToSliceReplacer().find_and_replace_pattern(graph)

        graph_ref = build_graph(nodes_attrs=nodes,
                                edges=[
                                    ('input', 'slice'),
                                    *connect_front('start', '1:slice'),
                                    *connect_front('end', '2:slice'),
                                    *connect_front('axis', '3:slice'),
                                    ('slice', 'output'),
                                ],
                                nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
コード例 #11
0
    def test_instance_normalization_test_1(self):
        graph = build_graph(nodes_attributes, [('input', 'node'),
                                               ('scale', 'node'),
                                               ('B', 'node'), ('node', 'out')],
                            {
                                'node': {
                                    'epsilon': 0.123
                                },
                            },
                            nodes_with_edges_only=True)

        graph_ref = build_graph(nodes_ref_attributes, [('input', 'mvn', {
            'out': 0
        }), ('input', 'rank', {
            'out': 0
        }), ('start', 'mvn_axes'), ('rank', 'mvn_axes'), ('step', 'mvn_axes'),
                                                       ('mvn_axes', 'mvn'),
                                                       ('mvn', 'mul'),
                                                       ('scale', 'mul'),
                                                       ('mul', 'add'),
                                                       ('B', 'add'),
                                                       ('add', 'out')],
                                {
                                    'mvn': {
                                        'eps': 0.123,
                                        'eps_mode': 'inside_sqrt',
                                        'normalize_variance': 1
                                    },
                                },
                                nodes_with_edges_only=True)

        graph.stage = 'front'

        tested_class = InstanceNormalization()
        tested_class.find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'out',
                                      check_op_attrs=False)
        self.assertTrue(flag, resp)
コード例 #12
0
    def test_v10_group_convolution_resolver(self):
        nodes = {
            **regular_op_with_shaped_data('input', [1, 3, 224, 224], {
                                              'type': 'Parameter'
                                          }),
            **valued_const_with_data('weights', np.ones([3, 8, 7, 7])),
            **valued_const_with_data('dim', int64_array([3, 8, 1, 7, 7])),
            **regular_op_with_empty_data('reshape', {'type': 'Reshape'}),
            **regular_op_with_shaped_data('convolution', None, {
                'type': 'Convolution',
                'group': 3,
                'output': 24
            }),
            **result(),
        }
        graph = build_graph(nodes, [
            *connect('input', '0:convolution'),
            *connect('weights', '1:convolution'),
            *connect('convolution', 'output'),
        ],
                            nodes_with_edges_only=True)

        V10ConvolutionWithGroupsResolver().find_and_replace_pattern(graph)

        nodes['convolution']['type'] = 'GroupConvolution'
        del nodes['convolution']['group']

        graph_ref = build_graph(nodes, [
            *connect('input', '0:convolution'),
            *connect('weights', '0:reshape'),
            *connect('dim', '1:reshape'),
            *connect('reshape', '1:convolution'),
            *connect('convolution', 'output'),
        ],
                                nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      last_node='output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
コード例 #13
0
    def test_scaleshift2_axis1_to_mul(self):
        graph = build_graph(nodes_attributes,
                            [('placeholder_1', 'placeholder_1_data'),
                             ('placeholder_2', 'placeholder_2_data'),
                             ('placeholder_1_data', 'scaleshift_1'),
                             ('placeholder_2_data', 'scaleshift_1'),
                             ('scaleshift_1', 'scaleshift_1_data'),
                             ('scaleshift_1_data', 'op_output')
                             ],
                            {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
                             'placeholder_2_data': {'shape': np.array([227])},
                             'scaleshift_1': {'axis': 1},
                             'scaleshift_1_data': {}
                             })

        graph_ref = build_graph(nodes_attributes,
                                [('placeholder_1', 'placeholder_1_data'),
                                 ('placeholder_2', 'placeholder_2_data'),
                                 ('placeholder_2_data', 'placeholder_2/Reshape_'),
                                 ('placeholder_2/Reshape_const', 'placeholder_2/Reshape_const_data'),
                                 ('placeholder_2/Reshape_const_data', 'placeholder_2/Reshape_'),
                                 ('placeholder_2/Reshape_', 'placeholder_2/Reshape_data'),
                                 ('placeholder_1_data', 'mul_1'),
                                 ('placeholder_2/Reshape_data', 'mul_1'),
                                 ('mul_1', 'scaleshift_1_data'),
                                 ('scaleshift_1_data', 'op_output')
                                 ],
                                {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
                                 'placeholder_2_data': {'shape': np.array([227])},
                                 'placeholder_2/Reshape_const': {'value': np.array([1, 227, 1, 1]), 'shape': [4]},
                                 'placeholder_2/Reshape_const_data': {'value': np.array([1, 227, 1, 1]), 'shape': [4]},
                                 'placeholder_2/Reshape_data': {'shape': np.array([1, 227, 1, 1])},
                                 'mul_1': {'can_be_fused': True},
                                 'scaleshift_1_data': {}
                                 })

        graph.graph['layout'] = 'NHWC'
        convert_scale_shift_to_mul_add(graph)
        graph.clean_up()
        (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1')
        self.assertTrue(flag, resp)
コード例 #14
0
    def test_set_ports_split2(self):
        nodes = {
            **regular_op('op1', {}),
            **regular_op('split', {'op': 'Split'}),
            **regular_op('op2', {}),
            **regular_op('op3', {}),
            **regular_op('op4', {}),
        }

        graph = build_graph(nodes, [('op1', 'split', {
            'fw_tensor_debug_info': {}
        }), ('split', 'op2', {
            'fw_tensor_debug_info': {},
            'out_port': 0
        }), ('split', 'op4', {
            'fw_tensor_debug_info': {},
            'out_port': 4
        }), ('split', 'op3', {
            'fw_tensor_debug_info': {},
            'out_port': 6
        })],
                            nodes_with_edges_only=True)

        graph.stage = 'front'
        graph.nodes()['split']['out_ports_count'] = 3

        ref_graph = build_graph(nodes, [
            *connect_front('op1:0', '0:split'),
            *connect_front('split:0', '0:op2'),
            *connect_front('split:1', '0:op4'),
            *connect_front('split:2', '0:op3')
        ],
                                nodes_with_edges_only=True)

        SetPortsPattern().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      ref_graph,
                                      'op4',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
コード例 #15
0
    def test_CTCGreedyDecoderSingle_negative(self):
        edges = [
            ('logits', 'decoder', {
                'out': 0,
                'in': 0
            }),
            ('seq_len', 'decoder', {
                'out': 0,
                'in': 1
            }),
            ('decoder', 'sparse_to_dense', {
                'out': 0,
                'in': 0
            }),
            ('decoder', 'cast', {
                'out': 1,
                'in': 0
            }),
            ('cast', 'sparse_to_dense', {
                'out': 0
            }),
            ('sparse_to_dense', 'last', {
                'out': 0,
                'in': 0
            }),
        ]
        graph = build_graph(self.nodes_attributes,
                            edges,
                            nodes_with_edges_only=True)
        graph.stage = 'front'
        CTCGreedyDecoderSingleReplacement().find_and_replace_pattern(graph)

        graph_ref = build_graph(self.nodes_attributes,
                                edges,
                                nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'last',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
コード例 #16
0
    def test_ti_reverse(self):
        graph = build_graph(nodes, [
            *connect('parameter:0', '0:direct_reverse'),
            *connect('parameter:0', 'shapeof', skip_data=True),
            *connect('shapeof:0', '0:gather_batch'),
            *connect('gather_batch_ind', '1:gather_batch'),
            *connect('gather_axis', '2:gather_batch'),
            *connect('shapeof:0', '0:gather_seq', skip_data=True),
            *connect('gather_seq_ind', '1:gather_seq'),
            *connect('gather_axis', '2:gather_seq'),
            *connect('gather_seq', '0:broadcast'),
            *connect('gather_batch', '1:broadcast'),
            *connect('broadcast', '1:direct_reverse'),
            *connect('direct_reverse', '0:ti'),
            *connect('init_hidden', '1:ti'), *connect('ti', 'inverse_shapeof'),
            *connect('inverse_shapeof:0', '0:inverse_gather_batch'),
            *connect('gather_batch_ind', '1:inverse_gather_batch'),
            *connect('gather_axis', '2:inverse_gather_batch'), *connect(
                'inverse_shapeof:0', '0:inverse_gather_seq', skip_data=True),
            *connect('gather_seq_ind', '1:inverse_gather_seq'),
            *connect('gather_axis', '2:inverse_gather_seq'),
            *connect('inverse_gather_seq', '0:inverse_broadcast'),
            *connect('inverse_gather_batch', '1:inverse_broadcast'),
            *connect('ti', '0:inverse_reverse', skip_data=True),
            *connect('inverse_broadcast', '1:inverse_reverse'), *connect(
                'inverse_reverse', 'some_op'), *connect('some_op', 'output')
        ],
                            nodes_with_edges_only=True)

        ReverseTensorIteratorLSTM().find_and_replace_pattern(graph)
        graph.clean_up()

        ref_graph = build_graph(ref_nodes, [
            *connect('parameter', '0:ti'), *connect('init_hidden', '1:ti'),
            *connect('ti', 'some_op'), *connect('some_op', 'output')
        ])
        flag, resp = compare_graphs(graph,
                                    ref_graph,
                                    'output',
                                    check_op_attrs=True)
        self.assertTrue(flag, resp)
コード例 #17
0
    def test_image_scaler_test_3(self):
        graph = build_graph(nodes_attributes, [
            ('placeholder_1', 'placeholder_1_data'),
            ('placeholder_1_data', 'im_scaler'),
            ('im_scaler', 'im_scaler_data'),
            ('im_scaler_data', 'last'),
        ], {
            'placeholder_1_data': {
                'shape': np.array([1, 227, 227, 3])
            },
            'im_scaler': {
                'scale': np.array(2.0),
                'bias': np.reshape(np.array([0, 0, 0]), [3, 1, 1])
            },
        },
                            nodes_with_edges_only=True)

        graph_ref = build_graph(nodes_attributes,
                                [('placeholder_1', 'placeholder_1_data'),
                                 ('placeholder_1_data', 'mul_1'),
                                 ('const_mul_1_w', 'mul_1_w'),
                                 ('mul_1_w', 'mul_1'), ('mul_1', 'mul_1_data'),
                                 ('mul_1_data', 'last')], {
                                     'placeholder_1_data': {
                                         'shape': np.array([1, 227, 227, 3])
                                     },
                                     'const_mul_1_w': {
                                         'shape': np.array(2.0).shape,
                                         'value': np.array(2.0)
                                     },
                                 },
                                nodes_with_edges_only=True)

        graph.graph['layout'] = 'NCHW'
        graph.stage = 'middle'

        replacer = ImageScaler()
        replacer.find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'last')
        self.assertTrue(flag, resp)
コード例 #18
0
    def test_pattern_does_not_satisfy(self, input_shape, scales):
        graph = build_graph(
            graph_node_attrs, graph_edges, {
                'placeholder_data': {
                    'shape': int64_array(input_shape)
                },
                'scales': {
                    'value': int64_array(scales),
                    'shape': int64_array(scales).shape
                },
                'scales_data': {
                    'value': int64_array(scales),
                    'shape': int64_array(scales).shape
                },
                'upsample_data': {
                    'shape': int64_array(input_shape) * int64_array(scales)
                }
            })
        graph.graph['layout'] = 'NCHW'

        ref_graph = build_graph(
            graph_node_attrs, graph_edges, {
                'placeholder_data': {
                    'shape': int64_array(input_shape)
                },
                'scales': {
                    'value': int64_array(scales),
                    'shape': int64_array(scales).shape
                },
                'scales_data': {
                    'value': int64_array(scales),
                    'shape': int64_array(scales).shape
                },
                'upsample_data': {
                    'shape': int64_array(input_shape) * int64_array(scales)
                }
            })

        UpsampleToResample().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, ref_graph, 'output')
        self.assertTrue(flag, resp)
コード例 #19
0
    def test_all_dynamic_inputs(self):
        nodes = {
            **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {
                                              'type': 'Parameter'
                                          }),
            **regular_op_with_shaped_data('min', [1, 3, 20, 20], {
                                              'type': 'Parameter'
                                          }),
            **regular_op_with_shaped_data('max', [1, 3, 20, 20], {
                                              'type': 'Parameter'
                                          }),
            **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {
                                              'type': None,
                                              'op': 'Clamp'
                                          }),
            **regular_op_with_shaped_data('maximum', [1, 3, 20, 20], {
                                              'type': 'Maximum',
                                              'op': 'Maximum'
                                          }),
            **regular_op_with_shaped_data('minimum', [1, 3, 20, 20], {
                                              'type': 'Minimum',
                                              'op': 'Minimum'
                                          }),
            **result('result'),
        }
        edges = [
            *connect('placeholder', '0:a_clamp'),
            *connect('min', '1:a_clamp'),
            *connect('max', '2:a_clamp'),
            *connect('a_clamp', 'result'),
        ]
        graph = build_graph(nodes, edges)
        ClampNormalizer().find_and_replace_pattern(graph)
        ref_graph = build_graph(nodes, [
            *connect('placeholder', '0:maximum'), *connect('min', '1:maximum'),
            *connect('maximum', '0:minimum'), *connect('max', '1:minimum'),
            *connect('minimum', 'result')
        ])

        (flag, resp) = compare_graphs(graph, ref_graph, 'result')
        self.assertTrue(flag, resp)
コード例 #20
0
    def test_positive(self, input_shape, axes, layout):
        graph = build_graph_with_attrs(nodes + [
            ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
            ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
            ('square_data', dict(kind='data', shape=input_shape)),
            ('sum_axes_data', dict(kind='data', value=axes, shape=None)),
        ], edges, nodes_with_edges_only=True)
        graph.stage = 'middle'
        graph.graph['layout'] = layout

        L2NormToNorm().find_and_replace_pattern(graph)

        graph_ref = build_graph_with_attrs(nodes + [
            ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
            ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
            ('weights_node_data', dict(kind='data', value=axes.sort())),
        ], edges_after_replacement, nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
        self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='NormalizeL2')[0]]['name'] == 'l2_norm_name')
        self.assertTrue(flag, resp)
コード例 #21
0
    def test_negative(self, input_shape, axes, layout):
        graph = build_graph_with_attrs(nodes + [
            ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
            ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
            ('square_data', dict(kind='data', shape=input_shape)),
            ('sum_axes_data', dict(kind='data', value=axes, shape=None)),
        ], edges, nodes_with_edges_only=True)
        graph.stage = 'middle'
        graph.graph['layout'] = layout

        L2NormToNorm().find_and_replace_pattern(graph)

        graph_ref = build_graph_with_attrs(nodes + [
            ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
            ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
            ('square_data', dict(kind='data', shape=input_shape)),
            ('sum_axes_data', dict(kind='data', value=axes, shape=None)),
        ], edges, nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
        self.assertTrue(flag, resp)
コード例 #22
0
    def test_mean_values_with_data_name(self):
        graph_ref = build_graph(nodes, [
            *connect('parameter', '0:add_mean'),
            *connect('mean', '1:add_mean'),
            *connect('add_mean', 'result'),
        ])

        mean_values = parse_tuple_pairs('(1,2,3)')
        scale_values = parse_tuple_pairs('')
        mean_scale = get_mean_scale_dictionary(mean_values, scale_values, None)
        argv = Namespace(mean_scale_values=mean_scale)

        graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
        self.set_graph_attrs(graph, ['parameter'])
        self.set_graph_attrs(graph_ref, ['parameter'])
        graph.graph['layout'] = 'NCHW'

        AddMeanScaleValues().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
        self.assertTrue(flag, resp)
        self.check_graph_attrs(graph, graph_ref, ['parameter'])
コード例 #23
0
    def test_insert_old_api_map(self):
        graph = build_graph(get_nodes([1, 10, 10, 3]),
                            [*connect('placeholder1', '0:mul'), *connect('placeholder2', '1:mul'),
                             *connect('mul', 'result')], nodes_with_edges_only=True,
                            cli=Namespace(reverse_input_channels=True))

        node = Node(graph, 'placeholder1')
        old_api_map = OldAPIMapOrder(version=0)
        node.rt_info.info[('old_api_map_order', old_api_map.get_version())] = old_api_map
        node.rt_info.info[('old_api_map_order', old_api_map.get_version())].old_api_transpose_parameter([0, 2, 3, 1])

        InsertReverseChannels().find_and_replace_pattern(graph)
        graph_ref = build_graph(get_nodes([1, 10, 10, 3], 3),
                                [*connect('placeholder1', 'reverse_channels'), *connect('reverse_channels', '0:mul'),
                                 *connect('placeholder2', '1:mul'), *connect('mul', 'result')])

        node2 = Node(graph_ref, 'placeholder1')
        node2.rt_info = node.rt_info

        (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
        self.assertTrue(flag, resp)
コード例 #24
0
    def test_identityN_unused_ports(self):
        graph = build_graph(nodes, [
            *connect('placeholder_0', '0:identityN'),
            *connect('placeholder_1', '1:identityN'),
            *connect('identityN:0', 'output0'),
        ],
                            nodes_with_edges_only=True)

        IdentityN_to_Identity().find_and_replace_pattern(graph)

        graph_ref = build_graph(nodes, [
            *connect('placeholder_0', 'identity0'),
            *connect('identity0', 'output0'),
        ],
                                nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output0',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
コード例 #25
0
 def test_3(self):
     graph = build_graph(nodes_attributes,
                         edges,
                         update_attributes={
                             'input_data': {
                                 'shape': int64_array([1, 224, 224, 3])
                             },
                             'shape_like_input_data': {
                                 'shape': int64_array([2, 2, 2, 2, 2])
                             },
                             'slice_like': {
                                 'axes': int64_array([1, 2])
                             }
                         },
                         nodes_with_edges_only=True)
     SliceLikeToStridedSlice().find_and_replace_pattern(graph)
     ref_graph = build_graph(nodes_attributes,
                             input_part_shape_edges,
                             nodes_with_edges_only=True)
     flag, resp = compare_graphs(graph, ref_graph, 'result')
     self.assertTrue(flag, resp)
コード例 #26
0
    def test_ScatterElementsUpdate_has_axis_and_3_inputs(self):
        graph = build_graph(nodes,
                            edges, {'node': {
                                'axis': 1
                            }},
                            nodes_with_edges_only=True)
        ScatterNormalizer().find_and_replace_pattern(graph)

        graph_ref = build_graph(nodes, [
            *edges,
            *connect('axis', '3:node'),
        ], {'axis': {
            'value': np.int64(1)
        }},
                                nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
コード例 #27
0
    def test_gelu_p3(self):
        edges = [('input', 'mul'), ('div', 'erf'), ('erf', 'add'),
                 ('add', 'mul'), ('mul', 'mul0'), ('mul_param', 'mul'),
                 ('div_param', 'div'), ('add_param', 'add'),
                 ('mul0', 'result')]

        graph = build_graph(self.nodes, edges)

        graph_ref = build_graph(ref_nodes, ref_edges)
        graph.stage = 'front'

        GeLUMergerErf().find_and_replace_pattern(graph)
        graph.clean_up()

        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)
        self.assertTrue(
            graph.get_op_nodes(op='Gelu')[0].approximation_mode == 'erf')
        self.assertTrue(
            len(graph.get_op_nodes(name='final_mul')) == 1
            and graph.get_op_nodes(name='final_mul')[0].op == 'Gelu')
コード例 #28
0
    def test_case6_dest(self):
        graph = build_graph(nodes, [('input', 'input_data'),
                                    ('input_data', 'Op1'), ('Op1', 'Op1_data'),
                                    ('Op1_data', 'Op2')])
        graph_ref = build_graph(nodes, [('input', 'input_data'),
                                        ('input_data', 'Op2'),
                                        ('Op1', 'Op1_data')])

        input_data = Node(graph_ref, 'input_data')
        input_data['fw_tensor_debug_info'] = [('Op1', 'Op1')]

        op1_node = Node(graph, 'Op1')
        op1_node.in_port(0).get_connection().set_destination(
            op1_node.out_port(0).get_destination(), "dest")

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'Op2',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
        self.check_graph_attrs_middle(graph, graph_ref)
コード例 #29
0
 def test_ifft_replacement(self, input_shape):
     graph = build_graph(nodes_attrs=fft_graph_node_attrs,
                         edges=fft_graph_edges,
                         update_attributes={
                             'placeholder': {
                                 'shape': input_shape
                             },
                             'fft': {
                                 'is_inverse': True
                             }
                         })
     graph.stage = 'front'
     MXFFTToDFT().find_and_replace_pattern(graph)
     ref_graph = build_graph(
         nodes_attrs=ref_converted_ifft_graph_node_attrs,
         edges=ref_converted_ifft_graph_edges,
         update_attributes={'placeholder': {
             'shape': input_shape
         }})
     (flag, resp) = compare_graphs(graph, ref_graph, 'output')
     self.assertTrue(flag, resp)
コード例 #30
0
    def test_pool_v2_to_attributed_pool(self):
        nodes = {
            **shaped_const_with_data('input', int64_array([200, 200])),
            **valued_const_with_data('windows', int64_array([4, 4])),
            **valued_const_with_data('strides', int64_array([4, 4])),

            **regular_op_with_empty_data('pool_v2', {'op': 'PoolingV2',
                                                     'pad': [2, 2],
                                                     'spatial_dims': [1, 2],
                                                     'auto_pad': 'same_upper',
                                                     'output_spatial_shape': [2, 3],
                                                     'pad_spatial_shape': [1, 2],
                                                     'pool_method': 'max',
                                                     'permute_attrs': None}),

            **regular_op_with_empty_data('pool_v1', {'type': 'Pooling',
                                                     'pad': [2, 2],
                                                     'spatial_dims': [1, 2],
                                                     'auto_pad': 'same_upper',
                                                     'output_spatial_shape': [2, 3],
                                                     'pad_spatial_shape': [1, 2],
                                                     'pool_method': 'max'}),

            **result('output')
        }

        edges = [
            *connect('input', 'pool_v2:0'),
            *connect('windows', 'pool_v2:1'),
            *connect('strides', 'pool_v2:2'),
            *connect('pool_v2', 'output'),
        ]

        graph = build_graph(nodes, edges, nodes_with_edges_only=True)
        PoolV2ToAttributedPool().find_and_replace_pattern(graph)

        ref_graph = build_graph(nodes, [*connect('input', 'pool_v1'), *connect('pool_v1', 'output')],
                                nodes_with_edges_only=True)
        (flag, resp) = compare_graphs(graph, ref_graph, 'output')
        self.assertTrue(flag, resp)