def strip_pruning_vars(checkpoint_dir, output_node_names, output_dir,
                       filename):
    """Remove pruning-related auxiliary variables and ops from the graph.

  Accepts training checkpoints and produces a GraphDef in which the pruning vars
  and ops have been removed.

  Args:
    checkpoint_dir: Path to the checkpoints.
    output_node_names: The name of the output nodes, comma separated.
    output_dir: Directory where to write the graph.
    filename: Output GraphDef file name.

  Returns:
    None

  Raises:
    ValueError: if output_nodes_names are not provided.
  """
    if not output_node_names:
        raise ValueError(
            'Need to specify atleast 1 output node through output_node_names flag'
        )
    output_node_names = output_node_names.replace(' ', '').split(',')

    initial_graph_def = strip_pruning_vars_lib.graph_def_from_checkpoint(
        checkpoint_dir, output_node_names)

    final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn(
        initial_graph_def, output_node_names)
    tf.io.write_graph(final_graph_def, output_dir, filename, as_text=False)
    tf.logging.info('\nFinal graph written to %s',
                    os.path.join(output_dir, filename))
示例#2
0
    def _get_final_outputs(self, output_tensor_names_list):
        self.final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn(
            self.initial_graph_def, _get_node_names(output_tensor_names_list))
        _ = tf.graph_util.import_graph_def(self.final_graph_def, name="final")

        with self.test_session(self.final_graph) as sess2:
            final_outputs = self._get_outputs(sess2,
                                              self.final_graph,
                                              output_tensor_names_list,
                                              graph_prefix="final")
        return final_outputs