def test_scale_input(self): graph_ref = build_graph( nodes, [ *connect('parameter', '0:mul_scale'), *connect('scale', '1:mul_scale'), *connect('mul_scale', 'result'), ], { 'scale': { 'shape': [1, 1, 1, 1], 'value': np.array(1 / 255) }, 'scale_d': { 'shape': [1, 1, 1, 1], 'value': np.array(1 / 255) } }) graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=255)) self.set_graph_attrs(graph, ['parameter']) self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' ScaleInput().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) self.check_graph_attrs(graph, graph_ref, ['parameter'])
def test_scale_input_1(self): graph = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'op_output')], {'placeholder_1': { 'shape': np.array([1, 3, 224, 224]) }}, nodes_with_edges_only=True) graph_ref = build_graph(nodes_attributes, [('placeholder_1', 'mul_1_data'), ('mul_1_data', 'mul_1'), ('mul_1_w', 'mul_1'), ('mul_1', 'placeholder_1_data'), ('placeholder_1_data', 'op_output')], { 'mul_1_w': { 'shape': np.array([1, 1, 1]), 'value': np.array([1 / 255]) } }, nodes_with_edges_only=True) graph.graph['layout'] = 'NCHW' graph.graph['cmd_params'] = Namespace(scale=255) ScaleInput().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1_data') self.assertTrue(flag, resp)
def test_scale_input_2(self): graph_ref = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True) graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=1)) self.set_graph_attrs(graph, ['parameter']) self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' ScaleInput().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) self.check_graph_attrs(graph, graph_ref, ['parameter'])
def test_scale_input_2(self): graph = build_graph(nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'op_output')], nodes_with_edges_only=True) graph_ref = build_graph(nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'op_output')], nodes_with_edges_only=True) graph.graph['cmd_params'] = Namespace(scale=1) ScaleInput().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1_data') self.assertTrue(flag, resp)