示例#1
0
 def test_summarize_graph(self):
     with patch('mo.front.tf.loader.open', mock_open(read_data=pbtxt)) as m:
         graph_def, _, _ = load_tf_graph_def('path', False)
         summary = summarize_graph(graph_def)
         self.assertEqual(len(summary['outputs']), 1)
         self.assertEqual(summary['outputs'][0], 'Output/Identity')
         self.assertEqual(len(summary['inputs']), 1)
         self.assertEqual('Placeholder' in summary['inputs'], True)
         self.assertEqual(str(summary['inputs']['Placeholder']['shape']), '(1,227,227,3)')
         self.assertEqual(str(summary['inputs']['Placeholder']['type']), 'float32')
示例#2
0
def get_output_node_names_list(graph_def, user_defined_output_node_names_list: list):
    return summarize_graph(graph_def)['outputs'] \
        if user_defined_output_node_names_list is None or len(user_defined_output_node_names_list) == 0 \
        else user_defined_output_node_names_list
示例#3
0
def process_graph(graph, file):
    #Disabled tests
    invalid_tests = read_tests_from_file("./invalid_tests_list.txt")
    full_test_name_placeholder = ""

    #Get outputs
    result = summarize_graph(graph.as_graph_def())
    print("This model has {} outputs".format(len(result['outputs'])))

    test_dir = ""
    test_case_name = ""
    test_names = file.split('.')
    for i in range(0, len(test_names)):
        if i == (len(test_names) - 1):
            test_case_name = test_names[i]
        test_dir += "/"
        test_dir += test_names[i]
        full_test_name_placeholder += test_names[i] + '_'

    #Process each subgraph (test case has multiple iterations)
    for output in result['outputs']:

        index = str(result['outputs'].index(output))
        out_dir = "./pbfiles/" + test_dir
        out_file = test_case_name + '-' + index + '.pb'

        sub_test = full_test_name_placeholder + test_case_name + '-' + index

        print("Testing --- ", sub_test)
        if sub_test in invalid_tests:
            print("Skipping test: {}".format(sub_test))
            continue

        print("File log {} exists?: {}".format(
            out_dir + '/' + out_file,
            os.path.exists(out_dir + '/' + out_file)))

        if not os.path.exists(out_dir + '/' + out_file):

            graph_def = tf.compat.v1.graph_util.extract_sub_graph(
                graph.as_graph_def(), [output])

            if not skipTest(graph_def):
                #Add placeholder node, get new graph with placeholder node added & const op name
                new_graph, node_to_replace = add_placeholder(graph_def)
                new_graph_def = new_graph.as_graph_def()

                if not len(node_to_replace) == 0:
                    replace_input_to_placeholder(new_graph_def,
                                                 node_to_replace)

                #Write the new graph
                nodes = [node for node in new_graph_def.node]
                if len(nodes) > 1:
                    new_graph_def = tf.compat.v1.graph_util.extract_sub_graph(
                        new_graph_def, [output])
                    tf.io.write_graph(new_graph_def,
                                      out_dir,
                                      out_file,
                                      as_text=False)
                else:
                    print(nodes)
                    print("Skipping graphs with just Placeholder/Const node")