示例#1
0
 def test_inititialize_with_data_structures(self, enable_async_ckpt):
   if enable_async_ckpt and not context.executing_eagerly():
     self.skipTest(
         "Skipping this test as async checkpoint does not support graph mode.")
   checkpoint = trackable_utils.Checkpoint(
       a=[variables_lib.Variable(0.), variables_lib.Variable(1.)],
       b={"a": variables_lib.Variable(2.), "b": variables_lib.Variable(3.)})
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   ckpt_options = checkpoint_options.CheckpointOptions(
       experimental_enable_async_checkpoint=enable_async_ckpt)
   save_path = checkpoint.save(file_prefix=checkpoint_prefix,
                               options=ckpt_options)
   load_checkpoint = trackable_utils.Checkpoint(
       a=[variables_lib.Variable(4.), variables_lib.Variable(5.)],
       b={"a": variables_lib.Variable(6.), "b": variables_lib.Variable(7.)})
   # When async checkpoint is enabled, we need to first make sure that the
   # checkpoint saving is fully complete before the checkpoint file can be
   # loaded by another checkpoint instance. Calling checkpoint.restore() is a
   # trick to make sure its async thread is joined.
   if enable_async_ckpt:
     checkpoint.restore(save_path)
   load_checkpoint.restore(save_path)
   self.assertAllClose(self.evaluate(load_checkpoint.a), [0, 1])
   self.assertAllClose(self.evaluate(load_checkpoint.b), {"a": 2, "b": 3})
示例#2
0
    def testCheckpointSaveRestoreIoDevice(self, distribution):
        def state():
            with distribution.scope():
                v = variables_lib.Variable(random_ops.random_normal([]))
                return v

        ckpt_options = checkpoint_options.CheckpointOptions(
            experimental_io_device="/job:localhost")

        def checkpoint():
            v = state()
            # Save random weights into checkpoint.
            checkpoint = trackable_utils.Checkpoint(v=v)
            prefix = os.path.join(self.get_temp_dir(), "ckpt")
            with self.test_session():
                save_path = checkpoint.save(prefix, options=ckpt_options)
            return save_path

        save_path = checkpoint()

        v = state()
        checkpoint = trackable_utils.Checkpoint(v=v)
        # Restore from the checkpoint inside a distribution.scope().
        # Check that restore works without error.
        with self.test_session():
            with distribution.scope():
                checkpoint.restore(save_path, options=ckpt_options)
示例#3
0
  def save(self, file_prefix, options=None):
    """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
      options: Optional `CheckpointOptions` object.
    Returns:
      An `Operation`, or None when executing eagerly.
    """
    options = options or checkpoint_options.CheckpointOptions()
    tensor_names = []
    tensors = []
    tensor_slices = []
    for saveable in self._saveable_objects:
      for spec in saveable.specs:
        tensor = spec.tensor
        # A tensor value of `None` indicates that this SaveableObject gets
        # recorded in the object graph, but that no value is saved in the
        # checkpoint.
        if tensor is not None:
          tensor_names.append(spec.name)
          tensors.append(tensor)
          tensor_slices.append(spec.slice_spec)
    save_device = options.experimental_io_device or "cpu:0"
    with ops.device(save_device):
      return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)
示例#4
0
  def restore(self, file_prefix, options=None):
    """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.
      options: Optional `CheckpointOptions` object.

    Returns:
      A dictionary mapping from SaveableObject names to restore operations.
    """
    options = options or checkpoint_options.CheckpointOptions()
    restore_specs = []
    tensor_structure = []
    for saveable in self._saveable_objects:
      saveable_tensor_structure = []
      tensor_structure.append(saveable_tensor_structure)
      for spec in saveable.specs:
        saveable_tensor_structure.append(spec.name)
        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
    tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
    restore_device = options.experimental_io_device or "cpu:0"
    with ops.device(restore_device):
      restored_tensors = io_ops.restore_v2(
          file_prefix, tensor_names, tensor_slices, tensor_dtypes)
    structured_restored_tensors = nest.pack_sequence_as(
        tensor_structure, restored_tensors)
    restore_ops = {}
    for saveable, restored_tensors in zip(self._saveable_objects,
                                          structured_restored_tensors):
      restore_ops[saveable.name] = saveable.restore(
          restored_tensors, restored_shapes=None)
    return restore_ops
示例#5
0
def load_internal(export_dir, tags=None, options=None, loader_cls=Loader,
                  filters=None):
  """Loader implementation."""
  options = options or load_options.LoadOptions()
  if tags is not None and not isinstance(tags, set):
    # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
    # sequences for nest.flatten, so we put those through as-is.
    tags = nest.flatten(tags)
  saved_model_proto, debug_info = (
      loader_impl.parse_saved_model_with_debug_info(export_dir))

  if (len(saved_model_proto.meta_graphs) == 1 and
      saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
    meta_graph_def = saved_model_proto.meta_graphs[0]
    # tensor_content field contains raw bytes in litle endian format
    # which causes problems when loaded on big-endian systems
    # requiring byteswap
    if sys.byteorder == "big":
      saved_model_utils.swap_function_tensor_content(meta_graph_def, "little",
                                                     "big")
    if (tags is not None
        and set(tags) != set(meta_graph_def.meta_info_def.tags)):
      raise ValueError(
          ("The SavedModel at {} has one MetaGraph with tags {}, but got an "
           "incompatible argument tags={} to tf.saved_model.load. You may omit "
           "it, pass 'None', or pass matching tags.")
          .format(export_dir, meta_graph_def.meta_info_def.tags, tags))
    object_graph_proto = meta_graph_def.object_graph_def

    ckpt_options = checkpoint_options.CheckpointOptions(
        experimental_io_device=options.experimental_io_device)
    with ops.init_scope():
      try:
        loader = loader_cls(object_graph_proto, saved_model_proto, export_dir,
                            ckpt_options, filters)
      except errors.NotFoundError as err:
        raise FileNotFoundError(
            str(err) + "\n If trying to load on a different device from the "
            "computational device, consider using setting the "
            "`experimental_io_device` option on tf.saved_model.LoadOptions "
            "to the io_device such as '/job:localhost'."
        )
      root = loader.get(0)
      if isinstance(loader, Loader):
        root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info)
    root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
    root.tensorflow_git_version = (
        meta_graph_def.meta_info_def.tensorflow_git_version)
  else:
    if filters:
      raise ValueError("SavedModels saved from Tensorflow V1 or Estimator (any "
                       "version) cannot be loaded with node filters.")
    with ops.init_scope():
      root = load_v1_in_v2.load(export_dir, tags)
      root.graph_debug_info = debug_info

  if filters:
    return {node_id: loader.get(node_id) for node_id in filters}
  else:
    return {"root": root}
示例#6
0
 def setUp(self):
   super(SaverTest, self).setUp()
   cpus = config.list_physical_devices("CPU")
   # Set 3 virtual CPUs
   config.set_logical_device_configuration(cpus[0], [
       context.LogicalDeviceConfiguration(),
       context.LogicalDeviceConfiguration(),
       context.LogicalDeviceConfiguration()
   ])
   self.local_options = checkpoint_options.CheckpointOptions(
       experimental_io_device=LOCALHOST)
示例#7
0
    def restore(self, file_prefix, options=None):
        """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.
      options: Optional `CheckpointOptions` object.

    Returns:
      A dictionary mapping from SaveableObject names to restore operations.
    """
        options = options or checkpoint_options.CheckpointOptions()

        def restore_fn():
            restore_ops = {}
            # Sort by device name to avoid propagating non-deterministic dictionary
            # ordering in some Python versions.
            for device, saver in sorted(self._single_device_savers.items()):
                with ops.device(device):
                    restore_ops.update(saver.restore(file_prefix, options))

            return restore_ops

        # Since this will causes a function re-trace on each restore, limit this to
        # cases where it is needed: eager and when there are multiple tasks/single
        # device savers. Note that the retrace is needed to ensure we pickup the
        # latest values of options like experimental_io_device.
        if context.executing_eagerly() and len(self._single_device_savers) > 1:
            first_device, _ = list(self._single_device_savers.items())[0]

            @def_function.function(jit_compile=False)
            def tf_function_restore():
                restore_ops = restore_fn()
                restore_tensors = {}
                # tf.functions must return tensors, thus we use control dependencies so
                # that we can return a tensor which depends on the given op.
                with ops.device(saveable_object_util.set_cpu0(first_device)):
                    for name, op in restore_ops.items():
                        with ops.control_dependencies([op]):
                            restore_tensors[name] = array_ops.identity(
                                file_prefix)
                return restore_tensors

            restore_ops = tf_function_restore()
        else:
            restore_ops = restore_fn()

        for callback in self._after_restore_callbacks:
            callback()

        return restore_ops
示例#8
0
 def testAssertConsumedNoCheckpoint(self, enable_async_ckpt):
   if enable_async_ckpt and not context.executing_eagerly():
     self.skipTest(
         "Skipping this test as async checkpoint does not support graph mode.")
   prefix = os.path.join(self.get_temp_dir(), "ckpt")
   v = variable_scope.get_variable(name="v", initializer=0.)
   self.evaluate(v.initializer)
   ckpt = trackable_utils.Checkpoint(v=v)
   self.evaluate(trackable_utils.gather_initializers(ckpt))
   ckpt_options = checkpoint_options.CheckpointOptions(
       experimental_enable_async_checkpoint=enable_async_ckpt)
   save_path = ckpt.save(file_prefix=prefix, options=ckpt_options)
   status = ckpt.restore(save_path=save_path)
   del ckpt
   status.assert_consumed()
示例#9
0
def load_internal(export_dir, tags=None, options=None, loader_cls=Loader):
  """Loader implementation."""
  options = options or load_options.LoadOptions()
  if tags is not None and not isinstance(tags, set):
    # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
    # sequences for nest.flatten, so we put those through as-is.
    tags = nest.flatten(tags)
  saved_model_proto, debug_info = (
      loader_impl.parse_saved_model_with_debug_info(export_dir))

  if (len(saved_model_proto.meta_graphs) == 1 and
      saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
    meta_graph_def = saved_model_proto.meta_graphs[0]
    if (tags is not None
        and set(tags) != set(meta_graph_def.meta_info_def.tags)):
      raise ValueError(
          ("The SavedModel at {} has one MetaGraph with tags {}, but got an "
           "incompatible argument tags={} to tf.saved_model.load. You may omit "
           "it, pass 'None', or pass matching tags.")
          .format(export_dir, meta_graph_def.meta_info_def.tags, tags))
    object_graph_proto = meta_graph_def.object_graph_def

    ckpt_options = checkpoint_options.CheckpointOptions(
        experimental_io_device=options.experimental_io_device)
    with ops.init_scope():
      try:
        loader = loader_cls(object_graph_proto, saved_model_proto, export_dir,
                            ckpt_options)
      except errors.NotFoundError as err:
        raise FileNotFoundError(
            str(err) + "\n If trying to load on a different device from the "
            "computational device, consider using setting the "
            "`experimental_io_device` option on tf.saved_model.LoadOptions "
            "to the io_device such as '/job:localhost'."
        )
      root = loader.get(0)
      if isinstance(loader, Loader):
        root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info)
    root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
    root.tensorflow_git_version = (
        meta_graph_def.meta_info_def.tensorflow_git_version)
  else:
    with ops.init_scope():
      root = load_v1_in_v2.load(export_dir, tags)
      root.graph_debug_info = debug_info
  return root
示例#10
0
    def restore(self, file_prefix, options=None):
        """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.
      options: Optional `CheckpointOptions` object.

    Returns:
      When not run eagerly or when saving on a single device, returns a
      dictionary mapping from SaveableObject names to restore operations;
      otherwise, returns an empty dict.
    """
        options = options or checkpoint_options.CheckpointOptions()

        def restore_fn():
            restore_ops = {}
            # Sort by device name to avoid propagating non-deterministic dictionary
            # ordering in some Python versions.
            for device, saver in sorted(self._single_device_savers.items()):
                with ops.device(device):
                    restore_ops.update(saver.restore(file_prefix, options))
            for _, (_, restore_fn) in self._registered_savers.items():
                restore_fn(file_prefix)
            return restore_ops

        # Since this will causes a function re-trace on each restore, limit this to
        # cases where it is needed: eager and when there are multiple tasks/single
        # device savers. Note that the retrace is needed to ensure we pickup the
        # latest values of options like experimental_io_device.
        if context.executing_eagerly() and len(self._single_device_savers) > 1:

            @def_function.function(jit_compile=False)
            def tf_function_restore():
                restore_fn()
                return {}

            restore_ops = tf_function_restore()
        else:
            restore_ops = restore_fn()

        for callback in self._after_restore_callbacks:
            callback()

        return restore_ops
示例#11
0
 def testCustomNumbering(self, enable_async_ckpt):
   if enable_async_ckpt and not context.executing_eagerly():
     self.skipTest(
         "Skipping this test as async checkpoint does not support graph mode.")
   directory = self.get_temp_dir()
   prefix = os.path.join(directory, "ckpt")
   step = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
   checkpoint = trackable_utils.Checkpoint(step=step)
   ckpt_options = checkpoint_options.CheckpointOptions(
       experimental_enable_async_checkpoint=enable_async_ckpt)
   self.evaluate(step.initializer)
   for i in range(5):
     path = checkpoint.write("%s-%d" % (prefix, self.evaluate(step)),
                             options=ckpt_options)
     expected_suffix = "-%d" % (2 * i,)
     if not path.endswith(expected_suffix):
       self.fail("%s should have suffix %s" % (path, expected_suffix))
     self.evaluate(step.assign_add(2))
    def testPassingCheckpointOptions(self):
        localhost = "/job:localhost/device:CPU:0"
        options = checkpoint_options.CheckpointOptions(
            experimental_io_device=localhost)
        prefix = os.path.join(self.get_temp_dir(), "ckpt")
        v = variable_scope.get_variable(name="v", initializer=0.)
        self.evaluate(v.initializer)
        ckpt = trackable_utils.Checkpoint(v=v)
        self.evaluate(trackable_utils.gather_initializers(ckpt))
        save_path = ckpt.save(file_prefix=prefix, options=options)
        status = ckpt.restore(save_path=save_path, options=options)
        del ckpt
        status.assert_consumed()

        # In graph mode, verify that the save and restore ops were set to run on
        # localhost.
        if not context.executing_eagerly():
            for op in ops.get_default_graph().get_operations():
                if op.type in ("SaveV2", "RestoreV2"):
                    self.assertEqual(localhost, op.device)
示例#13
0
  def save(self, file_prefix, options=None):
    """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
      options: Optional `CheckpointOptions` object.
    Returns:
      An `Operation`, or None when executing eagerly.
    """
    options = options or checkpoint_options.CheckpointOptions()
    tensor_names = []
    tensors = []
    tensor_slices = []
    for saveable in self._saveable_objects:
      for spec in saveable.specs:
        tensor_names.append(spec.name)
        tensors.append(spec.tensor)
        tensor_slices.append(spec.slice_spec)
    save_device = options.experimental_io_device or "cpu:0"
    with ops.device(save_device):
      return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)
示例#14
0
    def restore(self, file_prefix, options=None):
        """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.
      options: Optional `CheckpointOptions` object.

    Returns:
      A dictionary mapping from SaveableObject names to restore operations.
    """
        options = options or checkpoint_options.CheckpointOptions()
        restore_ops = {}
        # Sort by device name to avoid propagating non-deterministic dictionary
        # ordering in some Python versions.
        for device, saver in sorted(self._single_device_savers.items()):
            with ops.device(device):
                restore_ops.update(saver.restore(file_prefix, options))

        for callback in self._after_restore_callbacks:
            callback()

        return restore_ops
示例#15
0
 def testMoreComplexSaveableReturned(self, enable_async_ckpt):
   if enable_async_ckpt and not context.executing_eagerly():
     self.skipTest(
         "Skipping this test as async checkpoint does not support graph mode.")
   v = _OwnsMirroredVariables()
   checkpoint = trackable_utils.Checkpoint(v=v)
   test_dir = self.get_temp_dir()
   prefix = os.path.join(test_dir, "ckpt")
   self.evaluate(v.non_dep_variable.assign(42.))
   ckpt_options = checkpoint_options.CheckpointOptions(
       experimental_enable_async_checkpoint=enable_async_ckpt)
   save_path = checkpoint.save(file_prefix=prefix, options=ckpt_options)
   self.evaluate(v.non_dep_variable.assign(43.))
   self.evaluate(v.mirrored.assign(44.))
   checkpoint.restore(save_path).assert_consumed().initialize_or_restore()
   self.assertEqual(42., self.evaluate(v.non_dep_variable))
   self.assertEqual(42., self.evaluate(v.mirrored))
   self.evaluate(v.non_dep_variable.assign(44.))
   save_path = checkpoint.save(file_prefix=prefix, options=ckpt_options)
   self.evaluate(v.non_dep_variable.assign(45.))
   checkpoint.restore(save_path).assert_consumed().initialize_or_restore()
   self.assertEqual(44., self.evaluate(v.non_dep_variable))
   self.assertEqual(44., self.evaluate(v.mirrored))
示例#16
0
  def save(self, file_prefix, options=None):
    """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
      options: Optional `CheckpointOptions` object.
    Returns:
      An `Operation`, or None when executing eagerly.
    """
    options = options or checkpoint_options.CheckpointOptions()
    for callback in self._before_save_callbacks:
      callback()

    # IMPLEMENTATION DETAILS: most clients should skip.
    #
    # Suffix for any well-formed "checkpoint_prefix", when sharded.
    # Transformations:
    # * Users pass in "save_path" in save() and restore().  Say "myckpt".
    # * checkpoint_prefix gets fed <save_path><sharded_suffix>.
    #
    # Example:
    #   During runtime, a temporary directory is first created, which contains
    #   files
    #
    #     <train dir>/myckpt_temp/
    #        part-?????-of-?????{.index, .data-00000-of-00001}
    #
    #   Before .save() finishes, they will be (hopefully, atomically) renamed to
    #
    #     <train dir>/
    #        myckpt{.index, .data-?????-of-?????}
    #
    #   Filesystems with eventual consistency (such as S3), don't need a
    #   temporary location. Using a temporary directory in those cases might
    #   cause situations where files are not available during copy.
    #
    # Users only need to interact with the user-specified prefix, which is
    # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
    # prefix directly, instead of any physical pathname.  (On failure and
    # subsequent restore, an outdated and orphaned temporary directory can be
    # safely removed.)
    with ops.device("CPU"):
      sharded_suffix = array_ops.where(
          string_ops.regex_full_match(file_prefix, "^s3://.*"),
          constant_op.constant(".part"),
          constant_op.constant("_temp_%s/part" % uuid.uuid4().hex))
      tmp_checkpoint_prefix = string_ops.string_join(
          [file_prefix, sharded_suffix])

    def save_fn():
      num_shards = len(self._single_device_savers)
      sharded_saves = []
      sharded_prefixes = []
      num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
      last_device = None
      for shard, (device, saver) in enumerate(
          sorted(self._single_device_savers.items())):
        last_device = device
        with ops.device(saveable_object_util.set_cpu0(device)):
          shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
                                          num_shards_tensor)
        sharded_prefixes.append(shard_prefix)
        with ops.device(device):
          # _SingleDeviceSaver will use the CPU device when necessary, but
          # initial read operations should be placed on the SaveableObject's
          # device.
          sharded_saves.append(saver.save(shard_prefix, options))

      with ops.control_dependencies(sharded_saves):
        # Merge on the io_device if specified, otherwise co-locates the merge op
        # with the last device used.
        merge_device = (
            options.experimental_io_device or
            saveable_object_util.set_cpu0(last_device))
        with ops.device(merge_device):
          # V2 format write path consists of a metadata merge step.  Once
          # merged, attempts to delete the temporary directory,
          # "<user-fed prefix>_temp".
          return gen_io_ops.merge_v2_checkpoints(
              sharded_prefixes, file_prefix, delete_old_dirs=True)

    # Since this will causes a function re-trace on each save, limit this to the
    # cases where it is needed: eager and when there are multiple tasks/single
    # device savers. Note that the retrace is needed to ensure we pickup the
    # latest values of options like experimental_io_device.
    if context.executing_eagerly() and len(self._single_device_savers) > 1:
      # Explicitly place the identity op on the first device.
      @def_function.function(experimental_compile=False)
      def tf_function_save():
        save_fn()
      tf_function_save()
    else:
      return save_fn()
示例#17
0
    def __init__(self,
                 filepath,
                 monitor='val_loss',
                 verbose=0,
                 save_best_only=False,
                 save_weights_only=False,
                 mode='auto',
                 save_freq='epoch',
                 options=None,
                 **kwargs):
        super(ModelCheckpoint, self).__init__()
        self.filepaths = []
        self._supports_tf_logs = True
        self.monitor = monitor
        self.verbose = verbose
        self.filepath = tf.python.keras.utils.io_utils.path_to_string(filepath)
        self.save_best_only = save_best_only
        self.save_weights_only = save_weights_only
        self.save_freq = save_freq
        self.epochs_since_last_save = 0
        self._batches_seen_since_last_saving = 0
        self._last_batch_seen = 0

        if save_weights_only:
            if options is None or isinstance(
                    options, checkpoint_options_lib.CheckpointOptions):
                self._options = options or checkpoint_options_lib.CheckpointOptions(
                )
            else:
                raise TypeError(
                    'If save_weights_only is True, then `options` must be'
                    'either None or a tf.train.CheckpointOptions')
        else:
            if options is None or isinstance(options,
                                             save_options_lib.SaveOptions):
                self._options = options or save_options_lib.SaveOptions()
            else:
                raise TypeError(
                    'If save_weights_only is False, then `options` must be'
                    'either None or a tf.saved_model.SaveOptions')

        # Deprecated field `load_weights_on_restart` is for loading the checkpoint
        # file from `filepath` at the start of `model.fit()`
        # TODO(rchao): Remove the arg during next breaking release.
        if 'load_weights_on_restart' in kwargs:
            self.load_weights_on_restart = kwargs['load_weights_on_restart']
            logging.warning(
                '`load_weights_on_restart` argument is deprecated. '
                'Please use `model.load_weights()` for loading weights '
                'before the start of `model.fit()`.')
        else:
            self.load_weights_on_restart = False

        # Deprecated field `period` is for the number of epochs between which
        # the model is saved.
        if 'period' in kwargs:
            self.period = kwargs['period']
            logging.warning(
                '`period` argument is deprecated. Please use `save_freq` '
                'to specify the frequency in number of batches seen.')
        else:
            self.period = 1

        if mode not in ['auto', 'min', 'max']:
            logging.warning(
                'ModelCheckpoint mode %s is unknown, '
                'fallback to auto mode.', mode)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
            self.best = np.Inf
        elif mode == 'max':
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = np.greater
                self.best = -np.Inf
            else:
                self.monitor_op = np.less
                self.best = np.Inf

        if self.save_freq != 'epoch' and not isinstance(self.save_freq, int):
            raise ValueError('Unrecognized save_freq: {}'.format(
                self.save_freq))

        # Only the chief worker writes model checkpoints, but all workers
        # restore checkpoint at on_train_begin().
        self._chief_worker_only = False
示例#18
0
    def restore(self, save_path, options=None):
        """Restore a training checkpoint with host mesh placement."""
        options = options or checkpoint_options.CheckpointOptions()
        if save_path is None:
            return util.InitializationOnlyStatus(self._graph_view, ops.uid())
        reader = py_checkpoint_reader.NewCheckpointReader(save_path)
        graph_building = not context.executing_eagerly()
        if graph_building:
            dtype_map = None
        else:
            dtype_map = reader.get_variable_to_dtype_map()
        try:
            object_graph_string = reader.get_tensor(
                base.OBJECT_GRAPH_PROTO_KEY)
        except errors_impl.NotFoundError:
            # The object graph proto does not exist in this checkpoint. Try the
            # name-based compatibility mode.
            restore_coordinator = util._NameBasedRestoreCoordinator(  # pylint: disable=protected-access
                save_path=save_path,
                dtype_map=dtype_map)
            if not graph_building:
                for existing_trackable in self._graph_view.list_objects():
                    # pylint: disable=protected-access
                    existing_trackable._maybe_initialize_trackable()
                    existing_trackable._name_based_restores.add(
                        restore_coordinator)
                    existing_trackable._name_based_attribute_restore(
                        restore_coordinator)
                    # pylint: enable=protected-access
            return util.NameBasedSaverStatus(restore_coordinator,
                                             graph_view=self._graph_view)

        if graph_building:
            if self._file_prefix_placeholder is None:
                # DTensor change: provide a hint for mesh broadcasting to put the input
                # onto the host mesh.
                self._file_prefix_placeholder = api.pack(
                    [constant_op.constant("model")] *
                    self._mesh.num_local_devices(),
                    layout.Layout.replicated(self._mesh.host_mesh(), rank=0))
            file_prefix_tensor = self._file_prefix_placeholder
            file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
        else:
            # DTensor change: provide a hint for mesh broadcasting to put the input
            # onto the host mesh.
            file_prefix_tensor = api.pack([constant_op.constant(save_path)] *
                                          self._mesh.num_local_devices(),
                                          layout.Layout.replicated(
                                              self._mesh.host_mesh(), rank=0))
            file_prefix_feed_dict = None
        object_graph_proto = (
            trackable_object_graph_pb2.TrackableObjectGraph())
        object_graph_proto.ParseFromString(object_graph_string)
        # DTensor Change: Hook the proper DSaver in restore.
        checkpoint = _DCheckpointRestoreCoordinator(
            mesh=self._mesh,
            object_graph_proto=object_graph_proto,
            save_path=save_path,
            save_path_tensor=file_prefix_tensor,
            reader=reader,
            restore_op_cache=self._restore_op_cache,
            graph_view=self._graph_view,
            options=options,
            saveables_cache=self._saveables_cache)
        base.CheckpointPosition(checkpoint=checkpoint,
                                proto_id=0).restore(self._graph_view.root)

        # Attached dependencies are not attached to the root, so should be restored
        # separately.
        if self._graph_view.attached_dependencies:
            for ref in self._graph_view.attached_dependencies:
                if ref.name == "root":
                    # Root dependency is automatically added to attached dependencies --
                    # this can be ignored since it maps back to the root object.
                    continue
                proto_id = None
                # Find proto ID of attached dependency (if it is in the proto).
                for proto_ref in object_graph_proto.nodes[0].children:
                    if proto_ref.local_name == ref.name:
                        proto_id = proto_ref.node_id
                        break

                if proto_id in checkpoint.object_by_proto_id:
                    # Object has already been restored. This can happen when there's an
                    # indirect connection from the attached object to the root.
                    continue

                base.CheckpointPosition(checkpoint=checkpoint,
                                        proto_id=proto_id).restore(ref.ref)

        load_status = util.CheckpointLoadStatus(
            checkpoint,
            graph_view=self._graph_view,
            feed_dict=file_prefix_feed_dict)
        return load_status
示例#19
0
文件: load.py 项目: npfp/tensorflow
def load_partial(export_dir, filters, tags=None, options=None):
  """Partially load a SavedModel (saved from V2).

  Similar to `tf.saved_model.load`, but with an additional argument that
  lets you specify which nodes to load.
  `tf.saved_model.load_partial(export_dir, ["root"])` and
  `tf.saved_model.load(export_dir)` are equivalent.

  Note: This only works for SavedModels saved with TensorFlow V2 from
  `tf.saved_model.save` or Keras. This will not load SavedModels save from
  the Estimator API.

  In Tensorflow V2, SavedModel stores the **object graph** of the saved object.
  The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras
  layers, etc.) and edges that are the name of the attributes connecting the
  objects.

  *Example 1*

  ```
  model = tf.Module()
  model.child_layer = tf.Module()
  model.child_layer.v = tf.Variable(5.)
  tf.saved_model.save(model, '/tmp/model')
  loaded = tf.__internal__.saved_model.load_partial(
  ...   '/tmp/model',
  ...   ['root.child_layer', 'root.child_layer.v'])
  loaded['root.child_layer'].v.numpy()
  5.
  loaded['root.child_layer'].v is loaded['root.child_layer.v']
  True

  *Example 2*
  model = tf.Module()
  model.child_layer = tf.Module()
  model.child_layer.v = tf.Variable(5.)
  >>>
  tf.saved_model.save(model, '/tmp/model')
  # Create a variable
  new_variable = tf.Variable(0.)
  loaded = tf.__internal__.saved_model.load_partial(
  ...   '/tmp/model',
  ...   {'root.child_layer': None, 'root.child_layer.v': new_variable})
  loaded['root.child_layer'].v.numpy()
  5.
  new_variable.numpy()
  5.
  ```

  **Loading under different distribution strategies**
  You can load different parts of the model under different distribution
  strategies. Note that this is very experimental so use with care.

  ```
  model = tf.Module()
  model.layer_1 = tf.Module()
  model.layer_1.v = tf.Variable(5.)
  model.layer_2 = tf.Module()
  model.layer_2.v = tf.Variable(7.)
  tf.saved_model.save(model, '/tmp/model')
  # Load with no strategy
  loaded = tf.__internal__.saved_model.load_partial(
  ...   '/tmp/model',
  ...   ['root.layer_1'])
  loaded['root.layer_1'].v
  <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>
  strategy = tf.distribute.MirroredStrategy()
  with strategy.scope():
  ...   loaded2 = tf.__internal__.saved_model.load_partial(
  ...     '/tmp/model',
  ...     ['root.layer_2'])
  loaded2['root.layer_2'].v
  MirroredVariable:{
      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0>
  }
  ```

  Args:
    export_dir: The SavedModel directory to load from.
    filters: A list or dictionary where each element or key is a string
      path to nodes that should be loaded. Node paths consist of all the child
      attribute names to reach that node in the form: `root.{attribute_name}`.
      The loader will load all of the specified nodes and their recursive
      descendants. When this option is defined, the loader will return a
      dictionary mapping the node paths to the loaded objects.
    tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
      if the SavedModel contains a single MetaGraph, as for those exported from
      `tf.saved_model.save`.
    options: `tf.saved_model.LoadOptions` object that specifies options for
      loading.

  Returns:
    A dictionary mapping node paths from the filter to loaded objects.
  """
  options = options or load_options.LoadOptions()
  if tags is not None and not isinstance(tags, set):
    # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
    # sequences for nest.flatten, so we put those through as-is.
    tags = nest.flatten(tags)
  saved_model_proto, debug_info = (
      loader_impl.parse_saved_model_with_debug_info(export_dir))

  if (len(saved_model_proto.meta_graphs) == 1 and
      saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
    metrics.IncrementReadApi(_LOAD_V2_LABEL)
    meta_graph_def = saved_model_proto.meta_graphs[0]
    # tensor_content field contains raw bytes in litle endian format
    # which causes problems when loaded on big-endian systems
    # requiring byteswap
    if sys.byteorder == "big":
      saved_model_utils.swap_function_tensor_content(meta_graph_def, "little",
                                                     "big")
    if (tags is not None
        and set(tags) != set(meta_graph_def.meta_info_def.tags)):
      raise ValueError(
          f"Got an incompatible argument to `tags`: {tags}. The SavedModel at "
          f"{export_dir} has one MetaGraph with tags "
          f"{meta_graph_def.meta_info_def.tags}. You may omit the argument, "
          "pass 'None', or pass matching tags.")
    object_graph_proto = meta_graph_def.object_graph_def

    ckpt_options = checkpoint_options.CheckpointOptions(
        experimental_io_device=options.experimental_io_device)
    with ops.init_scope():
      try:
        loader = Loader(object_graph_proto, saved_model_proto, export_dir,
                        ckpt_options, options, filters)
      except errors.NotFoundError as err:
        raise FileNotFoundError(
            str(err) + "\n You may be trying to load on a different device "
            "from the computational device. Consider setting the "
            "`experimental_io_device` option in `tf.saved_model.LoadOptions` "
            "to the io_device such as '/job:localhost'.")
      root = loader.get(0)
      root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info)
    root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
    root.tensorflow_git_version = (
        meta_graph_def.meta_info_def.tensorflow_git_version)
    metrics.IncrementRead(write_version="2")
  else:
    if filters:
      raise ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any"
                       " version) cannot be loaded with node filters.")
    with ops.init_scope():
      root = load_v1_in_v2.load(export_dir, tags)
      root.graph_debug_info = debug_info

  if filters:
    return {node_id: loader.get(node_id) for node_id in filters}
  else:
    return {"root": root}
示例#20
0
def save(obj, export_dir, signatures=None, options=None):
    # pylint: disable=line-too-long
    """Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md).

  Example usage:

  ```python
  class Adder(tf.Module):

    @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
    def add(self, x):
      return x + x + 1.

  to_export = Adder()
  tf.saved_model.save(to_export, '/tmp/adder')
  ```

  The resulting SavedModel is then servable with an input named "x", its value
  having any shape and dtype float32.

  The optional `signatures` argument controls which methods in `obj` will be
  available to programs which consume `SavedModel`s, for example, serving
  APIs. Python functions may be decorated with
  `@tf.function(input_signature=...)` and passed as signatures directly, or
  lazily with a call to `get_concrete_function` on the method decorated with
  `@tf.function`.

  If the `signatures` argument is omitted, `obj` will be searched for
  `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that
  method will be used as the default signature for the SavedModel. This behavior
  is expected to change in the future, when a corresponding
  `tf.saved_model.load` symbol is added. At that point signatures will be
  completely optional, and any `@tf.function` attached to `obj` or its
  dependencies will be exported for use with `load`.

  When invoking a signature in an exported SavedModel, `Tensor` arguments are
  identified by name. These names will come from the Python function's argument
  names by default. They may be overridden by specifying a `name=...` argument
  in the corresponding `tf.TensorSpec` object. Explicit naming is required if
  multiple `Tensor`s are passed through a single argument to the Python
  function.

  The outputs of functions used as `signatures` must either be flat lists, in
  which case outputs will be numbered, or a dictionary mapping string keys to
  `Tensor`, in which case the keys will be used to name outputs.

  Signatures are available in objects returned by `tf.saved_model.load` as a
  `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save`
  on an object with a custom `.signatures` attribute will raise an exception.

  Since `tf.keras.Model` objects are also Trackable, this function can be
  used to export Keras models. For example, exporting with a signature
  specified:

  ```python
  class Model(tf.keras.Model):

    @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
    def serve(self, serialized):
      ...

  m = Model()
  tf.saved_model.save(m, '/tmp/saved_model/')
  ```

  Exporting from a function without a fixed signature:

  ```python
  class Model(tf.keras.Model):

    @tf.function
    def call(self, x):
      ...

  m = Model()
  tf.saved_model.save(
      m, '/tmp/saved_model/',
      signatures=m.call.get_concrete_function(
          tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp")))
  ```

  `tf.keras.Model` instances constructed from inputs and outputs already have a
  signature and so do not require a `@tf.function` decorator or a `signatures`
  argument. If neither are specified, the model's forward pass is exported.

  ```python
  x = input_layer.Input((4,), name="x")
  y = core.Dense(5, name="out")(x)
  model = training.Model(x, y)
  tf.saved_model.save(model, '/tmp/saved_model/')
  # The exported SavedModel takes "x" with shape [None, 4] and returns "out"
  # with shape [None, 5]
  ```

  Variables must be tracked by assigning them to an attribute of a tracked
  object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers
  from `tf.keras.layers`, optimizers from `tf.train`) track their variables
  automatically. This is the same tracking scheme that `tf.train.Checkpoint`
  uses, and an exported `Checkpoint` object may be restored as a training
  checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's
  "variables/" subdirectory. Currently, variables are the only stateful objects
  supported by `tf.saved_model.save`, but others (e.g. tables) will be supported
  in the future.

  `tf.function` does not hard-code device annotations from outside the function
  body, instead of using the calling context's device. This means for example
  that exporting a model that runs on a GPU and serving it on a CPU will
  generally work, with some exceptions. `tf.device` annotations inside the body
  of the function will be hard-coded in the exported model; this type of
  annotation is discouraged. Device-specific operations, e.g. with "cuDNN" in
  the name or with device-specific layouts, may cause issues. Currently a
  `DistributionStrategy` is another exception: active distribution strategies
  will cause device placements to be hard-coded in a function. Exporting a
  single-device computation and importing under a `DistributionStrategy` is
  not currently supported, but may be in the future.

  SavedModels exported with `tf.saved_model.save` [strip default-valued
  attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes)
  automatically, which removes one source of incompatibilities when the consumer
  of a SavedModel is running an older TensorFlow version than the
  producer. There are however other sources of incompatibilities which are not
  handled automatically, such as when the exported model contains operations
  which the consumer does not have definitions for.

  A single tf.function can generate many ConcreteFunctions. If a downstream tool
  wants to refer to all concrete functions generated by a single tf.function you
  can use the `function_aliases` argument to store a map from the alias name to
  all concrete function names.
  E.g.
  ```python
  class MyModel:
  @tf.function
  def func():
    ...

  @tf.function
  def serve():
    ...
    func()

  model = MyModel()
  signatures = {
      'serving_default': model.serve.get_concrete_function(),
  }
  options = tf.saved_model.SaveOptions(function_aliases={
      'my_func': func,
  })
  tf.saved_model.save(model, export_dir, signatures, options)
  ```

  Args:
    obj: A trackable object to export.
    export_dir: A directory in which to write the SavedModel.
    signatures: Optional, either a `tf.function` with an input signature
      specified or the result of `f.get_concrete_function` on a
      `@tf.function`-decorated function `f`, in which case `f` will be used to
      generate a signature for the SavedModel under the default serving
      signature key. `signatures` may also be a dictionary, in which case it
      maps from signature keys to either `tf.function` instances with input
      signatures or concrete functions. The keys of such a dictionary may be
      arbitrary strings, but will typically be from the
      `tf.saved_model.signature_constants` module.
    options: Optional, `tf.saved_model.SaveOptions` object that specifies
      options for saving.

  Raises:
    ValueError: If `obj` is not trackable.

  @compatibility(eager)
  Not well supported when graph building. From TensorFlow 1.x,
  `tf.compat.v1.enable_eager_execution()` should run first. Calling
  tf.saved_model.save in a loop when graph building from TensorFlow 1.x will
  add new save operations to the default graph each iteration.

  May not be called from within a function body.
  @end_compatibility
  """
    options = options or save_options.SaveOptions()
    # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
    # compatible (no sessions) and share it with this export API rather than
    # making a SavedModel proto and writing it directly.
    saved_model = saved_model_pb2.SavedModel()
    meta_graph_def = saved_model.meta_graphs.add()

    _, exported_graph, object_saver, asset_info = _build_meta_graph(
        obj, export_dir, signatures, options, meta_graph_def)
    saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION

    # Write the checkpoint, copy assets into the assets directory, and write out
    # the SavedModel proto itself.
    utils_impl.get_or_create_variables_dir(export_dir)
    ckpt_options = checkpoint_options.CheckpointOptions(
        experimental_io_device=options.experimental_io_device)
    object_saver.save(utils_impl.get_variables_path(export_dir),
                      options=ckpt_options)
    builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
                                                export_dir)
    # Note that this needs to be the last file operation when saving the
    # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an
    # indication that the SavedModel is completely written.
    if context.executing_eagerly():
        try:
            context.async_wait()  # Ensure save operations have completed.
        except errors.NotFoundError as err:
            raise FileNotFoundError(
                str(err) +
                "\n If trying to save on a different device from the "
                "computational device, consider using setting the "
                "`experimental_io_device` option on tf.saved_model.SaveOptions "
                "to the io_device such as '/job:localhost'.")

    path = os.path.join(compat.as_str(export_dir),
                        compat.as_str(constants.SAVED_MODEL_FILENAME_PB))
    file_io.atomic_write_string_to_file(
        path, saved_model.SerializeToString(deterministic=True))

    # Clean reference cycles so repeated export()s don't make work for the garbage
    # collector. Before this point, we need to keep references to captured
    # constants in the saved graph.
    ops.dismantle_graph(exported_graph)