Example #1
0
    def test_div_test_2(self):
        # Test with two same inputs from one placeholder
        graph = build_graph(nodes, [
            *connect('placeholder_1:0', '0:div'),
            *connect_data('placeholder_1:0', '1:div'),
            *connect('div', 'output'),
        ],
                            nodes_with_edges_only=True)
        Div().find_and_replace_pattern(graph)

        graph_ref = build_graph(nodes, [
            *connect('placeholder_1:0', '0:mul'),
            *connect_data('placeholder_1:0', '0:reciprocal'),
            *connect('minus_one', '1:reciprocal'),
            *connect('reciprocal', '1:mul'),
            *connect('mul', 'output'),
        ],
                                nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
        self.assertTrue(graph.node[graph.get_nodes_with_attributes(
            type='Multiply')[0]]['name'] == 'my_div')
Example #2
0
    def test_sub_test_2(self):
        # Test with two same inputs from one placeholder
        graph = build_graph(nodes, [
            *connect('placeholder_1:0', '0:sub'),
            *connect_data('placeholder_1:0', '1:sub'),
            *connect('sub', 'output'),
        ],
                            nodes_with_edges_only=True)
        Sub().find_and_replace_pattern(graph)

        graph_ref = build_graph(nodes, [
            *connect('placeholder_1:0', '0:add'),
            *connect_data('placeholder_1:0', '0:negate'),
            *connect('minus_one', '1:negate'),
            *connect('negate', '1:add'),
            *connect('add', 'output'),
        ],
                                nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
        self.assertTrue(graph.node[graph.get_nodes_with_attributes(
            type='Add')[0]]['name'] == 'my_sub')
Example #3
0
    def test_floor_div_test_2(self):
        # Test with two same inputs from one placeholder
        graph = build_graph(nodes, [
            *connect('placeholder_1:0', '0:floor_div'),
            *connect_data('placeholder_1:0', '1:floor_div'),
            *connect('floor_div', 'output'),
        ],
                            nodes_with_edges_only=True)
        FloorDivDecomposition().find_and_replace_pattern(graph)

        graph_ref = build_graph(nodes, [
            *connect('placeholder_1:0', '0:div'),
            *connect_data('placeholder_1:0', '1:div'),
            *connect('div', 'floor'),
            *connect('floor', 'output'),
        ],
                                nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
        self.assertTrue(graph.node[graph.get_nodes_with_attributes(
            type='Floor')[0]]['name'] == 'my_floor_div')
Example #4
0
    def test_mean_values_explicit_and_scale_values_explicit_with_shape_of(self):
        graph_ref = build_graph(nodes,
                                [
                                    *connect('parameter', '0:add_mean'),
                                    *connect('mean', '1:add_mean'),
                                    *connect('add_mean', '0:mul_scale'),
                                    *connect('scale', '1:mul_scale'),
                                    *connect('mul_scale', 'result'),
                                    *connect_data('parameter', 'shape_of'),
                                    *connect('shape_of', 'result_2'),
                                ],
                                nodes_with_edges_only=True)

        argv = Namespace(
            mean_scale_values={'parameter': {'mean': np.array([1, 2, 3]), 'scale': np.array([1, 2, 3])}})
        graph = build_graph(nodes,
                            [
                                *connect('parameter', 'result'),
                                *connect_data('parameter', 'shape_of'),
                                *connect('shape_of', 'result_2'),
                            ],
                            nodes_with_edges_only=True, cli=argv)
        self.set_graph_attrs(graph, ['parameter'])
        self.set_graph_attrs(graph_ref, ['parameter'])
        graph.graph['layout'] = 'NCHW'

        AddMeanScaleValues().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
        self.assertTrue(flag, resp)
        (flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
        self.assertTrue(flag, resp)
        self.check_graph_attrs(graph, graph_ref, ['parameter'])
Example #5
0
    def test_div_with_integer(self):
        # Test where transformation should not be applied because the divisor is integer
        graph = build_graph(
            {
                **regular_op_with_shaped_data('parameter', [1, 227, 227, 3], {
                                                  'type': 'Parameter',
                                                  'data_type': np.int32
                                              }),
                **valued_const_with_data('const',
                                         np.array([-1.], dtype=np.int32)),
                **regular_op_with_shaped_data('div', None, {
                    'op': 'Div',
                    'type': 'Divide',
                    'name': 'my_div'
                }),
                **result()
            }, [
                *connect('parameter:0', '0:div'),
                *connect_data('const:0', '1:div'),
                *connect('div', 'output'),
            ])
        graph_ref = graph.copy()
        Div().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
 def test_interpolate_concat_reshape_graph_comparison(self):
     graph = build_graph(nodes, [
         *connect('placeholder', '0:interpolate'),
         *connect('out_shape', '1:interpolate'),
         *connect('interpolate', '0:concat'),
         *connect('placeholder_1', '1:concat'),
         *connect('concat', 'output'),
     ],
                         nodes_with_edges_only=True)
     InterpolateConcat().find_and_replace_pattern(graph)
     graph.graph['cmd_params'] = Namespace(keep_shape_ops=True)
     graph.clean_up()
     graph_ref = build_graph(nodes, [
         *connect('placeholder', '0:interpolate'),
         *connect('placeholder_1', 'shape'),
         *connect('shape', '0:gather'),
         *connect('indices', '1:gather'),
         *connect('axis', '2:gather'),
         *connect('gather', '1:interpolate'),
         *connect('interpolate', '0:concat'),
         *connect_data('placeholder_1', '1:concat'),
         *connect('concat', 'output'),
     ],
                             nodes_with_edges_only=True)
     (flag, resp) = compare_graphs(graph,
                                   graph_ref,
                                   'output',
                                   check_op_attrs=True)
     self.assertTrue(flag, resp)
Example #7
0
 def test_interpolate_reshape_graph_comparison(self):
     graph = build_graph(nodes, [
         *connect('placeholder', '0:interpolate'),
         *connect('out_shape', '1:interpolate'),
         *connect('interpolate', 'output'),
     ],
                         nodes_with_edges_only=True)
     InterpolateReshapeWA().find_and_replace_pattern(graph)
     graph.clean_up()
     graph_ref = build_graph(nodes, [
         *connect('placeholder', '0:interpolate'),
         *connect_data('placeholder', 'shape'),
         *connect('shape', '0:gather'),
         *connect('indices', '1:gather'),
         *connect('axis', '2:gather'),
         *connect('gather', '0:mul'),
         *connect('multiplier', '1:mul'),
         *connect('mul', '1:interpolate'),
         *connect('interpolate', 'output'),
     ],
                             nodes_with_edges_only=True)
     (flag, resp) = compare_graphs(graph,
                                   graph_ref,
                                   'output',
                                   check_op_attrs=True)
     self.assertTrue(flag, resp)
Example #8
0
    def test_not_useless_pad_non_constant_input(self):
        nodes = {
            **regular_op_with_shaped_data('placeholder', [10, 20, 3], {'type': 'Parameter'}),
            **regular_op_with_shaped_data('shape_of_1', [3], {'type': 'ShapeOf'}),
            **regular_op_with_shaped_data('sub', [3], {'type': 'Subtract', 'op': 'Sub'}),
            **valued_const_with_data('desired_output_size', int64_array([10, 20, 3])),
            **regular_op_with_shaped_data('pad', [10, 20, 3], {'type': 'Pad', 'op': 'Pad'}),
            **valued_const_with_data('fill_value', np.array(1)),
            **result('result'),
        }
        edges = [*connect('placeholder', '0:pad'),
                 *connect('placeholder', 'shape_of_1'),
                 *connect('shape_of_1', '0:sub'),
                 *connect('desired_output_size', '1:sub'),
                 *connect('sub', '1:pad'),
                 *connect_data('sub', '2:pad'),
                 *connect('fill_value', '3:pad'),
                 *connect('pad', 'result'),
                 ]
        graph = build_graph(nodes, edges)
        RemoveUselessPad().find_and_replace_pattern(graph)
        ref_graph = build_graph(nodes, edges)

        (flag, resp) = compare_graphs(graph, ref_graph, 'result')
        self.assertTrue(flag, resp)
Example #9
0
    def test_backward_bfs_multi_consumer_data_nodes(self):
        # Placeholder-> Mul -> Result
        # Const      -/    \- Result2

        graph = build_graph(
            {
                **regular_op_with_shaped_data('parameter', [1], {
                                                  'op': 'Parameter'
                                              }),
                **valued_const_with_data('const', int64_array([5])),
                **regular_op_with_shaped_data('mul', [1], {'op': 'Mul'}),
                **result('result'),
                **result('result2'),
            }, [
                *connect('parameter', '0:mul'),
                *connect('const', '1:mul'),
                *connect('mul:0', 'result'),
                *connect_data('mul', 'result2'),
            ])

        res = common_bfs(Node(graph, 'result'), ['Mul'], ['Parameter'],
                         is_backward=True,
                         attr_to_check='op',
                         follow_multi_consumer_data_nodes=True)
        self.assertTrue(
            len(res) == 1,
            'The multi-consumer data node "mul_d" was not followed')

        res = common_bfs(Node(graph, 'result'), ['Mul'], ['Parameter'],
                         is_backward=True,
                         attr_to_check='op')
        self.assertTrue(
            len(res) == 0, 'The multi-consumer data node "mul_d" was followed')
Example #10
0
    def test_convert_slice_to_strided_slice_three_axes(self):
        graph = build_graph(
            nodes_attrs=nodes_attributes,
            edges=pattern_graph,
            update_attributes={
                'starts': {'value': int64_array([0, 0, 0]), 'shape': [3]},
                'ends': {'value': int64_array([2, 150, 150]), 'shape': [3]},
                'axes': {'value': int64_array([1, 2, 3]), 'shape': [3]},
                'axes_d': {'value': int64_array([1, 2, 3]), 'shape': [3]},
                'steps': {'value': int64_array([1, 1, 1]), 'shape': [3]},
                'steps_d': {'value': int64_array([1, 1, 1]), 'shape': [3]}
            },
            nodes_with_edges_only=True
        )

        ref_graph = build_graph(
            nodes_attrs=nodes_attributes,
            edges=pattern_ref_graph + [
                *connect('ss_begin_cast:0', '0:ss_begin_gather_0'),
                *connect('ss_begin_gather_0:0', '1:ss_begin_concat'),
                *connect_data('ss_begin_cast:0', '0:ss_begin_gather_1'),
                *connect('ss_begin_gather_1:0', '2:ss_begin_concat'),
                *connect_data('ss_begin_cast:0', '0:ss_begin_gather_2'),
                *connect('ss_begin_gather_2:0', '3:ss_begin_concat'),
                *connect('ss_begin_const_0:0', '0:ss_begin_concat'),

                *connect('ss_end_cast:0', '0:ss_end_gather_0'),
                *connect('ss_end_gather_0:0', '1:ss_end_concat'),
                *connect_data('ss_end_cast:0', '0:ss_end_gather_1'),
                *connect('ss_end_gather_1:0', '2:ss_end_concat'),
                *connect_data('ss_end_cast:0', '0:ss_end_gather_2'),
                *connect('ss_end_gather_2:0', '3:ss_end_concat'),
                *connect('ss_end_const_0:0', '0:ss_end_concat'),
            ],
            update_attributes={
                'starts': {'value': int64_array([0, 0, 0]), 'shape': [3]},
                'ends': {'value': int64_array([2, 150, 150]), 'shape': [3]},
                'ss_strides': {'value': int64_array([1, 1, 1, 1]), 'shape': [4]},
                'ss': {'begin_mask': int64_array([0, 1, 1, 1]), 'end_mask': int64_array([0, 1, 1, 1])}
            }
        )
        ConvertSlice().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True)
        self.assertTrue(flag, resp)
Example #11
0
    def test_broadcast_with_range_positive_test(self):
        graph = build_graph({
            **regular_op_with_shaped_data('shape', [2], {'type': 'Parameter'}),
            **valued_const_with_data('value', np.arange(0, 384).reshape((1, 384))),
            **regular_op_with_empty_data('bc', {'type': 'Broadcast'}),
            **result(),
        }, [
            *connect('value', '0:bc'),
            *connect('shape', '1:bc'),
            *connect('bc', 'output'),
        ], nodes_with_edges_only=True)
        ExpandRangeConstant().find_and_replace_pattern(graph)

        graph_ref = build_graph({
            **regular_op_with_shaped_data('shape', [2], {'type': 'Parameter'}),

            # start
            **valued_const_with_data('start', np.array(0)),
            # limit
            **valued_const_with_data('minus_one', np.array(-1)),
            **valued_const_with_data('zero', np.array(0)),
            **regular_op_with_empty_data('range_dim', {'type': 'Gather'}),
            # delta
            **valued_const_with_data('delta', np.array(1)),
            **regular_op_with_empty_data('range', {'type': 'Range'}),

            # keep dims
            **valued_const_with_data('axes', np.array([0])),
            **regular_op_with_empty_data('keep_shape', {'type': 'Unsqueeze'}),

            **regular_op_with_empty_data('bc', {'type': 'Broadcast'}),
            **result(),
        }, [
            *connect('start', '0:range'),
            *connect('shape', '0:range_dim'),
            *connect('minus_one', '1:range_dim'),
            *connect('zero', '2:range_dim'),
            *connect('range_dim', '1:range'),
            *connect('delta', '2:range'),
            *connect('range', '0:keep_shape'),
            *connect('axes', '1:keep_shape'),
            *connect('keep_shape', '0:bc'),
            *connect_data('shape', '1:bc'),
            *connect('bc', 'output'),
        ], nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
        self.assertTrue(flag, resp)
Example #12
0
    def test_leaky_relu_mul_multiple_consumers(self):
        # multiple consumers of Mul operation
        graph = build_graph_with_edge_attrs(nodes, edges, {})
        additional_result = Result(graph, {'name': 'result_2'}).create_node()
        Node(graph, 'mul').out_port(0).connect(additional_result.in_port(0))

        ref_nodes = {
            **regular_op_with_shaped_data('input', shape, {
                'type': 'Parameter',
                'op': 'Parameter'
            }),
            **regular_op_with_shaped_data('mul', shape, {
                'type': 'Multiply',
                'name': 'mul'
            }),
            **regular_op_with_shaped_data('max', shape, {
                'type': 'Maximum',
                'name': 'final_max'
            }),
            **valued_const_with_data('const', float_array([0.5])),
            **regular_op_with_shaped_data('leaky_relu', shape, {
                'type': 'LeakyReLU',
                'name': 'max_final',
                'negative_slope': None
            }),
            **result('result'),
            **result('result_2')
        }
        ref_edges = [
            *connect('input:0', '0:mul'), *connect('const', '1:mul'),
            *connect('max:0', 'result'), *connect('mul:0', 'result_2'),
            *connect_data('input', 'leaky_relu'),
            *connect('leaky_relu', 'result')
        ]
        graph_ref = build_graph_with_edge_attrs(ref_nodes, ref_edges)

        LeakyReLUFusion().find_and_replace_pattern(graph)
        graph.clean_up()

        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result_2')
        self.assertTrue(flag, resp)
Example #13
0
    **regular_op_with_shaped_data('mul', shape, {
        'type': 'Multiply',
        'name': 'mul'
    }),
    **regular_op_with_shaped_data('max', shape, {
        'type': 'Maximum',
        'name': 'final_max'
    }),
    **valued_const_with_data('const', float_array([0.5])),
    **result('result')
}

edges = [
    *connect('input:0', '0:mul'),
    *connect('const', '1:mul'),
    *connect_data('input', '0:max'),
    *connect('mul:0', '1:max'),
    *connect('max:0', 'result'),
]

ref_nodes = {
    **regular_op_with_shaped_data('input', shape, {
        'type': 'Parameter',
        'op': 'Parameter'
    }),
    **regular_op_with_shaped_data('leaky_relu', shape, {
        'type': 'LeakyReLU',
        'name': 'max_final',
        'negative_slope': None
    }),
    **result('result')