예제 #1
0
    def replace_pattern(self, graph: Graph, match: dict):
        scale = graph.graph['cmd_params'].scale
        if scale is None or scale == 1:
            return
        assert (len(match['placeholder'].out_nodes()))

        AddMeanScaleValues.apply_scale(graph, match['placeholder'],
                                       {'scale': np.array([scale])})
예제 #2
0
    def test_mean_values_without_data_name(self):
        graph_ref = build_graph(nodes, [
            *connect('parameter', '0:add_mean'),
            *connect('mean', '1:add_mean'),
            *connect('add_mean', 'result'),
        ], {'parameter': {
            'name': 'None'
        }})

        mean_values = parse_tuple_pairs('(1,2,3)')
        scale_values = parse_tuple_pairs('')
        mean_scale = get_mean_scale_dictionary(mean_values, scale_values, None)
        argv = Namespace(mean_scale_values=mean_scale)

        graph = build_graph(nodes, [*connect('parameter', 'result')],
                            {'parameter': {
                                'name': 'None'
                            }},
                            nodes_with_edges_only=True,
                            cli=argv)
        self.set_graph_attrs(graph, ['None'])
        self.set_graph_attrs(graph_ref, ['None'])
        graph.graph['layout'] = 'NCHW'

        AddMeanScaleValues().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'result',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
        self.check_graph_attrs(graph, graph_ref, ['None'])
예제 #3
0
    def test_debug_info_absence(self):
        graph_ref = build_graph(nodes, [
            *connect('parameter', '0:add_mean'),
            *connect('mean', '1:add_mean'),
            *connect('add_mean', '0:mul_scale'),
            *connect('scale', '1:mul_scale'),
            *connect('mul_scale', 'result'),
        ])

        argv = Namespace(mean_scale_values=[[
            np.array([1., 2., 3.]),
            np.array([1., 2., 3.])
        ]])
        graph = build_graph(nodes, [*connect('parameter', 'result')],
                            nodes_with_edges_only=True,
                            cli=argv)
        graph.graph['layout'] = 'NCHW'

        AddMeanScaleValues().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'result',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
        self.check_graph_attrs(graph, graph_ref, [])
예제 #4
0
    def test_mean_values_with_colon_in_node_name_and_port(self):
        graph_ref = build_graph(nodes, [
            *connect('parameter', '0:add_mean'),
            *connect('mean', '1:add_mean'),
            *connect('add_mean', 'result'),
        ])

        argv = Namespace(mean_scale_values={
            '0:param:0': {
                'scale': np.array([1.]),
                'mean': np.array([1., 2., 3.])
            }
        })
        graph = build_graph(nodes, [*connect('parameter', 'result')], {
            'parameter': {
                'name': 'param:0',
                'id': 'param:0/placeholder_0',
                'initial_node_name': 'param:0'
            }
        },
                            nodes_with_edges_only=True,
                            cli=argv)
        self.set_graph_attrs(graph, ['parameter'])
        self.set_graph_attrs(graph_ref, ['parameter'])
        graph.graph['layout'] = 'NCHW'

        AddMeanScaleValues().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'result',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
예제 #5
0
    def test_mean_values_optimized_and_scale_values_explicit(self):
        graph_ref = build_graph(nodes, [
            *connect('parameter', '0:mul_scale'),
            *connect('scale', '1:mul_scale'),
            *connect('mul_scale', 'result'),
        ])

        argv = Namespace(
            mean_scale_values={
                'parameter': {
                    'scale': np.array([1., 2., 3.]),
                    'mean': np.array([0., 0., 0.])
                }
            })
        graph = build_graph(nodes, [*connect('parameter', 'result')],
                            nodes_with_edges_only=True,
                            cli=argv)
        self.set_graph_attrs(graph, ['parameter'])
        self.set_graph_attrs(graph_ref, ['parameter'])
        graph.graph['layout'] = 'NCHW'

        AddMeanScaleValues().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'result',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
        self.check_graph_attrs(graph, graph_ref, ['parameter'])
예제 #6
0
    def test_add_mean_scale_values3(self):
        graph = build_graph(nodes_attributes, [('pl_1', 'pl_1_data')], {
            'pl_1_data': {
                'shape': np.array([1, 3, 38, 38]),
                'infer': None
            },
            'pl_1': {
                'shape': np.array([1, 3, 38, 38])
            },
        },
                            nodes_with_edges_only=True)
        graph.graph['layout'] = 'NCHW'
        argv = Namespace(mean_scale_values=[[
            np.array([1., 2., 3.]),
            np.array([1., 2., 3.])
        ]])
        graph.graph['cmd_params'] = argv
        AddMeanScaleValues().find_and_replace_pattern(graph)

        mul_op_cnt = 0
        add_op_cnt = 0
        for node in graph.nodes():
            node = Node(graph, node)
            if node.has_valid('op') and node.op == 'Mul':
                mul_op_cnt += 1
            if node.has_valid('op') and node.op == 'Add':
                add_op_cnt += 1

        self.assertEqual(add_op_cnt, 1, "Found more than one Add op in graph")
        self.assertEqual(mul_op_cnt, 1, "Found more than one Nul op in graph")
예제 #7
0
 def test_add_mean_scale_values_without_data_name(self):
     graph = build_graph(nodes_attributes, [('node_1', 'node_2'),
                                            ('node_2', 'op_output')],
                         {
                             'node_2': {
                                 'shape': None,
                                 'data_type': None
                             },
                             'node_1': {
                                 'shape': np.array([1, 3, 227, 227]),
                                 'op': 'Parameter',
                                 'name': 'data',
                                 'data_type': None
                             }
                         },
                         nodes_with_edges_only=True)
     graph.graph['layout'] = 'NCHW'
     mean_values = parse_tuple_pairs('(124,117,104)')
     scale_values = parse_tuple_pairs('')
     # input = None
     mean_scale = get_mean_scale_dictionary(mean_values, scale_values, None)
     argv = Namespace(mean_scale_values=mean_scale)
     graph.graph['cmd_params'] = argv
     self.assertEqual(len(graph), 3)
     AddMeanScaleValues().find_and_replace_pattern(graph)
     self.assertEqual(len(graph), 6)
예제 #8
0
    def test_mean_values_explicit_and_scale_values_explicit_on_cutted_graph(self):
        """
        Test case when user cutted start of the network and specified mean/scale value to the new input node 'node_3'.
        """
        graph_ref = build_graph(nodes, [
            *connect('parameter', '0:add_mean'),
            *connect('mean', '1:add_mean'),
            *connect('add_mean', 'result'),

            *connect('parameter_2', '0:mul_scale'),
            *connect('scale', '1:mul_scale'),
            *connect('mul_scale', 'op'),
            *connect('op', 'result_2'),
        ])

        argv = Namespace(
            mean_scale_values={'parameter': {'mean': np.array([1, 2, 3])}, 'op': {'scale': np.array([1, 2, 3])}})
        graph = build_graph(
            nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'op'), *connect('op', 'result_2')],
            {'parameter_2': {'initial_node_name': 'op'}}, nodes_with_edges_only=True, cli=argv)
        self.set_graph_attrs(graph, ['parameter', 'parameter_2'])
        self.set_graph_attrs(graph_ref, ['parameter', 'parameter_2'])
        graph.graph['layout'] = 'NCHW'
        AddMeanScaleValues().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
        self.assertTrue(flag, resp)
        (flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
        self.assertTrue(flag, resp)
        self.check_graph_attrs(graph, graph_ref, ['parameter', 'parameter_2'])
예제 #9
0
    def test_add_mean_scale_values_cut_graph(self):
        """
        Test case when user cutted start of the network and specified mean/scale value to the new input node 'node_3'.
        """
        graph = build_graph(nodes_attributes, [
            ('pl_1', 'pl_1_data'),
            ('pl_2', 'pl_2_data'),
            ('pl_2_data', 'node_3'),
            ('node_3', 'node_3_data'),
            ('pl_1_data', 'node_1'),
            ('node_3_data', 'node_1'),
        ], {
            'pl_1_data': {
                'shape': np.array([1, 3, 38, 38]),
                'infer': None
            },
            'pl_2_data': {
                'shape': np.array([1, 3, 38, 38]),
                'infer': None
            },
            'pl_2': {
                'initial_node_name': 'node_3',
                'shape': np.array([1, 3, 38, 38])
            },
            'pl_1': {
                'shape': np.array([1, 3, 38, 38])
            },
        },
                            nodes_with_edges_only=True)
        graph.graph['layout'] = 'NCHW'
        argv = Namespace(
            mean_scale_values={
                'pl_1': {
                    'mean': np.array([1, 2, 3])
                },
                'node_3': {
                    'scale': np.array([1, 2, 3])
                }
            })
        graph.graph['cmd_params'] = argv
        AddMeanScaleValues().find_and_replace_pattern(graph)

        mul_op_cnt = 0
        add_op_cnt = 0
        for node in graph.nodes():
            node = Node(graph, node)
            if node.has_valid('op') and node.op == 'Mul':
                mul_op_cnt += 1
            if node.has_valid('op') and node.op == 'Add':
                add_op_cnt += 1

        self.assertEqual(add_op_cnt, 1, "There should be exactly one Add op")
        self.assertEqual(mul_op_cnt, 1, "There should be exactly one Mul op")
        self.assertEqual(
            Node(graph, 'pl_2').out_node().out_node().op, 'Mul',
            "The Mul op should be added after pl_2")
        self.assertEqual(
            Node(graph, 'pl_1').out_node().out_node().op, 'Add',
            "The Add op should be added after pl_1")
예제 #10
0
    def test_mean_values_explicit_and_scale_values_explicit_with_shape_of(
            self):
        graph_ref = build_graph(nodes, [
            *connect('parameter', '0:add_mean'),
            *connect('mean', '1:add_mean'),
            *connect('add_mean', '0:mul_scale'),
            *connect('scale', '1:mul_scale'),
            *connect('mul_scale', 'result'),
            *connect_data('parameter', 'shape_of'),
            *connect('shape_of', 'result_2'),
        ],
                                nodes_with_edges_only=True)

        argv = Namespace(
            mean_scale_values={
                'parameter': {
                    'mean': np.array([1, 2, 3]),
                    'scale': np.array([1, 2, 3])
                }
            })
        graph = build_graph(nodes, [
            *connect('parameter', 'result'),
            *connect_data('parameter', 'shape_of'),
            *connect('shape_of', 'result_2'),
        ],
                            nodes_with_edges_only=True,
                            cli=argv)
        graph.graph['layout'] = 'NCHW'

        AddMeanScaleValues().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'result',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'result_2',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)