def test_scale_input(self): graph_ref = build_graph( nodes, [ *connect('parameter', '0:mul_scale'), *connect('scale', '1:mul_scale'), *connect('mul_scale', 'result'), ], { 'scale': { 'shape': [1, 1, 1, 1], 'value': np.array(1 / 255) }, 'scale_d': { 'shape': [1, 1, 1, 1], 'value': np.array(1 / 255) } }) graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=255)) self.set_graph_attrs(graph, ['parameter']) self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' ScaleInput().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) self.check_graph_attrs(graph, graph_ref, ['parameter'])
def test_scale_input_2(self): graph_ref = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True) graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=1)) self.set_graph_attrs(graph, ['parameter']) self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' ScaleInput().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) self.check_graph_attrs(graph, graph_ref, ['parameter'])