def replace_pattern(self, graph: Graph, match: dict): scale = graph.graph['cmd_params'].scale if scale is None or scale == 1: return assert (len(match['placeholder'].out_nodes())) AddMeanScaleValues.apply_scale(graph, match['placeholder'], {'scale': mo_array([scale])})
def test_mean_values_with_data_name(self): graph_ref = build_graph(nodes, [ *connect('parameter', '0:add_mean'), *connect('mean', '1:add_mean'), *connect('add_mean', 'result'), ]) mean_values = parse_tuple_pairs('(1,2,3)') scale_values = parse_tuple_pairs('') mean_scale = get_mean_scale_dictionary(mean_values, scale_values, None) argv = Namespace(mean_scale_values=mean_scale) graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv) self.set_graph_attrs(graph, ['parameter']) self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp) self.check_graph_attrs(graph, graph_ref, ['parameter'])
def test_mean_values_explicit_and_scale_values_explicit_layout(self): graph_ref = build_graph(nodes, [ *connect('parameter', '0:add_mean'), *connect('mean', '1:add_mean'), *connect('add_mean', '0:mul_scale'), *connect('scale', '1:mul_scale'), *connect('mul_scale', 'result'), ]) argv = Namespace(mean_scale_values=[[ np.array([1., 2., 3.]), np.array([1., 2., 3.]) ]], layout_values={ '': { 'source_layout': 'nchw', 'target_layout': None } }) graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv) self.set_graph_attrs(graph, ['parameter']) self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NHWC' AddMeanScaleValues().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp) self.check_graph_attrs(graph, graph_ref, ['parameter'])
def test_insert_add_mean_scale_after_convert_different_type(self): graph_ref = build_graph(nodes, [ *connect('parameter', 'convert'), *connect('convert', '0:add_mean'), *connect('mean', '1:add_mean'), *connect('add_mean', '0:mul_scale'), *connect('scale', '1:mul_scale'), *connect('mul_scale', 'result'), ]) argv = Namespace(mean_scale_values=[[ np.array([1., 2., 3.]), np.array([1., 2., 3.]) ]]) graph = build_graph( nodes, [*connect('parameter', 'convert'), *connect('convert', 'result')], nodes_with_edges_only=True, cli=argv) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp) self.check_graph_attrs(graph, graph_ref, []) add_node = graph.get_op_nodes(type="Add")[0] self.assertTrue( add_node.in_port(1).get_connection().get_source().node['value']. dtype == np.float32)
def test_insert_add_mean_scale_after_convert(self): graph_ref = build_graph(nodes, [ *connect('parameter', 'convert'), *connect('convert', '0:add_mean'), *connect('mean', '1:add_mean'), *connect('add_mean', '0:mul_scale'), *connect('scale', '1:mul_scale'), *connect('mul_scale', 'result'), ]) argv = Namespace(mean_scale_values=[[ np.array([1., 2., 3.]), np.array([1., 2., 3.]) ]]) graph = build_graph( nodes, [*connect('parameter', 'convert'), *connect('convert', 'result')], nodes_with_edges_only=True, cli=argv) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp) self.check_graph_attrs(graph, graph_ref, [])
def test_mean_values_with_colon_in_node_name_and_port(self): graph_ref = build_graph(nodes, [ *connect('parameter', '0:add_mean'), *connect('mean', '1:add_mean'), *connect('add_mean', 'result'), ]) argv = Namespace(mean_scale_values={ '0:param:0': { 'scale': np.array([1.]), 'mean': np.array([1., 2., 3.]) } }) graph = build_graph(nodes, [*connect('parameter', 'result')], { 'parameter': { 'name': 'param:0', 'id': 'param:0/placeholder_0', 'initial_node_name': 'param:0' } }, nodes_with_edges_only=True, cli=argv) self.set_graph_attrs(graph, ['parameter']) self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp)
def test_mean_values_optimized_and_scale_values_explicit(self): graph_ref = build_graph(nodes, [ *connect('parameter', '0:mul_scale'), *connect('scale', '1:mul_scale'), *connect('mul_scale', 'result'), ]) argv = Namespace( mean_scale_values={ 'parameter': { 'scale': np.array([1., 2., 3.]), 'mean': np.array([0., 0., 0.]) } }) graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv) self.set_graph_attrs(graph, ['parameter']) self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp) self.check_graph_attrs(graph, graph_ref, ['parameter'])
def test_mean_values_explicit_and_scale_values_explicit_on_cutted_graph(self): """ Test case when user cutted start of the network and specified mean/scale value to the new input node 'node_3'. """ graph_ref = build_graph(nodes, [ *connect('parameter', '0:add_mean'), *connect('mean', '1:add_mean'), *connect('add_mean', 'result'), *connect('parameter_2', '0:mul_scale'), *connect('scale', '1:mul_scale'), *connect('mul_scale', 'op'), *connect('op', 'result_2'), ]) argv = Namespace( mean_scale_values={'parameter': {'mean': np.array([1, 2, 3])}, 'op': {'scale': np.array([1, 2, 3])}}) graph = build_graph( nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'op'), *connect('op', 'result_2')], {'parameter_2': {'initial_node_name': 'op'}}, nodes_with_edges_only=True, cli=argv) self.set_graph_attrs(graph, ['parameter', 'parameter_2']) self.set_graph_attrs(graph_ref, ['parameter', 'parameter_2']) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp) (flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True) self.assertTrue(flag, resp) self.check_graph_attrs(graph, graph_ref, ['parameter', 'parameter_2'])