def testContinueFromUnmanaged(self):
   directory = self.get_temp_dir()
   prefix = os.path.join(directory, "unusual_prefix")
   checkpoint = util.Checkpoint()
   first_path = checkpoint.save(prefix)
   second_path = checkpoint.save(prefix)
   del checkpoint
   checkpoint = util.Checkpoint()
   manager = checkpoint_management.CheckpointManager(
       checkpoint, directory, max_to_keep=2)
   checkpoint.restore(manager.latest_checkpoint).run_restore_ops()
   self.assertEqual(2, self.evaluate(checkpoint.save_counter))
   third_path = manager.save()
   self.assertEqual([third_path], manager.checkpoints)
   fourth_path = manager.save()
   self.assertEqual([third_path, fourth_path],
                    manager.checkpoints)
   fifth_path = manager.save()
   self.assertEqual([fourth_path, fifth_path],
                    manager.checkpoints)
   self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
   self.assertFalse(checkpoint_management.checkpoint_exists(third_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
 def testContinueFromUnmanaged(self):
   directory = self.get_temp_dir()
   prefix = os.path.join(directory, "unusual_prefix")
   checkpoint = util.Checkpoint()
   first_path = checkpoint.save(prefix)
   second_path = checkpoint.save(prefix)
   del checkpoint
   checkpoint = util.Checkpoint()
   manager = checkpoint_management.CheckpointManager(
       checkpoint, directory, max_to_keep=2)
   checkpoint.restore(manager.latest_checkpoint).run_restore_ops()
   self.assertEqual(2, self.evaluate(checkpoint.save_counter))
   third_path = manager.save()
   self.assertEqual([third_path], manager.checkpoints)
   fourth_path = manager.save()
   self.assertEqual([third_path, fourth_path],
                    manager.checkpoints)
   fifth_path = manager.save()
   self.assertEqual([fourth_path, fifth_path],
                    manager.checkpoints)
   self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
   self.assertFalse(checkpoint_management.checkpoint_exists(third_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
 def testDeletion(self):
   checkpoint = util.Checkpoint()
   manager = checkpoint_management.CheckpointManager(
       checkpoint, self.get_temp_dir(), max_to_keep=3)
   first_path = manager.save()
   second_path = manager.save()
   third_path = manager.save()
   fourth_path = manager.save()
   self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
   self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
 def testDeletion(self):
   checkpoint = util.Checkpoint()
   manager = checkpoint_management.CheckpointManager(
       checkpoint, self.get_temp_dir(), max_to_keep=3)
   first_path = manager.save()
   second_path = manager.save()
   third_path = manager.save()
   fourth_path = manager.save()
   self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
   self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
  def testRemoveCheckpoint(self):
    for sharded in (False, True):
      for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
        with self.session(graph=ops_lib.Graph()) as sess:
          unused_v = variables.Variable(1.0, name="v")
          variables.global_variables_initializer().run()
          saver = saver_module.Saver(sharded=sharded, write_version=version)

          path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
          ckpt_prefix = saver.save(sess, path)
          self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
          checkpoint_management.remove_checkpoint(ckpt_prefix, version)
          self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))
  def testRemoveCheckpoint(self):
    for sharded in (False, True):
      for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
        with self.test_session(graph=ops_lib.Graph()) as sess:
          unused_v = variables.Variable(1.0, name="v")
          variables.global_variables_initializer().run()
          saver = saver_module.Saver(sharded=sharded, write_version=version)

          path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
          ckpt_prefix = saver.save(sess, path)
          self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
          checkpoint_management.remove_checkpoint(ckpt_prefix, version)
          self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))
 def testClockReset(self, mock_time):
     directory = self.get_temp_dir()
     mock_time.time.return_value = 10000.
     checkpoint = util.Checkpoint()
     first_manager = checkpoint_management.CheckpointManager(
         checkpoint,
         directory,
         max_to_keep=1,
         keep_checkpoint_every_n_hours=1.)
     first_path = first_manager.save()
     mock_time.time.return_value += 3600.
     second_path = first_manager.save()
     mock_time.time.return_value += 3600.
     third_path = first_manager.save()
     self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertEqual([third_path], first_manager.checkpoints)
     state = checkpoint_management.get_checkpoint_state(directory)
     self.assertEqual(13600., state.last_preserved_timestamp)
     # Set the clock back in time
     mock_time.time.return_value = 5000.
     del first_manager
     with test.mock.patch.object(logging, "warning") as mock_log:
         second_manager = checkpoint_management.CheckpointManager(
             checkpoint, directory, max_to_keep=1)
         self.assertRegexpMatches(
             str(mock_log.call_args),
             "behind the last preserved checkpoint timestamp")
     # We should err on the side of keeping checkpoints around when we're not
     # sure whether they were preserved or not due to clock funkiness.
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     # We know about the existing checkpoints, but they'll never be deleted and
     # so won't go in the CheckpointState proto on save.
     self.assertEqual(third_path, second_manager.latest_checkpoint)
     self.assertEqual([], second_manager.checkpoints)
     mock_time.time.return_value += 10.
     fourth_path = second_manager.save()
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertEqual(fourth_path, second_manager.latest_checkpoint)
     self.assertEqual([fourth_path], second_manager.checkpoints)
     mock_time.time.return_value += 10.
     fifth_path = second_manager.save()
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertEqual([fifth_path], second_manager.checkpoints)
     state = checkpoint_management.get_checkpoint_state(directory)
     self.assertEqual(5000., state.last_preserved_timestamp)
     self.assertEqual([5020.], state.all_model_checkpoint_timestamps)
 def testClockReset(self, mock_time):
   directory = self.get_temp_dir()
   mock_time.time.return_value = 10000.
   checkpoint = util.Checkpoint()
   first_manager = checkpoint_management.CheckpointManager(
       checkpoint, directory, max_to_keep=1, keep_checkpoint_every_n_hours=1.)
   first_path = first_manager.save()
   mock_time.time.return_value += 3600.
   second_path = first_manager.save()
   mock_time.time.return_value += 3600.
   third_path = first_manager.save()
   self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
   self.assertEqual([third_path], first_manager.checkpoints)
   state = checkpoint_management.get_checkpoint_state(directory)
   self.assertEqual(13600., state.last_preserved_timestamp)
   # Set the clock back in time
   mock_time.time.return_value = 5000.
   del first_manager
   with test.mock.patch.object(logging, "warning") as mock_log:
     second_manager = checkpoint_management.CheckpointManager(
         checkpoint, directory, max_to_keep=1)
     self.assertRegexpMatches(
         str(mock_log.call_args),
         "behind the last preserved checkpoint timestamp")
   # We should err on the side of keeping checkpoints around when we're not
   # sure whether they were preserved or not due to clock funkiness.
   self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
   # We know about the existing checkpoints, but they'll never be deleted and
   # so won't go in the CheckpointState proto on save.
   self.assertEqual(third_path, second_manager.latest_checkpoint)
   self.assertEqual([], second_manager.checkpoints)
   mock_time.time.return_value += 10.
   fourth_path = second_manager.save()
   self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
   self.assertEqual(fourth_path, second_manager.latest_checkpoint)
   self.assertEqual([fourth_path], second_manager.checkpoints)
   mock_time.time.return_value += 10.
   fifth_path = second_manager.save()
   self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
   self.assertEqual([fifth_path], second_manager.checkpoints)
   state = checkpoint_management.get_checkpoint_state(directory)
   self.assertEqual(5000., state.last_preserved_timestamp)
   self.assertEqual([5020.],
                    state.all_model_checkpoint_timestamps)
    def testCheckpointInterval(self):
        v = variables.Variable(1.0)
        step_counter = variables.Variable(0)
        self.evaluate([v.initializer, step_counter.initializer])
        checkpoint = util.Checkpoint(v=v)
        manager = checkpoint_management.CheckpointManager(
            checkpoint,
            self.get_temp_dir(),
            max_to_keep=None,
            step_counter=step_counter,
            checkpoint_interval=2)

        # step_counter: 0, save an initial checkpoint.
        path = manager.save(check_interval=True)
        self.assertTrue(checkpoint_management.checkpoint_exists(path))

        # step_counter: 1, no checkpoint saved.
        self.evaluate(step_counter.assign_add(1))
        path = manager.save(check_interval=True)
        self.assertIsNone(path)

        # step_counter: 2, checkpoint saved.
        self.evaluate(step_counter.assign_add(1))
        path = manager.save(check_interval=True)
        self.assertTrue(checkpoint_management.checkpoint_exists(path))

        # no checkpoint saved when calling `save` with the same step counter.
        path = manager.save(check_interval=True)
        self.assertIsNone(path)

        # step_counter: 3, no checkpoint saved.
        self.evaluate(step_counter.assign_add(1))
        path = manager.save(check_interval=True)
        self.assertIsNone(path)

        # Always save the checkpoint.
        path = manager.save(check_interval=False)
        self.assertTrue(checkpoint_management.checkpoint_exists(path))
  def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True):
    """Wait for a checkpoint file to appear.

    Args:
      pattern: A string.
      timeout_secs: How long to wait for in seconds.
      for_checkpoint: whether we're globbing for checkpoints.
    """
    end_time = time.time() + timeout_secs
    while time.time() < end_time:
      if for_checkpoint:
        if checkpoint_management.checkpoint_exists(pattern):
          return
      else:
        if len(gfile.Glob(pattern)) >= 1:
          return
      time.sleep(0.05)
    self.assertFalse(True, "Glob never matched any file: %s" % pattern)
Exemple #11
0
    def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True):
        """Wait for a checkpoint file to appear.

    Args:
      pattern: A string.
      timeout_secs: How long to wait for in seconds.
      for_checkpoint: whether we're globbing for checkpoints.
    """
        end_time = time.time() + timeout_secs
        while time.time() < end_time:
            if for_checkpoint:
                if checkpoint_management.checkpoint_exists(pattern):
                    return
            else:
                if len(gfile.Glob(pattern)) >= 1:
                    return
            time.sleep(0.05)
        self.assertFalse(True, "Glob never matched any file: %s" % pattern)
Exemple #12
0
def deploy_checkpoint(input_meta_graph_def, input_checkpoint, q_config):

    if not checkpoint_management.checkpoint_exists(input_checkpoint):
        raise ValueError("Input checkpoint '" + input_checkpoint +
                         "' does not exits.")

    if gfile.IsDirectory(input_checkpoint):
        input_checkpoint = checkpoint_management.latest_checkpoint(
            input_checkpoint)

    if not os.path.exists(q_config.output_dir):
        os.makedirs(q_config.output_dir)

    quantize_eval_graph_def = None
    if input_meta_graph_def:
        quantize_eval_graph_def = input_meta_graph_def.graph_def
    else:
        raise ValueError("You need to provide a `MetaGraphDef` for deploy.")

    q_config.output_nodes = get_quantized_nodes(quantize_eval_graph_def,
                                                q_config.output_nodes)
    saver = saver_lib.import_meta_graph(input_meta_graph_def,
                                        clear_devices=True)
    with Session() as sess:
        saver.restore(sess, input_checkpoint)
        frozen_graph_def = graph_util.convert_variables_to_constants(
            sess, quantize_eval_graph_def, q_config.output_nodes)

    if not os.path.exists(os.path.join(q_config.output_dir, "deploy")):
        os.makedirs(os.path.join(q_config.output_dir, "deploy"))
    quantize_deploy_graph_def = CreateQuantizeDeployGraphDef(
        frozen_graph_def, q_config)
    save_pb_file(quantize_deploy_graph_def,
                 os.path.join(q_config.output_dir, "deploy/deploy_model.pb"))

    print("INFO: Quantize deploy graph are generated in: {}".format(
        os.path.join(q_config.output_dir, "deploy")))
    return
def patched_restore(self, sess, save_path, options=None):  # type: ignore
    """
    Restores previously saved variables.

    This method runs the ops added by the constructor for restoring variables.
    It requires a session in which the graph was launched.  The variables to
    restore do not have to have been initialized, as restoring is itself a way
    to initialize variables.

    The `save_path` argument is typically a value previously returned from a
    `save()` call, or a call to `latest_checkpoint()`.

    Args:
      sess: A `Session` to use to restore the parameters. None in eager mode.
      save_path: Path where parameters were previously saved.

    Raises:
      ValueError: If save_path is None or not a valid checkpoint.
    """
    if self._is_empty:
        return
    if save_path is None:
        raise ValueError("Can't load save_path when it is None.")

    checkpoint_prefix = compat.as_text(save_path)
    if not checkpoint_management.checkpoint_exists(checkpoint_prefix):
        raise ValueError("The passed save_path is not a valid checkpoint: " +
                         checkpoint_prefix)

    logging.info("Restoring parameters from %s", checkpoint_prefix)
    try:
        if context.executing_eagerly():
            self._build_eager(save_path, build_save=False, build_restore=True)
        else:
            sess.run(
                self.saver_def.restore_op_name,
                {self.saver_def.filename_tensor_name: save_path},
                options=options,
            )
    except errors.NotFoundError as err:
        # There are three common conditions that might cause this error:
        # 0. The file is missing. We ignore here, as this is checked above.
        # 1. This is an object-based checkpoint trying name-based loading.
        # 2. The graph has been altered and a variable or other name is missing.

        # 1. The checkpoint would not be loaded successfully as is. Try to parse
        # it as an object-based checkpoint.
        try:
            names_to_keys = object_graph_key_mapping(save_path)
        except errors.NotFoundError:
            # 2. This is not an object-based checkpoint, which likely means there
            # is a graph mismatch. Re-raise the original error with
            # a helpful message (b/110263146)
            raise _wrap_restore_error_with_msg(
                err, "a Variable name or other graph key that is missing")

        # This is an object-based checkpoint. We'll print a warning and then do
        # the restore.
        logging.warning(
            "Restoring an object-based checkpoint using a name-based saver. This "
            "may be somewhat fragile, and will re-build the Saver. Instead, "
            "consider loading object-based checkpoints using "
            "tf.train.Checkpoint().")
        self._object_restore_saver = saver_from_object_based_checkpoint(
            checkpoint_path=save_path,
            var_list=self._var_list,
            builder=self._builder,
            names_to_keys=names_to_keys,
            cached_saver=self._object_restore_saver,
        )
        self._object_restore_saver.restore(sess=sess,
                                           save_path=save_path,
                                           options=options)
    except errors.InvalidArgumentError as err:
        # There is a mismatch between the graph and the checkpoint being loaded.
        # We add a more reasonable error message here to help users (b/110263146)
        raise _wrap_restore_error_with_msg(
            err, "a mismatch between the current graph and the graph")
Exemple #14
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_denylist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants.
  Args:
    input_graph_def: A `GraphDef`.
    input_saver_def: A `SaverDef` (optional).
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated string of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted).
    variable_names_denylist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph_def: A `MetaGraphDef` (optional),
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
                           and variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format (optional).
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2)
  Returns:
    Location of the output_graph_def.
  """
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if (not input_saved_model_dir
            and not checkpoint_management.checkpoint_exists(input_checkpoint)):
        raise ValueError("Input checkpoint '" + input_checkpoint +
                         "' doesn't exist!")

    if not output_node_names:
        raise ValueError(
            "You need to supply the name of a node to --output_node_names.")

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        if input_meta_graph_def:
            for node in input_meta_graph_def.graph_def.node:
                node.device = ""
        elif input_graph_def:
            for node in input_graph_def.node:
                node.device = ""

    if input_graph_def:
        _ = importer.import_graph_def(input_graph_def, name="")
    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
        elif input_meta_graph_def:
            restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                   clear_devices=True)
            restorer.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))
        elif input_saved_model_dir:
            if saved_model_tags is None:
                saved_model_tags = []
            loader.load(sess, saved_model_tags, input_saved_model_dir)
        else:
            var_list = {}
            reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()

            # List of all partition variables. Because the condition is heuristic
            # based, the list could include false positives.
            all_partition_variable_names = [
                tensor.name.split(":")[0]
                for op in sess.graph.get_operations()
                for tensor in op.values()
                if re.search(r"/part_\d+/", tensor.name)
            ]
            has_partition_var = False

            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                    if any(key in name
                           for name in all_partition_variable_names):
                        has_partition_var = True
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor

            try:
                saver = saver_lib.Saver(var_list=var_list,
                                        write_version=checkpoint_version)
            except TypeError as e:
                # `var_list` is required to be a map of variable names to Variable
                # tensors. Partition variables are Identity tensors that cannot be
                # handled by Saver.
                if has_partition_var:
                    raise ValueError(
                        "Models containing partition variables cannot be converted "
                        "from checkpoint files. Please pass in a SavedModel using "
                        "the flag --input_saved_model_dir.")
                # Models that have been frozen previously do not contain Variables.
                elif _has_no_variables(sess):
                    raise ValueError(
                        "No variables were found in this model. It is likely the model "
                        "was frozen previously. You cannot freeze a graph twice."
                    )
                    return 0
                else:
                    raise e

            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))

        variable_names_whitelist = (variable_names_whitelist.replace(
            " ", "").split(",") if variable_names_whitelist else None)
        variable_names_denylist = (variable_names_denylist.replace(
            " ", "").split(",") if variable_names_denylist else None)

        if input_meta_graph_def:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_meta_graph_def.graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_denylist)
        else:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_denylist)

    # Write GraphDef to file if output path has been given.
    if output_graph:
        with gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())

    return output_graph_def
 def testKeepAll(self):
   checkpoint = util.Checkpoint()
   directory = os.path.join(
       self.get_temp_dir(),
       # Avoid sharing directories between eager and graph
       # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories
       str(context.executing_eagerly()))
   manager = checkpoint_management.CheckpointManager(
       checkpoint, directory, max_to_keep=None)
   first_path = manager.save()
   second_path = manager.save()
   third_path = manager.save()
   self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
   self.assertEqual(third_path, manager.latest_checkpoint)
   self.assertEqual([first_path, second_path, third_path],
                    manager.checkpoints)
   del manager
   manager = checkpoint_management.CheckpointManager(
       checkpoint, directory, max_to_keep=None)
   fourth_path = manager.save()
   self.assertEqual([first_path, second_path, third_path, fourth_path],
                    manager.checkpoints)
   del manager
   manager = checkpoint_management.CheckpointManager(
       checkpoint, directory, max_to_keep=3)
   self.assertEqual([first_path, second_path, third_path, fourth_path],
                    manager.checkpoints)
   self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
   fifth_path = manager.save()
   self.assertEqual([third_path, fourth_path, fifth_path],
                    manager.checkpoints)
   self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
   self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
   self.assertFalse(checkpoint_management.checkpoint_exists(second_path))
   self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
    def testSaveRestoreState(self, mock_time):
        directory = self.get_temp_dir()
        mock_time.time.return_value = 3.
        checkpoint = util.Checkpoint()
        first_manager = checkpoint_management.CheckpointManager(checkpoint,
                                                                directory,
                                                                max_to_keep=2)
        first_time = 10000.
        first_name = os.path.join(directory, "ckpt-1")
        mock_time.time.return_value = first_time
        first_manager.save()
        state = checkpoint_management.get_checkpoint_state(directory)
        second_time = first_time + 3610.
        second_name = os.path.join(directory, "ckpt-2")
        mock_time.time.return_value = second_time
        first_manager.save()
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual([first_time, second_time],
                         state.all_model_checkpoint_timestamps)
        self.assertEqual([first_name, second_name], first_manager.checkpoints)
        self.assertEqual(second_name, first_manager.latest_checkpoint)
        del first_manager

        second_manager = checkpoint_management.CheckpointManager(
            checkpoint,
            directory,
            max_to_keep=2,
            keep_checkpoint_every_n_hours=1.5)
        self.assertEqual([first_name, second_name], second_manager.checkpoints)
        self.assertEqual(second_name, second_manager.latest_checkpoint)
        third_name = os.path.join(directory, "ckpt-3")
        third_time = second_time + 3600. * 0.2
        mock_time.time.return_value = third_time
        second_manager.save()
        self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(second_name))
        self.assertEqual([second_name, third_name], second_manager.checkpoints)
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual(first_time, state.last_preserved_timestamp)
        fourth_time = third_time + 3600. * 0.5
        mock_time.time.return_value = fourth_time
        fourth_name = os.path.join(directory, "ckpt-4")
        second_manager.save()
        self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
        self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
        self.assertEqual([third_name, fourth_name], second_manager.checkpoints)
        fifth_time = fourth_time + 3600. * 0.5
        mock_time.time.return_value = fifth_time
        fifth_name = os.path.join(directory, "ckpt-5")
        second_manager.save()
        self.assertEqual([fourth_name, fifth_name], second_manager.checkpoints)
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual(first_time, state.last_preserved_timestamp)
        del second_manager
        third_manager = checkpoint_management.CheckpointManager(
            checkpoint,
            directory,
            max_to_keep=2,
            keep_checkpoint_every_n_hours=1.5)
        self.assertEqual(fifth_name, third_manager.latest_checkpoint)
        mock_time.time.return_value += 10.
        third_manager.save()
        sixth_name = os.path.join(directory, "ckpt-6")
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual(fourth_time, state.last_preserved_timestamp)
        self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name))
        self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
        self.assertFalse(checkpoint_management.checkpoint_exists(third_name))
        self.assertEqual([fifth_name, sixth_name], third_manager.checkpoints)
 def testKeepAll(self):
     checkpoint = util.Checkpoint()
     directory = os.path.join(
         self.get_temp_dir(),
         # Avoid sharing directories between eager and graph
         # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories
         str(context.executing_eagerly()))
     manager = checkpoint_management.CheckpointManager(checkpoint,
                                                       directory,
                                                       max_to_keep=None)
     first_path = manager.save()
     second_path = manager.save()
     third_path = manager.save()
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
     self.assertEqual(third_path, manager.latest_checkpoint)
     self.assertEqual([first_path, second_path, third_path],
                      manager.checkpoints)
     del manager
     manager = checkpoint_management.CheckpointManager(checkpoint,
                                                       directory,
                                                       max_to_keep=None)
     fourth_path = manager.save()
     self.assertEqual([first_path, second_path, third_path, fourth_path],
                      manager.checkpoints)
     del manager
     manager = checkpoint_management.CheckpointManager(checkpoint,
                                                       directory,
                                                       max_to_keep=3)
     self.assertEqual([first_path, second_path, third_path, fourth_path],
                      manager.checkpoints)
     self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
     fifth_path = manager.save()
     self.assertEqual([third_path, fourth_path, fifth_path],
                      manager.checkpoints)
     self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertFalse(checkpoint_management.checkpoint_exists(second_path))
     self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
  def testSaveRestoreState(self, mock_time):
    directory = self.get_temp_dir()
    mock_time.time.return_value = 3.
    checkpoint = util.Checkpoint()
    first_manager = checkpoint_management.CheckpointManager(
        checkpoint, directory, max_to_keep=2)
    first_time = 10000.
    first_name = os.path.join(directory, "ckpt-1")
    mock_time.time.return_value = first_time
    first_manager.save()
    state = checkpoint_management.get_checkpoint_state(directory)
    second_time = first_time + 3610.
    second_name = os.path.join(directory, "ckpt-2")
    mock_time.time.return_value = second_time
    first_manager.save()
    state = checkpoint_management.get_checkpoint_state(directory)
    self.assertEqual([first_time, second_time],
                     state.all_model_checkpoint_timestamps)
    self.assertEqual([first_name, second_name], first_manager.checkpoints)
    self.assertEqual(second_name, first_manager.latest_checkpoint)
    del first_manager

    second_manager = checkpoint_management.CheckpointManager(
        checkpoint, directory,
        max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
    self.assertEqual([first_name, second_name], second_manager.checkpoints)
    self.assertEqual(second_name, second_manager.latest_checkpoint)
    third_name = os.path.join(directory, "ckpt-3")
    third_time = second_time + 3600. * 0.2
    mock_time.time.return_value = third_time
    second_manager.save()
    self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
    self.assertTrue(checkpoint_management.checkpoint_exists(second_name))
    self.assertEqual([second_name, third_name],
                     second_manager.checkpoints)
    state = checkpoint_management.get_checkpoint_state(directory)
    self.assertEqual(first_time, state.last_preserved_timestamp)
    fourth_time = third_time + 3600. * 0.5
    mock_time.time.return_value = fourth_time
    fourth_name = os.path.join(directory, "ckpt-4")
    second_manager.save()
    self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
    self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
    self.assertEqual([third_name, fourth_name],
                     second_manager.checkpoints)
    fifth_time = fourth_time + 3600. * 0.5
    mock_time.time.return_value = fifth_time
    fifth_name = os.path.join(directory, "ckpt-5")
    second_manager.save()
    self.assertEqual([fourth_name, fifth_name],
                     second_manager.checkpoints)
    state = checkpoint_management.get_checkpoint_state(directory)
    self.assertEqual(first_time, state.last_preserved_timestamp)
    del second_manager
    third_manager = checkpoint_management.CheckpointManager(
        checkpoint, directory,
        max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
    self.assertEqual(fifth_name, third_manager.latest_checkpoint)
    mock_time.time.return_value += 10.
    third_manager.save()
    sixth_name = os.path.join(directory, "ckpt-6")
    state = checkpoint_management.get_checkpoint_state(directory)
    self.assertEqual(fourth_time, state.last_preserved_timestamp)
    self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
    self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name))
    self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name))
    self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name))
    self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
    self.assertFalse(checkpoint_management.checkpoint_exists(third_name))
    self.assertEqual([fifth_name, sixth_name],
                     third_manager.checkpoints)
Exemple #19
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if (not input_saved_model_dir
            and not checkpoint_management.checkpoint_exists(input_checkpoint)):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        if input_meta_graph_def:
            for node in input_meta_graph_def.graph_def.node:
                node.device = ""
        elif input_graph_def:
            for node in input_graph_def.node:
                node.device = ""

    if input_graph_def:
        _ = importer.import_graph_def(input_graph_def, name="")
    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
        elif input_meta_graph_def:
            restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                   clear_devices=True)
            restorer.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))
        elif input_saved_model_dir:
            if saved_model_tags is None:
                saved_model_tags = []
            loader.load(sess, saved_model_tags, input_saved_model_dir)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()

            # List of all partition variables. Because the condition is heuristic
            # based, the list could include false positives.
            all_parition_variable_names = [
                tensor.name.split(":")[0]
                for op in sess.graph.get_operations()
                for tensor in op.values()
                if re.search(r"/part_\d+/", tensor.name)
            ]
            has_partition_var = False

            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                    if any(key in name
                           for name in all_parition_variable_names):
                        has_partition_var = True
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor

            try:
                saver = saver_lib.Saver(var_list=var_list,
                                        write_version=checkpoint_version)
            except TypeError as e:
                # `var_list` is required to be a map of variable names to Variable
                # tensors. Partition variables are Identity tensors that cannot be
                # handled by Saver.
                if has_partition_var:
                    print(
                        "Models containing partition variables cannot be converted "
                        "from checkpoint files. Please pass in a SavedModel using "
                        "the flag --input_saved_model_dir.")
                    return -1
                else:
                    raise e

            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))

        variable_names_whitelist = (variable_names_whitelist.replace(
            " ", "").split(",") if variable_names_whitelist else None)
        variable_names_blacklist = (variable_names_blacklist.replace(
            " ", "").split(",") if variable_names_blacklist else None)

        if input_meta_graph_def:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_meta_graph_def.graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)
        else:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)

    # Write GraphDef to file if output path has been given.
    if output_graph:
        with gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())

    return output_graph_def
Exemple #20
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants.

  Args:
    input_graph_def: A `GraphDef`.
    input_saver_def: A `SaverDef` (optional).
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated string of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted).
    variable_names_blacklist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph_def: A `MetaGraphDef` (optional),
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
                           and variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format (optional).
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2)

  Returns:
    Location of the output_graph_def.
  """
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if (not input_saved_model_dir and
      not checkpoint_management.checkpoint_exists(input_checkpoint)):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    if input_meta_graph_def:
      for node in input_meta_graph_def.graph_def.node:
        node.device = ""
    elif input_graph_def:
      for node in input_graph_def.node:
        node.device = ""

  if input_graph_def:
    _ = importer.import_graph_def(input_graph_def, name="")
  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(
          saver_def=input_saver_def, write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
    elif input_meta_graph_def:
      restorer = saver_lib.import_meta_graph(
          input_meta_graph_def, clear_devices=True)
      restorer.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))
    elif input_saved_model_dir:
      if saved_model_tags is None:
        saved_model_tags = []
      loader.load(sess, saved_model_tags, input_saved_model_dir)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()

      # List of all partition variables. Because the condition is heuristic
      # based, the list could include false positives.
      all_parition_variable_names = [
          tensor.name.split(":")[0]
          for op in sess.graph.get_operations()
          for tensor in op.values()
          if re.search(r"/part_\d+/", tensor.name)
      ]
      has_partition_var = False

      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
          if any(key in name for name in all_parition_variable_names):
            has_partition_var = True
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor

      try:
        saver = saver_lib.Saver(
            var_list=var_list, write_version=checkpoint_version)
      except TypeError as e:
        # `var_list` is required to be a map of variable names to Variable
        # tensors. Partition variables are Identity tensors that cannot be
        # handled by Saver.
        if has_partition_var:
          print("Models containing partition variables cannot be converted "
                "from checkpoint files. Please pass in a SavedModel using "
                "the flag --input_saved_model_dir.")
          return -1
        # Models that have been frozen previously do not contain Variables.
        elif _has_no_variables(sess):
          print("No variables were found in this model. It is likely the model "
                "was frozen previously. You cannot freeze a graph twice.")
          return 0
        else:
          raise e

      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))

    variable_names_whitelist = (
        variable_names_whitelist.replace(" ", "").split(",")
        if variable_names_whitelist else None)
    variable_names_blacklist = (
        variable_names_blacklist.replace(" ", "").split(",")
        if variable_names_blacklist else None)

    if input_meta_graph_def:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_meta_graph_def.graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)
    else:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

  # Write GraphDef to file if output path has been given.
  if output_graph:
    with gfile.GFile(output_graph, "wb") as f:
      f.write(output_graph_def.SerializeToString())

  return output_graph_def
Exemple #21
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if (not input_saved_model_dir and
      not checkpoint_management.checkpoint_exists(input_checkpoint)):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    if input_meta_graph_def:
      for node in input_meta_graph_def.graph_def.node:
        node.device = ""
    elif input_graph_def:
      for node in input_graph_def.node:
        node.device = ""

  if input_graph_def:
    _ = importer.import_graph_def(input_graph_def, name="")
  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(
          saver_def=input_saver_def, write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
    elif input_meta_graph_def:
      restorer = saver_lib.import_meta_graph(
          input_meta_graph_def, clear_devices=True)
      restorer.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))
    elif input_saved_model_dir:
      if saved_model_tags is None:
        saved_model_tags = []
      loader.load(sess, saved_model_tags, input_saved_model_dir)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()

      # List of all partition variables. Because the condition is heuristic
      # based, the list could include false positives.
      all_parition_variable_names = [
          tensor.name.split(":")[0]
          for op in sess.graph.get_operations()
          for tensor in op.values()
          if re.search(r"/part_\d+/", tensor.name)
      ]
      has_partition_var = False

      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
          if any(key in name for name in all_parition_variable_names):
            has_partition_var = True
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor

      try:
        saver = saver_lib.Saver(
            var_list=var_list, write_version=checkpoint_version)
      except TypeError as e:
        # `var_list` is required to be a map of variable names to Variable
        # tensors. Partition variables are Identity tensors that cannot be
        # handled by Saver.
        if has_partition_var:
          print("Models containing partition variables cannot be converted "
                "from checkpoint files. Please pass in a SavedModel using "
                "the flag --input_saved_model_dir.")
          return -1
        else:
          raise e

      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))

    variable_names_whitelist = (
        variable_names_whitelist.replace(" ", "").split(",")
        if variable_names_whitelist else None)
    variable_names_blacklist = (
        variable_names_blacklist.replace(" ", "").split(",")
        if variable_names_blacklist else None)

    if input_meta_graph_def:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_meta_graph_def.graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)
    else:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

  # Write GraphDef to file if output path has been given.
  if output_graph:
    with gfile.GFile(output_graph, "wb") as f:
      f.write(output_graph_def.SerializeToString())

  return output_graph_def