def test_summarize_graph(self): with patch('mo.front.tf.loader.open', mock_open(read_data=pbtxt)) as m: graph_def, _, _ = load_tf_graph_def('path', False) summary = summarize_graph(graph_def) self.assertEqual(len(summary['outputs']), 1) self.assertEqual(summary['outputs'][0], 'Output/Identity') self.assertEqual(len(summary['inputs']), 1) self.assertEqual('Placeholder' in summary['inputs'], True) self.assertEqual(str(summary['inputs']['Placeholder']['shape']), '(1,227,227,3)') self.assertEqual(str(summary['inputs']['Placeholder']['type']), 'float32')
def get_output_node_names_list(graph_def, user_defined_output_node_names_list: list): return summarize_graph(graph_def)['outputs'] \ if user_defined_output_node_names_list is None or len(user_defined_output_node_names_list) == 0 \ else user_defined_output_node_names_list
def process_graph(graph, file): #Disabled tests invalid_tests = read_tests_from_file("./invalid_tests_list.txt") full_test_name_placeholder = "" #Get outputs result = summarize_graph(graph.as_graph_def()) print("This model has {} outputs".format(len(result['outputs']))) test_dir = "" test_case_name = "" test_names = file.split('.') for i in range(0, len(test_names)): if i == (len(test_names) - 1): test_case_name = test_names[i] test_dir += "/" test_dir += test_names[i] full_test_name_placeholder += test_names[i] + '_' #Process each subgraph (test case has multiple iterations) for output in result['outputs']: index = str(result['outputs'].index(output)) out_dir = "./pbfiles/" + test_dir out_file = test_case_name + '-' + index + '.pb' sub_test = full_test_name_placeholder + test_case_name + '-' + index print("Testing --- ", sub_test) if sub_test in invalid_tests: print("Skipping test: {}".format(sub_test)) continue print("File log {} exists?: {}".format( out_dir + '/' + out_file, os.path.exists(out_dir + '/' + out_file))) if not os.path.exists(out_dir + '/' + out_file): graph_def = tf.compat.v1.graph_util.extract_sub_graph( graph.as_graph_def(), [output]) if not skipTest(graph_def): #Add placeholder node, get new graph with placeholder node added & const op name new_graph, node_to_replace = add_placeholder(graph_def) new_graph_def = new_graph.as_graph_def() if not len(node_to_replace) == 0: replace_input_to_placeholder(new_graph_def, node_to_replace) #Write the new graph nodes = [node for node in new_graph_def.node] if len(nodes) > 1: new_graph_def = tf.compat.v1.graph_util.extract_sub_graph( new_graph_def, [output]) tf.io.write_graph(new_graph_def, out_dir, out_file, as_text=False) else: print(nodes) print("Skipping graphs with just Placeholder/Const node")