def strip_pruning_vars(checkpoint_dir, output_node_names, output_dir, filename): """Remove pruning-related auxiliary variables and ops from the graph. Accepts training checkpoints and produces a GraphDef in which the pruning vars and ops have been removed. Args: checkpoint_dir: Path to the checkpoints. output_node_names: The name of the output nodes, comma separated. output_dir: Directory where to write the graph. filename: Output GraphDef file name. Returns: None Raises: ValueError: if output_nodes_names are not provided. """ if not output_node_names: raise ValueError( 'Need to specify atleast 1 output node through output_node_names flag' ) output_node_names = output_node_names.replace(' ', '').split(',') initial_graph_def = strip_pruning_vars_lib.graph_def_from_checkpoint( checkpoint_dir, output_node_names) final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn( initial_graph_def, output_node_names) tf.io.write_graph(final_graph_def, output_dir, filename, as_text=False) tf.logging.info('\nFinal graph written to %s', os.path.join(output_dir, filename))
def _get_final_outputs(self, output_tensor_names_list): self.final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn( self.initial_graph_def, _get_node_names(output_tensor_names_list)) _ = tf.graph_util.import_graph_def(self.final_graph_def, name="final") with self.test_session(self.final_graph) as sess2: final_outputs = self._get_outputs(sess2, self.final_graph, output_tensor_names_list, graph_prefix="final") return final_outputs