def _infer_mht_saveable_name(ms):
  """Returns name of the `MutableHashTable._Saveable`

  Args:
    ms: A `MutableHashTable._Saveable`

  Returns:
    Name of the `MutableHashTable._Saveable`
  """
  name_to_ms_dict = saveable_object_util.op_list_to_dict([ms])
  if len(name_to_ms_dict) > 1:
    raise TypeError("`ms` = %s passed as arg violates the constraints.  "
                    "name_to_var_dict = %s" % (ms, name_to_ms_dict))
  return list(name_to_ms_dict.keys())[0]
    def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs):
        """Create a saver copy global_center_variable to trainable variables

    Please call this function after all your variables created with
    ElasticAverageCustomGetter. For evaluations or inference, use this saver
    during training.  It will save the global_center_variable of the trained
    parameters under the original parameter names.
    Args:
      var_list: List of variables to save, as per `Saver()`. If set to None,
        save all the trainable_variables that have been created before this
        call.
      name: The name of the saver.
      **kwargs: Keyword arguments of `Saver()`.

    Returns:
      A `tf.compat.v1.train.Saver` object.
    Raises:
      RuntimeError: global_center_variable is empty, please make sure
                    this is called after model created and
                    ElasticAverageCustomGetter is used when declaring you model
    """
        if not self._global_map:
            raise RuntimeError(
                'global_center_variable is empty, please make sure '
                'this is called after model created and '
                'ElasticAverageCustomGetter is used when declaring '
                'you model')

        if var_list is None:
            var_list = variables.trainable_variables()
        if not isinstance(var_list, dict):
            var_list = saveable_object_util.op_list_to_dict(var_list)

        swapped_var_list = {}
        for key, var in var_list.items():
            tensor = var

            if not isinstance(var, list):
                for tvar in variables.trainable_variables():
                    if tvar.op.name == var.op.name:
                        tensor = self._global_map.get(tvar, var)
                        break
            else:  #partitioned variable
                tensor = [self._global_map.get(lvar, lvar) for lvar in var]

            swapped_var_list[key] = tensor

        return saver.Saver(swapped_var_list, name=name, **kwargs)
  def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs):
    """Create a saver copy global_center_variable to trainable variables

    Please call this function after all your variables created with
    ElasticAverageCustomGetter. For evaluations or inference, use this saver
    during training.  It will save the global_center_variable of the trained
    parameters under the original parameter names.
    Args:
      var_list: List of variables to save, as per `Saver()`. If set to None,
        save all the trainable_variables that have been created before this
        call.
      name: The name of the saver.
      **kwargs: Keyword arguments of `Saver()`.

    Returns:
      A `tf.compat.v1.train.Saver` object.
    Raises:
      RuntimeError: global_center_variable is empty, please make sure
                    this is called after model created and
                    ElasticAverageCustomGetter is used when declaring you model
    """
    if not self._global_map:
      raise RuntimeError('global_center_variable is empty, please make sure '
                         'this is called after model created and '
                         'ElasticAverageCustomGetter is used when declaring '
                         'you model')

    if var_list is None:
      var_list = variables.trainable_variables()
    if not isinstance(var_list, dict):
      var_list = saveable_object_util.op_list_to_dict(var_list)

    swapped_var_list = {}
    for key, var in var_list.items():
      tensor = var

      if not isinstance(var, list):
        for tvar in variables.trainable_variables():
          if tvar.op.name == var.op.name:
            tensor = self._global_map.get(tvar, var)
            break
      else:  #partitioned variable
        tensor = [self._global_map.get(lvar, lvar) for lvar in var]

      swapped_var_list[key] = tensor

    return saver.Saver(swapped_var_list, name=name, **kwargs)
Exemplo n.º 4
0
def _infer_var_name(var):
    """Returns name of the `var`.

  Args:
    var: A list. The list can contain either of the following:
      (i) A single `Variable`
      (ii) A single `ResourceVariable`
      (iii) Multiple `Variable` objects which must be slices of the same larger
        variable.
      (iv) A single `PartitionedVariable`

  Returns:
    Name of the `var`
  """
    name_to_var_dict = saveable_object_util.op_list_to_dict(var)
    if len(name_to_var_dict) > 1:
        raise TypeError("`var` = %s passed as arg violates the constraints.  "
                        "name_to_var_dict = %s" % (var, name_to_var_dict))
    return list(name_to_var_dict.keys())[0]
Exemplo n.º 5
0
def _infer_var_name(var):
  """Returns name of the `var`.

  Args:
    var: A list. The list can contain either of the following:
      (i) A single `Variable`
      (ii) A single `ResourceVariable`
      (iii) Multiple `Variable` objects which must be slices of the same larger
        variable.
      (iv) A single `PartitionedVariable`

  Returns:
    Name of the `var`
  """
  name_to_var_dict = saveable_object_util.op_list_to_dict(var)
  if len(name_to_var_dict) > 1:
    raise TypeError("`var` = %s passed as arg violates the constraints.  "
                    "name_to_var_dict = %s" % (var, name_to_var_dict))
  return list(name_to_var_dict.keys())[0]
Exemplo n.º 6
0
def _set_checkpoint_initializer(variable,
                                ckpt_file,
                                tensor_name,
                                slice_spec,
                                name="checkpoint_initializer"):
    """Overrides given variable's initialization op.

  Sets variable initializer to assign op that initializes variable from tensor's
  value in the checkpoint.

  Args:
    variable: `tf.Variable` object.
    ckpt_file: string, full path of the checkpoint.
    tensor_name: Name of the tensor to load from the checkpoint.
    slice_spec: Slice specification for loading partitioned tensors.
    name: Name of the operation.
  """
    base_type = variable.dtype.base_dtype
    # Do not colocate with variable since RestoreV2 op only runs on CPU and
    # colocation will force variable (and other ops that colocate with variable)
    # to be on CPU as well. It is okay to place the variable's initializer op on
    # CPU since it will only be run once at the start.
    with ops.device(variable.device), ops.device("/cpu:0"):
        restore_op = io_ops.restore_v2(ckpt_file, [tensor_name], [slice_spec],
                                       [base_type],
                                       name=name)[0]

        names_to_saveables = saveable_object_util.op_list_to_dict([variable])
        saveable_objects = []
        for name, op in names_to_saveables.items():
            for s in saveable_object_util.saveable_objects_for_op(op, name):
                saveable_objects.append(s)

        assert len(saveable_objects) == 1  # Should be only one variable.
    init_op = saveable_objects[0].restore([restore_op], restored_shapes=None)

    # pylint:disable=protected-access
    variable._initializer_op = init_op
    restore_op.set_shape(variable.shape)
    variable._initial_value = restore_op
Exemplo n.º 7
0
def _set_checkpoint_initializer(variable,
                                ckpt_file,
                                tensor_name,
                                slice_spec,
                                name="checkpoint_initializer"):
  """Overrides given variable's initialization op.

  Sets variable initializer to assign op that initializes variable from tensor's
  value in the checkpoint.

  Args:
    variable: `tf.Variable` object.
    ckpt_file: string, full path of the checkpoint.
    tensor_name: Name of the tensor to load from the checkpoint.
    slice_spec: Slice specification for loading partitioned tensors.
    name: Name of the operation.
  """
  base_type = variable.dtype.base_dtype
  # Do not colocate with variable since RestoreV2 op only runs on CPU and
  # colocation will force variable (and other ops that colocate with variable)
  # to be on CPU as well. It is okay to place the variable's initializer op on
  # CPU since it will only be run once at the start.
  with ops.device(variable.device), ops.device("/cpu:0"):
    restore_op = io_ops.restore_v2(
        ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]

    names_to_saveables = saveable_object_util.op_list_to_dict([variable])
    saveable_objects = []
    for name, op in names_to_saveables.items():
      for s in saveable_object_util.saveable_objects_for_op(op, name):
        saveable_objects.append(s)

    assert len(saveable_objects) == 1  # Should be only one variable.
  init_op = saveable_objects[0].restore([restore_op], restored_shapes=None)

  # pylint:disable=protected-access
  variable._initializer_op = init_op
  restore_op.set_shape(variable.shape)
  variable._initial_value = restore_op
Exemplo n.º 8
0
    def _add_attributes_to_object_graph(self, trackable_objects,
                                        object_graph_proto, node_ids,
                                        object_names, object_map,
                                        call_with_mapped_captures):
        """Create SaveableObjects and corresponding SerializedTensor protos."""
        named_saveable_objects = []
        if self._saveables_cache is None:
            # No SaveableObject caching. Either we're executing eagerly, or building a
            # static save which is specialized to the current Python state.
            feed_additions = None
        else:
            # If we are caching SaveableObjects, we need to build up a feed_dict with
            # functions computing volatile Python state to be saved with the
            # checkpoint.
            feed_additions = {}
        for checkpoint_id, (trackable, object_proto) in enumerate(
                zip(trackable_objects, object_graph_proto.nodes)):
            assert node_ids[trackable] == checkpoint_id
            object_name = object_names[trackable]
            if object_map is None:
                object_to_save = trackable
            else:
                object_to_save = object_map.get(trackable, trackable)
            if self._saveables_cache is not None:
                cached_attributes = self._saveables_cache.setdefault(
                    object_to_save, {})
            else:
                cached_attributes = None

            for name, saveable_factory in (
                    object_to_save._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
                attribute = object_proto.attributes.add()
                attribute.name = name
                attribute.checkpoint_key = "%s/%s/%s" % (
                    object_name, _OBJECT_ATTRIBUTES_NAME,
                    _escape_local_name(name))
                if cached_attributes is None:
                    saveables = None
                else:
                    saveables = cached_attributes.get(name, None)
                    if saveables is not None:
                        for saveable in saveables:
                            if attribute.checkpoint_key not in saveable.name:
                                # The checkpoint key for this SaveableObject is different. We
                                # need to re-create it.
                                saveables = None
                                del cached_attributes[name]
                                break
                if saveables is None:
                    if callable(saveable_factory):
                        maybe_saveable = saveable_object_util.create_saveable_object(
                            saveable_factory, attribute.checkpoint_key,
                            call_with_mapped_captures)
                    else:
                        maybe_saveable = saveable_factory
                    if isinstance(maybe_saveable,
                                  saveable_object_lib.SaveableObject):
                        saveables = (maybe_saveable, )
                    else:
                        # Figure out the name-based Saver's name for this variable. If it's
                        # already a SaveableObject we'd just get the checkpoint key back, so
                        # we leave full_name blank.
                        saver_dict = saveable_object_util.op_list_to_dict(
                            [maybe_saveable], convert_variable_to_tensor=False)
                        full_name, = saver_dict.keys()
                        saveables = tuple(
                            saveable_object_util.saveable_objects_for_op(
                                op=maybe_saveable,
                                name=attribute.checkpoint_key))
                        for saveable in saveables:
                            saveable.full_name = full_name
                    for saveable in saveables:
                        if attribute.checkpoint_key not in saveable.name:
                            raise AssertionError((
                                "The object %s produced a SaveableObject with name '%s' for "
                                "attribute '%s'. Expected a name containing '%s'."
                            ) % (trackable, name, saveable.name,
                                 attribute.checkpoint_key))
                    if cached_attributes is not None:
                        cached_attributes[name] = saveables

                optional_restore = None
                for saveable in saveables:
                    if optional_restore is None:
                        optional_restore = saveable.optional_restore
                    else:
                        optional_restore = optional_restore and saveable.optional_restore

                    if hasattr(saveable, "full_name"):
                        attribute.full_name = saveable.full_name
                    if isinstance(saveable, base.PythonStateSaveable):
                        if feed_additions is None:
                            assert self._saveables_cache is None
                            # If we're not caching saveables, then we're either executing
                            # eagerly or building a static save/restore (e.g. for a
                            # SavedModel). In either case, we should embed the current Python
                            # state in the graph rather than relying on a feed dict.
                            saveable = saveable.freeze()
                        else:
                            saveable_feed_dict = saveable.feed_dict_additions()
                            for new_feed_key in saveable_feed_dict.keys():
                                if new_feed_key in feed_additions:
                                    raise AssertionError((
                                        "The object %s tried to feed a value for the Tensor %s "
                                        "when saving, but another object is already feeding a "
                                        "value.") % (trackable, new_feed_key))
                            feed_additions.update(saveable_feed_dict)
                    named_saveable_objects.append(saveable)
                if optional_restore is None:
                    optional_restore = False
                attribute.optional_restore = optional_restore

        return named_saveable_objects, feed_additions
Exemplo n.º 9
0
    def _add_attributes_to_object_graph_for_saveable_objects(
            self, checkpoint_factory_map, object_graph_proto, node_ids,
            object_map, call_with_mapped_captures):
        """Create SaveableObjects and corresponding SerializedTensor protos."""
        named_saveable_objects = []
        if self._saveables_cache is None:
            # No SaveableObject caching. Either we're executing eagerly, or building a
            # static save which is specialized to the current Python state.
            feed_additions = None
        else:
            # If we are caching SaveableObjects, we need to build up a feed_dict with
            # functions computing volatile Python state to be saved with the
            # checkpoint.
            feed_additions = {}
        for trackable, factory_data_list in checkpoint_factory_map.items():
            object_proto = object_graph_proto.nodes[node_ids[trackable]]
            if self._saveables_cache is not None:
                object_to_save = _get_mapped_trackable(trackable, object_map)
                cached_attributes = self._saveables_cache.setdefault(
                    object_to_save, {})
            else:
                cached_attributes = None

            for factory_data in factory_data_list:
                attribute = object_proto.attributes.add()
                attribute.name = name = factory_data.name
                attribute.checkpoint_key = key = factory_data.checkpoint_key
                saveable_factory = factory_data.factory

                # See if we can skip saving this checkpoint key.
                saveables = cached_attributes.get(
                    name) if cached_attributes else None
                if saveables is not None:
                    for saveable in saveables:
                        if key not in saveable.name:
                            # The checkpoint key for this SaveableObject is different. We
                            # need to re-create it.
                            saveables = None
                            del cached_attributes[name]
                            break

                if saveables is None:
                    if callable(saveable_factory):
                        maybe_saveable = saveable_object_util.create_saveable_object(
                            saveable_factory, key, call_with_mapped_captures)
                    else:
                        maybe_saveable = saveable_factory
                    if isinstance(maybe_saveable,
                                  saveable_object_lib.SaveableObject):
                        saveables = (maybe_saveable, )
                    else:
                        # Figure out the name-based Saver's name for this variable. If it's
                        # already a SaveableObject we'd just get the checkpoint key back, so
                        # we leave full_name blank.
                        saver_dict = saveable_object_util.op_list_to_dict(
                            [maybe_saveable], convert_variable_to_tensor=False)
                        full_name, = saver_dict.keys()
                        saveables = tuple(
                            saveable_object_util.saveable_objects_for_op(
                                op=maybe_saveable, name=key))
                        for saveable in saveables:
                            saveable.full_name = full_name
                    for saveable in saveables:
                        if key not in saveable.name:
                            raise AssertionError(
                                f"The object {trackable} produced a SaveableObject with name "
                                f"'{saveable.name}' for attribute '{name}'. Expected a name"
                                f" containing '{key}'.")
                    if cached_attributes is not None:
                        cached_attributes[name] = saveables

                for saveable in saveables:
                    if hasattr(saveable, "full_name"):
                        attribute.full_name = saveable.full_name
                    if isinstance(saveable, base.PythonStateSaveable):
                        if feed_additions is None:
                            assert self._saveables_cache is None
                            # If we're not caching saveables, then we're either executing
                            # eagerly or building a static save/restore (e.g. for a
                            # SavedModel). In either case, we should embed the current Python
                            # state in the graph rather than relying on a feed dict.
                            saveable = saveable.freeze()
                        else:
                            saveable_feed_dict = saveable.feed_dict_additions()
                            for new_feed_key in saveable_feed_dict.keys():
                                if new_feed_key in feed_additions:
                                    raise AssertionError(
                                        f"The object {trackable} tried to feed a value for the "
                                        f"Tensor {new_feed_key} when saving, but another object "
                                        "is already feeding a value.")
                            feed_additions.update(saveable_feed_dict)
                    named_saveable_objects.append(saveable)

        return named_saveable_objects, feed_additions
  def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs):
    """Create a saver swapping moving averages and variables.

    You should use this saver during training.  It will save the moving averages
    of the trained parameters under the original parameter names.  For
    evaluations or inference you should use a regular saver and it will
    automatically use the moving averages for the trained variable.

    You must call this function after all variables have been created and after
    you have called Optimizer.minimize().

    Args:
      var_list: List of variables to save, as per `Saver()`.
                If set to None, will save all the variables that have been
                created before this call.
      name: The name of the saver.
      **kwargs: Keyword arguments of `Saver()`.

    Returns:
      A `tf.compat.v1.train.Saver` object.

    Raises:
      RuntimeError: If apply_gradients or minimize has not been called before.
      ValueError: If var_list is provided and contains some variables but not
        their moving average counterpart.
    """

    if self._swapped_variable_name_map is None:
      raise RuntimeError('Must call apply_gradients or minimize before '
                         'creating the swapping_saver')
    if var_list is None:
      var_list = variables.global_variables()
    if not isinstance(var_list, dict):
      var_list = saveable_object_util.op_list_to_dict(var_list)

    v_name_to_tensor = {}
    for k, tensor_or_list in six.iteritems(var_list):
      # For each partitioned variable OpListToDict returns list of constituent
      # parts instead of single tensor.
      if (isinstance(tensor_or_list, list)
          or isinstance(tensor_or_list, variables.PartitionedVariable)):
        for tensor in tensor_or_list:
          v_name = tensor.op.name
          v_name_to_tensor[v_name] = tensor
      else:
        v_name_to_tensor[k] = tensor_or_list

    # Now swap variables and moving averages
    swapped_var_list = {}
    for k, tensor_or_list in six.iteritems(var_list):
      if isinstance(tensor_or_list, list):
        tensor_list_to_save = []
        for tensor in tensor_or_list:
          v_name = tensor.op.name
          swapped_variable = self._find_swapped_variable(v_name_to_tensor,
                                                         v_name,
                                                         tensor)
          tensor_list_to_save.append(swapped_variable)
        swapped_var_list[k] = tensor_list_to_save
      else:
        swapped_var_list[k] = self._find_swapped_variable(
            v_name_to_tensor, k, tensor_or_list)

    # Build the swapping saver.
    return saver.Saver(swapped_var_list, name=name, **kwargs)
Exemplo n.º 11
0
    def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs):
        """Create a saver swapping moving averages and variables.

    You should use this saver during training.  It will save the moving averages
    of the trained parameters under the original parameter names.  For
    evaluations or inference you should use a regular saver and it will
    automatically use the moving averages for the trained variable.

    You must call this function after all variables have been created and after
    you have called Optimizer.minimize().

    Args:
      var_list: List of variables to save, as per `Saver()`.
                If set to None, will save all the variables that have been
                created before this call.
      name: The name of the saver.
      **kwargs: Keyword arguments of `Saver()`.

    Returns:
      A `tf.compat.v1.train.Saver` object.

    Raises:
      RuntimeError: If apply_gradients or minimize has not been called before.
      ValueError: If var_list is provided and contains some variables but not
        their moving average counterpart.
    """

        if self._swapped_variable_name_map is None:
            raise RuntimeError('Must call apply_gradients or minimize before '
                               'creating the swapping_saver')
        if var_list is None:
            var_list = variables.global_variables()
        if not isinstance(var_list, dict):
            var_list = saveable_object_util.op_list_to_dict(var_list)

        v_name_to_tensor = {}
        for k, tensor_or_list in six.iteritems(var_list):
            # For each partitioned variable OpListToDict returns list of constituent
            # parts instead of single tensor.
            if (isinstance(tensor_or_list, list) or isinstance(
                    tensor_or_list, variables.PartitionedVariable)):
                for tensor in tensor_or_list:
                    v_name = tensor.op.name
                    v_name_to_tensor[v_name] = tensor
            else:
                v_name_to_tensor[k] = tensor_or_list

        # Now swap variables and moving averages
        swapped_var_list = {}
        for k, tensor_or_list in six.iteritems(var_list):
            if isinstance(tensor_or_list, list):
                tensor_list_to_save = []
                for tensor in tensor_or_list:
                    v_name = tensor.op.name
                    swapped_variable = self._find_swapped_variable(
                        v_name_to_tensor, v_name, tensor)
                    tensor_list_to_save.append(swapped_variable)
                swapped_var_list[k] = tensor_list_to_save
            else:
                swapped_var_list[k] = self._find_swapped_variable(
                    v_name_to_tensor, k, tensor_or_list)

        # Build the swapping saver.
        return saver.Saver(swapped_var_list, name=name, **kwargs)