示例#1
0
    def test_scale_input(self):
        graph_ref = build_graph(
            nodes, [
                *connect('parameter', '0:mul_scale'),
                *connect('scale', '1:mul_scale'),
                *connect('mul_scale', 'result'),
            ], {
                'scale': {
                    'shape': [1, 1, 1, 1],
                    'value': np.array(1 / 255)
                },
                'scale_d': {
                    'shape': [1, 1, 1, 1],
                    'value': np.array(1 / 255)
                }
            })

        graph = build_graph(nodes,
                            connect('parameter', 'result'),
                            nodes_with_edges_only=True,
                            cli=Namespace(scale=255))
        self.set_graph_attrs(graph, ['parameter'])
        self.set_graph_attrs(graph_ref, ['parameter'])
        graph.graph['layout'] = 'NCHW'

        ScaleInput().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)
        self.check_graph_attrs(graph, graph_ref, ['parameter'])
示例#2
0
    def test_scale_input_1(self):
        graph = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'op_output')],
            {'placeholder_1': {
                'shape': np.array([1, 3, 224, 224])
            }},
            nodes_with_edges_only=True)

        graph_ref = build_graph(nodes_attributes,
                                [('placeholder_1', 'mul_1_data'),
                                 ('mul_1_data', 'mul_1'), ('mul_1_w', 'mul_1'),
                                 ('mul_1', 'placeholder_1_data'),
                                 ('placeholder_1_data', 'op_output')], {
                                     'mul_1_w': {
                                         'shape': np.array([1, 1, 1]),
                                         'value': np.array([1 / 255])
                                     }
                                 },
                                nodes_with_edges_only=True)
        graph.graph['layout'] = 'NCHW'
        graph.graph['cmd_params'] = Namespace(scale=255)
        ScaleInput().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1_data')
        self.assertTrue(flag, resp)
示例#3
0
    def test_scale_input_2(self):
        graph_ref = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True)
        graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=1))
        self.set_graph_attrs(graph, ['parameter'])
        self.set_graph_attrs(graph_ref, ['parameter'])
        graph.graph['layout'] = 'NCHW'

        ScaleInput().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)
        self.check_graph_attrs(graph, graph_ref, ['parameter'])
示例#4
0
    def test_scale_input_2(self):
        graph = build_graph(nodes_attributes,
                            [('placeholder_1', 'placeholder_1_data'),
                             ('placeholder_1_data', 'op_output')],
                            nodes_with_edges_only=True)

        graph_ref = build_graph(nodes_attributes,
                                [('placeholder_1', 'placeholder_1_data'),
                                 ('placeholder_1_data', 'op_output')],
                                nodes_with_edges_only=True)
        graph.graph['cmd_params'] = Namespace(scale=1)
        ScaleInput().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1_data')
        self.assertTrue(flag, resp)