def testDeletion(self):
     checkpoint = util.Checkpoint()
     manager = checkpoint_management.CheckpointManager(checkpoint,
                                                       self.get_temp_dir(),
                                                       max_to_keep=3)
     first_path = manager.save()
     second_path = manager.save()
     third_path = manager.save()
     fourth_path = manager.save()
     self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
    def testRemoveCheckpoint(self):
        for sharded in (False, True):
            for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
                with self.session(graph=ops_lib.Graph()) as sess:
                    unused_v = variables.Variable(1.0, name="v")
                    self.evaluate(variables.global_variables_initializer())
                    saver = saver_module.Saver(sharded=sharded,
                                               write_version=version)

                    path = os.path.join(self._base_dir,
                                        "%s-%s" % (sharded, version))
                    ckpt_prefix = saver.save(sess, path)
                    self.assertTrue(
                        checkpoint_management.checkpoint_exists(ckpt_prefix))
                    checkpoint_management.remove_checkpoint(
                        ckpt_prefix, version)
                    self.assertFalse(
                        checkpoint_management.checkpoint_exists(ckpt_prefix))
 def testClockReset(self, mock_time):
     directory = self.get_temp_dir()
     mock_time.time.return_value = 10000.
     checkpoint = util.Checkpoint()
     first_manager = checkpoint_management.CheckpointManager(
         checkpoint,
         directory,
         max_to_keep=1,
         keep_checkpoint_every_n_hours=1.)
     first_path = first_manager.save()
     mock_time.time.return_value += 3600.
     second_path = first_manager.save()
     mock_time.time.return_value += 3600.
     third_path = first_manager.save()
     self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertEqual([third_path], first_manager.checkpoints)
     state = checkpoint_management.get_checkpoint_state(directory)
     self.assertEqual(13600., state.last_preserved_timestamp)
     # Set the clock back in time
     mock_time.time.return_value = 5000.
     del first_manager
     with test.mock.patch.object(logging, "warning") as mock_log:
         second_manager = checkpoint_management.CheckpointManager(
             checkpoint, directory, max_to_keep=1)
         self.assertRegex(str(mock_log.call_args),
                          "behind the last preserved checkpoint timestamp")
     # We should err on the side of keeping checkpoints around when we're not
     # sure whether they were preserved or not due to clock funkiness.
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     # We know about the existing checkpoints, but they'll never be deleted and
     # so won't go in the CheckpointState proto on save.
     self.assertEqual(third_path, second_manager.latest_checkpoint)
     self.assertEqual([], second_manager.checkpoints)
     mock_time.time.return_value += 10.
     fourth_path = second_manager.save()
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertEqual(fourth_path, second_manager.latest_checkpoint)
     self.assertEqual([fourth_path], second_manager.checkpoints)
     mock_time.time.return_value += 10.
     fifth_path = second_manager.save()
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertEqual([fifth_path], second_manager.checkpoints)
     state = checkpoint_management.get_checkpoint_state(directory)
     self.assertEqual(5000., state.last_preserved_timestamp)
     self.assertEqual([5020.], state.all_model_checkpoint_timestamps)
    def testCheckpointInterval(self):
        v = variables.Variable(1.0)
        step_counter = variables.Variable(0)
        self.evaluate([v.initializer, step_counter.initializer])
        checkpoint = util.Checkpoint(v=v)
        manager = checkpoint_management.CheckpointManager(
            checkpoint,
            self.get_temp_dir(),
            max_to_keep=None,
            step_counter=step_counter,
            checkpoint_interval=2)

        # step_counter: 0, save an initial checkpoint.
        path = manager.save(check_interval=True)
        self.assertTrue(checkpoint_management.checkpoint_exists(path))

        # step_counter: 1, no checkpoint saved.
        self.evaluate(step_counter.assign_add(1))
        path = manager.save(check_interval=True)
        self.assertIsNone(path)

        # step_counter: 2, checkpoint saved.
        self.evaluate(step_counter.assign_add(1))
        path = manager.save(check_interval=True)
        self.assertTrue(checkpoint_management.checkpoint_exists(path))

        # no checkpoint saved when calling `save` with the same step counter.
        path = manager.save(check_interval=True)
        self.assertIsNone(path)

        # step_counter: 3, no checkpoint saved.
        self.evaluate(step_counter.assign_add(1))
        path = manager.save(check_interval=True)
        self.assertIsNone(path)

        # Always save the checkpoint.
        path = manager.save(check_interval=False)
        self.assertTrue(checkpoint_management.checkpoint_exists(path))
 def testContinueFromUnmanaged(self):
     directory = self.get_temp_dir()
     prefix = os.path.join(directory, "unusual_prefix")
     checkpoint = util.Checkpoint()
     first_path = checkpoint.save(prefix)
     second_path = checkpoint.save(prefix)
     del checkpoint
     checkpoint = util.Checkpoint()
     manager = checkpoint_management.CheckpointManager(checkpoint,
                                                       directory,
                                                       max_to_keep=2)
     checkpoint.restore(manager.latest_checkpoint).run_restore_ops()
     self.assertEqual(2, self.evaluate(checkpoint.save_counter))
     third_path = manager.save()
     self.assertEqual([third_path], manager.checkpoints)
     fourth_path = manager.save()
     self.assertEqual([third_path, fourth_path], manager.checkpoints)
     fifth_path = manager.save()
     self.assertEqual([fourth_path, fifth_path], manager.checkpoints)
     self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertFalse(checkpoint_management.checkpoint_exists(third_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
示例#6
0
  def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True):
    """Wait for a checkpoint file to appear.

    Args:
      pattern: A string.
      timeout_secs: How long to wait for in seconds.
      for_checkpoint: whether we're globbing for checkpoints.
    """
    end_time = time.time() + timeout_secs
    while time.time() < end_time:
      if for_checkpoint:
        if checkpoint_management.checkpoint_exists(pattern):
          return
      else:
        if len(gfile.Glob(pattern)) >= 1:
          return
      time.sleep(0.05)
    self.assertFalse(True, "Glob never matched any file: %s" % pattern)
示例#7
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_denylist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants.

  Args:
    input_graph_def: A `GraphDef`.
    input_saver_def: A `SaverDef` (optional).
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated string of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted).
    variable_names_denylist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph_def: A `MetaGraphDef` (optional),
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
                           and variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format (optional).
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2)

  Returns:
    Location of the output_graph_def.
  """
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if (not input_saved_model_dir
            and not checkpoint_management.checkpoint_exists(input_checkpoint)):
        raise ValueError("Input checkpoint '" + input_checkpoint +
                         "' doesn't exist!")

    if not output_node_names:
        raise ValueError(
            "You need to supply the name of a node to --output_node_names.")

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        if input_meta_graph_def:
            for node in input_meta_graph_def.graph_def.node:
                node.device = ""
        elif input_graph_def:
            for node in input_graph_def.node:
                node.device = ""

    if input_graph_def:
        _ = importer.import_graph_def(input_graph_def, name="")
    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
        elif input_meta_graph_def:
            restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                   clear_devices=True)
            restorer.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))
        elif input_saved_model_dir:
            if saved_model_tags is None:
                saved_model_tags = []
            loader.load(sess, saved_model_tags, input_saved_model_dir)
        else:
            var_list = {}
            reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()

            # List of all partition variables. Because the condition is heuristic
            # based, the list could include false positives.
            all_partition_variable_names = [
                tensor.name.split(":")[0]
                for op in sess.graph.get_operations()
                for tensor in op.values()
                if re.search(r"/part_\d+/", tensor.name)
            ]
            has_partition_var = False

            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                    if any(key in name
                           for name in all_partition_variable_names):
                        has_partition_var = True
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor

            try:
                saver = saver_lib.Saver(var_list=var_list,
                                        write_version=checkpoint_version)
            except TypeError as e:
                # `var_list` is required to be a map of variable names to Variable
                # tensors. Partition variables are Identity tensors that cannot be
                # handled by Saver.
                if has_partition_var:
                    raise ValueError(
                        "Models containing partition variables cannot be converted "
                        "from checkpoint files. Please pass in a SavedModel using "
                        "the flag --input_saved_model_dir.")
                # Models that have been frozen previously do not contain Variables.
                elif _has_no_variables(sess):
                    raise ValueError(
                        "No variables were found in this model. It is likely the model "
                        "was frozen previously. You cannot freeze a graph twice."
                    )
                    return 0
                else:
                    raise e

            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))

        variable_names_whitelist = (variable_names_whitelist.replace(
            " ", "").split(",") if variable_names_whitelist else None)
        variable_names_denylist = (variable_names_denylist.replace(
            " ", "").split(",") if variable_names_denylist else None)

        if input_meta_graph_def:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_meta_graph_def.graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_denylist)
        else:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_denylist)

    # Write GraphDef to file if output path has been given.
    if output_graph:
        with gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString(deterministic=True))

    return output_graph_def
    def testSaveRestoreState(self, mock_time):
        directory = self.get_temp_dir()
        mock_time.time.return_value = 3.
        checkpoint = util.Checkpoint()
        first_manager = checkpoint_management.CheckpointManager(checkpoint,
                                                                directory,
                                                                max_to_keep=2)
        first_time = 10000.
        first_name = os.path.join(directory, "ckpt-1")
        mock_time.time.return_value = first_time
        first_manager.save()
        state = checkpoint_management.get_checkpoint_state(directory)
        second_time = first_time + 3610.
        second_name = os.path.join(directory, "ckpt-2")
        mock_time.time.return_value = second_time
        first_manager.save()
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual([first_time, second_time],
                         state.all_model_checkpoint_timestamps)
        self.assertEqual([first_name, second_name], first_manager.checkpoints)
        self.assertEqual(second_name, first_manager.latest_checkpoint)
        del first_manager

        second_manager = checkpoint_management.CheckpointManager(
            checkpoint,
            directory,
            max_to_keep=2,
            keep_checkpoint_every_n_hours=1.5)
        self.assertEqual([first_name, second_name], second_manager.checkpoints)
        self.assertEqual(second_name, second_manager.latest_checkpoint)
        third_name = os.path.join(directory, "ckpt-3")
        third_time = second_time + 3600. * 0.2
        mock_time.time.return_value = third_time
        second_manager.save()
        self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(second_name))
        self.assertEqual([second_name, third_name], second_manager.checkpoints)
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual(first_time, state.last_preserved_timestamp)
        fourth_time = third_time + 3600. * 0.5
        mock_time.time.return_value = fourth_time
        fourth_name = os.path.join(directory, "ckpt-4")
        second_manager.save()
        self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
        self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
        self.assertEqual([third_name, fourth_name], second_manager.checkpoints)
        fifth_time = fourth_time + 3600. * 0.5
        mock_time.time.return_value = fifth_time
        fifth_name = os.path.join(directory, "ckpt-5")
        second_manager.save()
        self.assertEqual([fourth_name, fifth_name], second_manager.checkpoints)
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual(first_time, state.last_preserved_timestamp)
        del second_manager
        third_manager = checkpoint_management.CheckpointManager(
            checkpoint,
            directory,
            max_to_keep=2,
            keep_checkpoint_every_n_hours=1.5)
        self.assertEqual(fifth_name, third_manager.latest_checkpoint)
        mock_time.time.return_value += 10.
        third_manager.save()
        sixth_name = os.path.join(directory, "ckpt-6")
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual(fourth_time, state.last_preserved_timestamp)
        self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name))
        self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
        self.assertFalse(checkpoint_management.checkpoint_exists(third_name))
        self.assertEqual([fifth_name, sixth_name], third_manager.checkpoints)
 def testKeepAll(self):
     checkpoint = util.Checkpoint()
     directory = os.path.join(
         self.get_temp_dir(),
         # Avoid sharing directories between eager and graph
         # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories
         str(context.executing_eagerly()))
     manager = checkpoint_management.CheckpointManager(checkpoint,
                                                       directory,
                                                       max_to_keep=None)
     first_path = manager.save()
     second_path = manager.save()
     third_path = manager.save()
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
     self.assertEqual(third_path, manager.latest_checkpoint)
     self.assertEqual([first_path, second_path, third_path],
                      manager.checkpoints)
     del manager
     manager = checkpoint_management.CheckpointManager(checkpoint,
                                                       directory,
                                                       max_to_keep=None)
     fourth_path = manager.save()
     self.assertEqual([first_path, second_path, third_path, fourth_path],
                      manager.checkpoints)
     del manager
     manager = checkpoint_management.CheckpointManager(checkpoint,
                                                       directory,
                                                       max_to_keep=3)
     self.assertEqual([first_path, second_path, third_path, fourth_path],
                      manager.checkpoints)
     self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
     fifth_path = manager.save()
     self.assertEqual([third_path, fourth_path, fifth_path],
                      manager.checkpoints)
     self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertFalse(checkpoint_management.checkpoint_exists(second_path))
     self.assertFalse(checkpoint_management.checkpoint_exists(first_path))