def test_scale_input(self):
        graph_ref = build_graph(
            nodes, [
                *connect('parameter', '0:mul_scale'),
                *connect('scale', '1:mul_scale'),
                *connect('mul_scale', 'result'),
            ], {
                'scale': {
                    'shape': [1, 1, 1, 1],
                    'value': np.array(1 / 255)
                },
                'scale_d': {
                    'shape': [1, 1, 1, 1],
                    'value': np.array(1 / 255)
                }
            })

        graph = build_graph(nodes,
                            connect('parameter', 'result'),
                            nodes_with_edges_only=True,
                            cli=Namespace(scale=255))
        self.set_graph_attrs(graph, ['parameter'])
        self.set_graph_attrs(graph_ref, ['parameter'])
        graph.graph['layout'] = 'NCHW'

        ScaleInput().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)
        self.check_graph_attrs(graph, graph_ref, ['parameter'])
Beispiel #2
0
    def test_scale_input_2(self):
        graph_ref = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True)
        graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=1))
        self.set_graph_attrs(graph, ['parameter'])
        self.set_graph_attrs(graph_ref, ['parameter'])
        graph.graph['layout'] = 'NCHW'

        ScaleInput().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)
        self.check_graph_attrs(graph, graph_ref, ['parameter'])