Exemplo n.º 1
0
    def objects_ids_and_slot_variables_and_paths(self):
        """Traverse the object graph and list all accessible objects.

    Looks for `Trackable` objects which are dependencies of
    `root_trackable`. Includes slot variables only if the variable they are
    slotting for and the optimizer are dependencies of `root_trackable`
    (i.e. if they would be saved with a checkpoint).

    Returns:
      A tuple of (trackable objects, paths from root for each object,
                  object -> node id, slot variables, object_names)
    """
        trackable_objects, node_paths = self._breadth_first_traversal()
        object_names = object_identity.ObjectIdentityDictionary()
        for obj, path in node_paths.items():
            object_names[obj] = trackable_utils.object_path_to_string(path)
        node_ids = object_identity.ObjectIdentityDictionary()
        for node_id, node in enumerate(trackable_objects):
            node_ids[node] = node_id
        slot_variables = _serialize_slot_variables(
            trackable_objects=trackable_objects,
            node_ids=node_ids,
            object_names=object_names)
        return (trackable_objects, node_paths, node_ids, slot_variables,
                object_names)
Exemplo n.º 2
0
 def _serialize_gathered_objects(self,
                                 trackable_objects,
                                 node_paths,
                                 object_map=None,
                                 call_with_mapped_captures=None):
     """Create SaveableObjects and protos for gathered objects."""
     object_names = object_identity.ObjectIdentityDictionary()
     for obj, path in node_paths.items():
         object_names[obj] = trackable_utils.object_path_to_string(path)
     node_ids = object_identity.ObjectIdentityDictionary()
     for node_id, node in enumerate(trackable_objects):
         node_ids[node] = node_id
     slot_variables = _serialize_slot_variables(
         trackable_objects=trackable_objects,
         node_ids=node_ids,
         object_names=object_names)
     object_graph_proto = self._fill_object_graph_proto(
         trackable_objects=trackable_objects,
         node_ids=node_ids,
         slot_variables=slot_variables)
     named_saveable_objects, feed_additions, registered_savers = (
         self._add_attributes_to_object_graph(
             trackable_objects=trackable_objects,
             object_graph_proto=object_graph_proto,
             node_ids=node_ids,
             object_names=object_names,
             object_map=object_map,
             call_with_mapped_captures=call_with_mapped_captures))
     # Gather all trackables that have checkpoint values or descendants with
     # checkpoint values, and add that info to the proto.
     self._add_checkpoint_values_check(trackable_objects,
                                       object_graph_proto)
     return (named_saveable_objects, object_graph_proto, feed_additions,
             registered_savers)
Exemplo n.º 3
0
 def _serialize_gathered_objects(self,
                                 trackable_objects,
                                 path_to_root,
                                 object_map=None,
                                 call_with_mapped_captures=None):
     """Create SaveableObjects and protos for gathered objects."""
     object_names = object_identity.ObjectIdentityDictionary()
     for obj, path in path_to_root.items():
         object_names[obj] = _object_prefix_from_path(path)
     node_ids = object_identity.ObjectIdentityDictionary()
     for node_id, node in enumerate(trackable_objects):
         node_ids[node] = node_id
     slot_variables = _serialize_slot_variables(
         trackable_objects=trackable_objects,
         node_ids=node_ids,
         object_names=object_names)
     object_graph_proto = self._fill_object_graph_proto(
         trackable_objects=trackable_objects,
         node_ids=node_ids,
         slot_variables=slot_variables)
     named_saveable_objects, feed_additions = (
         self._add_attributes_to_object_graph(
             trackable_objects=trackable_objects,
             object_graph_proto=object_graph_proto,
             node_ids=node_ids,
             object_names=object_names,
             object_map=object_map,
             call_with_mapped_captures=call_with_mapped_captures))
     return named_saveable_objects, object_graph_proto, feed_additions
Exemplo n.º 4
0
 def __init__(self, root):
     if (not context.executing_eagerly() and not ops.inside_function()):
         saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
     else:
         saveables_cache = None
     super(_AugmentedGraphView, self).__init__(root, saveables_cache)
     # Object -> (name -> dep)
     self._extra_dependencies = object_identity.ObjectIdentityDictionary()
     self._functions = object_identity.ObjectIdentityDictionary()
     # Cache shared between objects in the same object graph. This is passed to
     # each trackable object's `_list_extra_dependencies_for_serialization` and
     # `_list_functions_for_serialization` function.
     self._serialization_cache = object_identity.ObjectIdentityDictionary()
    def __init__(self,
                 decay,
                 num_updates=None,
                 zero_debias=False,
                 name="ExponentialMovingAverage"):
        """Creates a new ExponentialMovingAverage object.

    The `apply()` method has to be called to create shadow variables and add
    ops to maintain moving averages.

    The optional `num_updates` parameter allows one to tweak the decay rate
    dynamically. It is typical to pass the count of training steps, usually
    kept in a variable that is incremented at each step, in which case the
    decay rate is lower at the start of training.  This makes moving averages
    move faster.  If passed, the actual decay rate used is:

      `min(decay, (1 + num_updates) / (10 + num_updates))`

    Args:
      decay: Float.  The decay to use.
      num_updates: Optional count of number of updates applied to variables.
      zero_debias: If `True`, zero debias moving-averages that are initialized
        with tensors.
      name: String. Optional prefix name to use for the name of ops added in
        `apply()`.
    """
        self._decay = decay
        self._num_updates = num_updates
        self._zero_debias = zero_debias
        self._name = name
        self._averages = object_identity.ObjectIdentityDictionary()
Exemplo n.º 6
0
    def initialize_variables():
      op_map = object_identity.ObjectIdentityDictionary()
      # Stack all the var_is_initialized values into one tensor and intepret the
      # numpy value. This will reduce the number of RPCs between client and
      # worker in the remote case.
      with ops.init_scope():
        var_is_initialized = []
        for v, _ in initializers:
          var_is_initialized.append(
              resource_variable_ops.var_is_initialized_op(v.handle))
        var_is_initialized = array_ops.stack(var_is_initialized).numpy()

      inits = []
      for (v, init), is_initialized in zip(initializers, var_is_initialized):
        with ops.init_scope():
          if is_initialized:
            continue
        inits.append(init)

      if inits:
        op_map = lift_to_graph.lift_to_graph(
            inits, ops.get_default_graph(), op_map=op_map)
      for (v, init), is_initialized in zip(initializers, var_is_initialized):
        with ops.init_scope():
          if is_initialized:
            continue
        v.assign(op_map[init], read_value=False)
Exemplo n.º 7
0
  def __init__(self, checkpoint_view):
    self.checkpoint_view = checkpoint_view
    trackable_objects, node_ids, slot_variables = (
        self.checkpoint_view.objects_ids_and_slot_variables())
    self.nodes = trackable_objects
    self.node_ids = node_ids
    self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
    self.slot_variables = slot_variables
    self.concrete_functions = []

    # Also add `Function`s as nodes.
    nodes_without_functions = list(self.nodes)
    seen_function_names = set()
    for node in nodes_without_functions:
      for function in checkpoint_view.list_functions(node).values():
        if function not in self.node_ids:
          self.node_ids[function] = len(self.nodes)
          self.nodes.append(function)
        if isinstance(function, def_function.Function):
          # Force listing the concrete functions for the side effects:
          #  - populate the cache for functions that have an input_signature
          #  and have not been called.
          #  - force side effects of creation of concrete functions, e.g. create
          #  variables on first run.
          concrete_functions = (
              function._list_all_concrete_functions_for_serialization())  # pylint: disable=protected-access
        else:
          concrete_functions = [function]
        for concrete_function in concrete_functions:
          if concrete_function.name not in seen_function_names:
            seen_function_names.add(concrete_function.name)
            self.concrete_functions.append(concrete_function)
Exemplo n.º 8
0
 def initialize_variables():
   op_map = object_identity.ObjectIdentityDictionary()
   for v, init in initializer_map.items():
     with ops.init_scope():
       if resource_variable_ops.var_is_initialized_op(v.handle):
         # Ignore variables which are already initialized at trace time.
         continue
     op_map = lift_to_graph.lift_to_graph(
         [init], ops.get_default_graph(), op_map=op_map)
     v.assign(op_map[init])
Exemplo n.º 9
0
def _serialize_slot_variables(trackable_objects, node_ids, object_names):
    """Gather and name slot variables."""
    non_slot_objects = list(trackable_objects)
    slot_variables = object_identity.ObjectIdentityDictionary()
    for trackable in non_slot_objects:
        if (isinstance(trackable, optimizer_v1.Optimizer)
                # TODO(b/110718070): Fix Keras imports.
                # Note: dir() is used rather than hasattr() here to avoid triggering
                # custom __getattr__ code, see b/152031870 for context.
                or "_create_or_restore_slot_variable" in dir(trackable)):
            naming_scheme = _slot_variable_naming_for_optimizer(
                optimizer_path=object_names[trackable])
            slot_names = trackable.get_slot_names()
            for slot_name in slot_names:
                for original_variable_node_id, original_variable in enumerate(
                        non_slot_objects):
                    try:
                        slot_variable = trackable.get_slot(
                            original_variable, slot_name)
                    except (AttributeError, KeyError):
                        slot_variable = None
                    if slot_variable is None:
                        continue
                    slot_variable._maybe_initialize_trackable()  # pylint: disable=protected-access
                    if slot_variable._checkpoint_dependencies:  # pylint: disable=protected-access
                        # TODO(allenl): Gather dependencies of slot variables.
                        raise NotImplementedError(
                            "Currently only variables with no dependencies can be saved as "
                            "slot variables. File a feature request if this limitation "
                            "bothers you.")
                    if slot_variable in node_ids:
                        raise NotImplementedError(
                            "A slot variable was re-used as a dependency of a Trackable "
                            f"object: {slot_variable}. This is not currently allowed. "
                            "File a feature request if this limitation bothers you."
                        )
                    checkpoint_name = naming_scheme(
                        variable_path=object_names[original_variable],
                        slot_name=slot_name)
                    object_names[slot_variable] = checkpoint_name
                    slot_variable_node_id = len(trackable_objects)
                    node_ids[slot_variable] = slot_variable_node_id
                    trackable_objects.append(slot_variable)
                    slot_variable_proto = (
                        trackable_object_graph_pb2.TrackableObjectGraph.
                        TrackableObject.SlotVariableReference(
                            slot_name=slot_name,
                            original_variable_node_id=original_variable_node_id,
                            slot_variable_node_id=slot_variable_node_id))
                    slot_variables.setdefault(trackable,
                                              []).append(slot_variable_proto)
    return slot_variables
Exemplo n.º 10
0
def get_checkpoint_factories_and_keys(object_names, object_map=None):
    """Gets a map of saveable factories and corresponding checkpoint keys.

  Args:
    object_names: a dictionary that maps `Trackable` objects to auto-generated
      string names.
    object_map: a dictionary mapping `Trackable` to copied `Trackable` objects.
      The copied objects are generated from `Trackable._map_resources()` which
      copies the object into another graph. Generally only resource objects
      (e.g. Variables, Tables) will be in this map.

  Returns:
    A tuple of (
      Dictionary mapping trackable -> list of _CheckpointFactoryData,
      Dictionary mapping registered saver name -> {object name -> trackable})
  """
    checkpoint_factory_map = object_identity.ObjectIdentityDictionary()
    unmapped_registered_savers = collections.defaultdict(dict)
    for trackable, object_name in object_names.items():
        # object_to_save is only used to retrieve the saving functionality. For keys
        # and other data, use the original `trackable`.
        object_to_save = util.get_mapped_trackable(trackable, object_map)

        saver_name = registration.get_registered_saver_name(object_to_save)
        if saver_name:
            # Add the original trackable instead of `object_to_save` to the returned
            # dict because the original is needed for writing the object proto.
            unmapped_registered_savers[saver_name][object_name] = trackable
        else:
            checkpoint_factory_map[trackable] = []
            for name, saveable_factory in (
                    saveable_object_util.saveable_objects_from_trackable(
                        object_to_save).items()):  # pylint: disable=protected-access
                # Retrieve the legacy saveable name (for compatibility purposes during
                # SaveableObject deprecation)

                key_suffix = saveable_compat.get_saveable_name(
                    object_to_save) or name
                checkpoint_key = trackable_utils.checkpoint_key(
                    object_name, key_suffix)

                if not saveable_compat.force_checkpoint_conversion_enabled():
                    # Make sure the set the name as the legacy saveable name if there
                    # is one (only when checkpoint conversion is diabled)
                    name = key_suffix

                checkpoint_factory_map[trackable].append(
                    _CheckpointFactoryData(factory=saveable_factory,
                                           name=name,
                                           checkpoint_key=checkpoint_key))
    return checkpoint_factory_map, unmapped_registered_savers
Exemplo n.º 11
0
    def __init__(self, checkpoint_view, wrapped_functions=None):
        """Initializes a SaveableView.

    Args:
      checkpoint_view: A GraphView object.
      wrapped_functions: Dictionary that maps concrete functions to functions
        that do not capture cached variable values.
    """
        self.checkpoint_view = checkpoint_view
        trackable_objects, node_ids, slot_variables = (
            self.checkpoint_view.objects_ids_and_slot_variables())
        self.nodes = trackable_objects
        self.node_ids = node_ids
        self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary(
        )
        self.slot_variables = slot_variables
        self.concrete_functions = []

        # Maps functions -> wrapped functions that capture variables
        self.wrapped_functions = wrapped_functions or {}
        # Maps names of concrete functions in the object to names of wrapped
        # functions. When writing the SavedFunction protos, the names of the
        # wrapped functions should be used in place of the original functions.
        self.function_name_map = {
            compat.as_text(original.name): compat.as_text(wrapped.name)
            for original, wrapped in self.wrapped_functions.items()
        }

        # Also add `Function`s as nodes.
        nodes_without_functions = list(self.nodes)
        seen_function_names = set()
        for node in nodes_without_functions:
            for function in checkpoint_view.list_functions(node).values():
                if function not in self.node_ids:
                    self.node_ids[function] = len(self.nodes)
                    self.nodes.append(function)
                if isinstance(function, def_function.Function):
                    # Force listing the concrete functions for the side effects:
                    #  - populate the cache for functions that have an input_signature
                    #  and have not been called.
                    #  - force side effects of creation of concrete functions, e.g. create
                    #  variables on first run.
                    concrete_functions = (
                        function.
                        _list_all_concrete_functions_for_serialization())  # pylint: disable=protected-access
                else:
                    concrete_functions = [function]
                for concrete_function in concrete_functions:
                    if concrete_function.name not in seen_function_names:
                        seen_function_names.add(concrete_function.name)
                        self.concrete_functions.append(concrete_function)
Exemplo n.º 12
0
 def _descendants_with_paths(self):
     """Returns a list of all nodes and its paths from self.root using a breadth first traversal."""
     bfs_sorted = []
     to_visit = collections.deque([self.root])
     node_paths = object_identity.ObjectIdentityDictionary()
     node_paths[self.root] = ()
     while to_visit:
         current_trackable = to_visit.popleft()
         bfs_sorted.append(current_trackable)
         for name, dependency in self.children(current_trackable).items():
             if dependency not in node_paths:
                 node_paths[dependency] = (
                     node_paths[current_trackable] +
                     (base.TrackableReference(name, dependency), ))
                 to_visit.append(dependency)
     return bfs_sorted, node_paths
Exemplo n.º 13
0
 def _breadth_first_traversal(self):
     """Find shortest paths to all dependencies of self.root."""
     bfs_sorted = []
     to_visit = collections.deque([self.root])
     node_paths = object_identity.ObjectIdentityDictionary()
     node_paths[self.root] = ()
     while to_visit:
         current_trackable = to_visit.popleft()
         bfs_sorted.append(current_trackable)
         for name, dependency in self.list_children(current_trackable):
             if dependency not in node_paths:
                 node_paths[dependency] = (
                     node_paths[current_trackable] +
                     (base.TrackableReference(name, dependency), ))
                 to_visit.append(dependency)
     return bfs_sorted, node_paths
Exemplo n.º 14
0
    def _add_checkpoint_values_check(self, trackable_objects,
                                     object_graph_proto):
        """Determines which objects have checkpoint values and saves to the proto.

    Args:
      trackable_objects: A list of all trackable objects.
      object_graph_proto: A `TrackableObjectGraph` proto.
    """
        # Trackable -> set of all trackables that depend on it (the "parents").
        # If a trackable has checkpoint values, then all of the parents can be
        # marked as having checkpoint values.
        parents = object_identity.ObjectIdentityDictionary()
        checkpointed_trackables = object_identity.ObjectIdentitySet()

        # First pass: build dictionary of parent objects and initial set of
        # checkpointed trackables.
        for trackable, object_proto in zip(trackable_objects,
                                           object_graph_proto.nodes):
            if (object_proto.attributes or object_proto.slot_variables
                    or object_proto.HasField("registered_saver")):
                checkpointed_trackables.add(trackable)
            for child_proto in object_proto.children:
                child = trackable_objects[child_proto.node_id]
                if child not in parents:
                    parents[child] = object_identity.ObjectIdentitySet()
                parents[child].add(trackable)

        # Second pass: add all connected parents to set of checkpointed trackables.
        to_visit = object_identity.ObjectIdentitySet()
        to_visit.update(checkpointed_trackables)

        while to_visit:
            trackable = to_visit.pop()
            if trackable not in parents:
                # Some trackables may not have parents (e.g. slot variables).
                continue
            current_parents = parents.pop(trackable)
            checkpointed_trackables.update(current_parents)
            for parent in current_parents:
                if parent in parents:
                    to_visit.add(parent)

        for node_id, trackable in enumerate(trackable_objects):
            object_graph_proto.nodes[
                node_id].has_checkpoint_values.value = bool(
                    trackable in checkpointed_trackables)
Exemplo n.º 15
0
    def get_initialization_function(self, *args, **kwargs):
        """Returns a `ConcreteFunction` which initializes this function's variables.

    Requires that this function hasn't been accessed yet through either calling
    it or calling get_concrete_function. Fails if we cannot build an initializer
    function which does not depend on the concrete values of the inputs to this
    function.

    Note that running this function will overwrite any values currently assigned
    to variables, for example restores from a checkpoint.

    Args:
      *args: arguments to the underlying python callable.
      **kwargs: keyword arguments to the python callable.

    Returns:
      A `ConcreteFunction` object which initializes the variables of this
      function.

    Raises:
      RuntimeError: if called after the variables have been initialized.
    """
        with self._lock:
            if self._stateful_fn is not None:
                raise RuntimeError(
                    "get_initialization_function cannot be called after the function "
                    "has been used")
            # Here we trace the function, collect the initializers, and attempt to
            # extract them and run them eagerly. Fail only if we cannot do so.
            initializer_map = object_identity.ObjectIdentityDictionary()
            self._initialize(args, kwargs, add_initializers_to=initializer_map)

        # Note: using defun here avoids an infinite recursion.
        @function_lib.defun
        def initialize_variables():
            for v, init in initializer_map.items():
                v.assign(
                    lift_to_graph.lift_to_graph([init],
                                                ops.get_default_graph())[init])

        return initialize_variables.get_concrete_function()
Exemplo n.º 16
0
def get_checkpoint_factories_and_keys(object_names):
    """Gets a map of saveable factories and corresponding checkpoint keys.

  Args:
    object_names: a dictionary that maps `Trackable` objects to auto-generated
      string names.
  Returns:
    A dictionary mapping Trackables -> a list of _CheckpointFactoryData.
  """
    checkpoint_factory_map = object_identity.ObjectIdentityDictionary()
    for trackable, object_name in object_names.items():
        checkpoint_factory_map[trackable] = []
        for name, saveable_factory in (
                trackable._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
            checkpoint_key = "%s/%s/%s" % (
                object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
            checkpoint_factory_map[trackable].append(
                _CheckpointFactoryData(factory=saveable_factory,
                                       name=name,
                                       checkpoint_key=checkpoint_key))
    return checkpoint_factory_map
Exemplo n.º 17
0
  def test_serialize_gathered_objects_with_map(self):
    root = autotrackable.AutoTrackable()
    root.v = variables.Variable(1.0)
    root.registered = TrackableWithRegisteredSaver()

    copy_of_registered = TrackableWithRegisteredSaver()
    copy_of_v = variables.Variable(1.0)
    object_map = object_identity.ObjectIdentityDictionary()
    object_map[root.registered] = copy_of_registered
    object_map[root.v] = copy_of_v

    named_saveable_objects, _, _, registered_savers = (
        save_util_v1.serialize_gathered_objects(
            graph_view.ObjectGraphView(root), object_map))

    self.assertLen(named_saveable_objects, 1)
    self.assertIsNot(named_saveable_objects[0].op, root.v)
    self.assertIs(named_saveable_objects[0].op, copy_of_v)

    ret_value = registered_savers["Custom.RegisteredSaver"]["registered"]
    self.assertIsNot(root.registered, ret_value)
    self.assertIs(copy_of_registered, ret_value)
Exemplo n.º 18
0
 def _breadth_first_traversal(self):
     """Find shortest paths to all dependencies of self.root."""
     bfs_sorted = []
     to_visit = collections.deque([self.root])
     path_to_root = object_identity.ObjectIdentityDictionary()
     path_to_root[self.root] = ()
     while to_visit:
         current_trackable = to_visit.popleft()
         if isinstance(current_trackable, tracking.NotTrackable):
             raise NotImplementedError((
                 "The object %s does not support object-based saving. File a "
                 "feature request if this limitation bothers you. In the meantime, "
                 "you can remove the dependency on this object and save everything "
                 "else.") % (current_trackable, ))
         bfs_sorted.append(current_trackable)
         for name, dependency in self.list_dependencies(current_trackable):
             if dependency not in path_to_root:
                 path_to_root[dependency] = (
                     path_to_root[current_trackable] +
                     (base.TrackableReference(name, dependency), ))
                 to_visit.append(dependency)
     return bfs_sorted, path_to_root
Exemplo n.º 19
0
def get_checkpoint_factories_and_keys(object_names, object_map=None):
    """Gets a map of saveable factories and corresponding checkpoint keys.

  Args:
    object_names: a dictionary that maps `Trackable` objects to auto-generated
      string names.
    object_map: a dictionary mapping `Trackable` to copied `Trackable` objects.
      The copied objects are generated from `Trackable._map_resources()` which
      copies the object into another graph. Generally only resource objects
      (e.g. Variables, Tables) will be in this map.
  Returns:
    A tuple of (
      Dictionary mapping trackable -> list of _CheckpointFactoryData,
      Dictionary mapping registered saver name -> {object name -> trackable})
  """
    checkpoint_factory_map = object_identity.ObjectIdentityDictionary()
    registered_savers = collections.defaultdict(dict)
    for trackable, object_name in object_names.items():
        # object_to_save is only used to retrieve the saving functionality. For keys
        # and other data, use the original `trackable`.
        object_to_save = _get_mapped_trackable(trackable, object_map)

        saver_name = registration.get_registered_saver_name(object_to_save)
        if saver_name:
            registered_savers[saver_name][object_name] = trackable
        else:
            checkpoint_factory_map[trackable] = []
            for name, saveable_factory in (
                    saveable_object_util.saveable_objects_from_trackable(
                        object_to_save).items()):  # pylint: disable=protected-access
                checkpoint_key = trackable_utils.checkpoint_key(
                    object_name, name)
                checkpoint_factory_map[trackable].append(
                    _CheckpointFactoryData(factory=saveable_factory,
                                           name=name,
                                           checkpoint_key=checkpoint_key))
    return checkpoint_factory_map, registered_savers
Exemplo n.º 20
0
    def _add_attributes_to_object_graph(self, trackable_objects,
                                        object_graph_proto, node_ids,
                                        object_names, object_map,
                                        call_with_mapped_captures):
        """Create SaveableObjects and corresponding SerializedTensor protos."""
        named_saveable_objects = []
        if self._saveables_cache is None:
            # No SaveableObject caching. Either we're executing eagerly, or building a
            # static save which is specialized to the current Python state.
            feed_additions = None
        else:
            # If we are caching SaveableObjects, we need to build up a feed_dict with
            # functions computing volatile Python state to be saved with the
            # checkpoint.
            feed_additions = {}
        if object_map is None:
            mapped_object_names = object_names
        else:
            mapped_object_names = object_identity.ObjectIdentityDictionary()
            for trackable, name in object_names.items():
                mapped_object_names[object_map.get(trackable,
                                                   trackable)] = name
        checkpoint_factory_map = get_checkpoint_factories_and_keys(
            mapped_object_names)
        for checkpoint_id, (trackable, object_proto) in enumerate(
                zip(trackable_objects, object_graph_proto.nodes)):
            assert node_ids[trackable] == checkpoint_id
            if object_map is None:
                object_to_save = trackable
            else:
                object_to_save = object_map.get(trackable, trackable)
            if self._saveables_cache is not None:
                cached_attributes = self._saveables_cache.setdefault(
                    object_to_save, {})
            else:
                cached_attributes = None

            for factory_data in checkpoint_factory_map[object_to_save]:
                attribute = object_proto.attributes.add()
                attribute.name = name = factory_data.name
                attribute.checkpoint_key = key = factory_data.checkpoint_key
                saveable_factory = factory_data.factory

                # See if we can skip saving this checkpoint key.
                saveables = cached_attributes.get(
                    name) if cached_attributes else None
                if saveables is not None:
                    for saveable in saveables:
                        if key not in saveable.name:
                            # The checkpoint key for this SaveableObject is different. We
                            # need to re-create it.
                            saveables = None
                            del cached_attributes[name]
                            break

                if saveables is None:
                    if callable(saveable_factory):
                        maybe_saveable = saveable_object_util.create_saveable_object(
                            saveable_factory, key, call_with_mapped_captures)
                    else:
                        maybe_saveable = saveable_factory
                    if isinstance(maybe_saveable,
                                  saveable_object_lib.SaveableObject):
                        saveables = (maybe_saveable, )
                    else:
                        # Figure out the name-based Saver's name for this variable. If it's
                        # already a SaveableObject we'd just get the checkpoint key back, so
                        # we leave full_name blank.
                        saver_dict = saveable_object_util.op_list_to_dict(
                            [maybe_saveable], convert_variable_to_tensor=False)
                        full_name, = saver_dict.keys()
                        saveables = tuple(
                            saveable_object_util.saveable_objects_for_op(
                                op=maybe_saveable, name=key))
                        for saveable in saveables:
                            saveable.full_name = full_name
                    for saveable in saveables:
                        if key not in saveable.name:
                            raise AssertionError(
                                f"The object {trackable} produced a SaveableObject with name "
                                f"'{saveable.name}' for attribute '{name}'. Expected a name"
                                f" containing '{key}'.")
                    if cached_attributes is not None:
                        cached_attributes[name] = saveables

                optional_restore = None
                for saveable in saveables:
                    if optional_restore is None:
                        optional_restore = saveable.optional_restore
                    else:
                        optional_restore = optional_restore and saveable.optional_restore

                    if hasattr(saveable, "full_name"):
                        attribute.full_name = saveable.full_name
                    if isinstance(saveable, base.PythonStateSaveable):
                        if feed_additions is None:
                            assert self._saveables_cache is None
                            # If we're not caching saveables, then we're either executing
                            # eagerly or building a static save/restore (e.g. for a
                            # SavedModel). In either case, we should embed the current Python
                            # state in the graph rather than relying on a feed dict.
                            saveable = saveable.freeze()
                        else:
                            saveable_feed_dict = saveable.feed_dict_additions()
                            for new_feed_key in saveable_feed_dict.keys():
                                if new_feed_key in feed_additions:
                                    raise AssertionError(
                                        f"The object {trackable} tried to feed a value for the "
                                        f"Tensor {new_feed_key} when saving, but another object "
                                        "is already feeding a value.")
                            feed_additions.update(saveable_feed_dict)
                    named_saveable_objects.append(saveable)
                if optional_restore is None:
                    optional_restore = False
                attribute.optional_restore = optional_restore

        return named_saveable_objects, feed_additions
Exemplo n.º 21
0
  def map_resources(self):
    """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
    # Only makes sense when adding to the export Graph
    assert not context.executing_eagerly()
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = object_identity.ObjectIdentityDictionary()
    resource_map = {}
    asset_info = _AssetInfo(
        asset_defs=[],
        asset_initializers_by_resource={},
        asset_filename_map={},
        asset_index={})

    for node_id, obj in enumerate(self.nodes):
      if isinstance(obj, tracking.CapturableResource):
        # pylint: disable=protected-access
        with ops.device(obj._resource_device):
          new_resource = obj._create_resource()
        # pylint: enable=protected-access
        resource_map[obj.resource_handle] = new_resource
        self.captured_tensor_node_ids[obj.resource_handle] = node_id
      elif (ds_values.is_distributed_variable(obj) or
            resource_variable_ops.is_resource_variable(obj)):
        obj_to_copy = obj.primary if ds_values.is_distributed_variable(
            obj) else obj
        new_variable = resource_variable_ops.copy_to_graph_uninitialized(
            obj_to_copy)
        if ds_values.is_distributed_variable(obj):
          self.captured_tensor_node_ids[obj] = node_id
        object_map[obj] = new_variable
        resource_map[obj.handle] = new_variable.handle
        self.captured_tensor_node_ids[obj.handle] = node_id
      elif isinstance(obj, tracking.Asset):
        _process_asset(obj, asset_info, resource_map)
        self.captured_tensor_node_ids[obj.asset_path] = node_id

    for concrete_function in self.concrete_functions:
      if not concrete_function.graph.saveable:
        raise ValueError(
            ("Unable to save function {name} for the following reason(s):\n" +
             "\n".join(concrete_function.graph.saving_errors))
            .format(name=concrete_function.name))
      for capture in concrete_function.captured_inputs:
        if (tensor_util.is_tensor(capture)
            and capture.dtype not in _UNCOPIABLE_DTYPES
            and capture not in self.captured_tensor_node_ids):
          capture_constant_value = tensor_util.constant_value(capture)
          if capture_constant_value is None:
            raise ValueError(
                ("Attempted to save a function {} which references a symbolic "
                 "Tensor {} that is not a simple constant. This is not "
                 "supported.").format(concrete_function.name, capture))
          copied_tensor = constant_op.constant(capture_constant_value)
          node_id = len(self.nodes)
          node = _CapturedConstant(
              eager_tensor=capture, graph_tensor=copied_tensor)
          self.nodes.append(node)
          self.node_ids[capture] = node_id
          self.node_ids[node] = node_id
          self.captured_tensor_node_ids[capture] = node_id
          resource_map[capture] = copied_tensor

    return object_map, resource_map, asset_info
Exemplo n.º 22
0
def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
  """Clone a functional `Model` instance.

  Model cloning is similar to calling a model on new inputs,
  except that it creates new layers (and thus new weights) instead
  of sharing the weights of the existing layers.

  Input layers are always cloned.

  Arguments:
      model: Instance of `Model`.
      input_tensors: optional list of input tensors
          to build the model upon. If not provided,
          placeholders will be created.
      layer_fn: callable to be applied on non-input layers in the model. By
          default it clones the layer. Another example is to preserve the layer
          to share the weights. This is required when we create a per-replica
          copy of the model with distribution strategy; we want the weights to
          be shared but still feed inputs separately so we create new input
          layers.

  Returns:
      An instance of `Model` reproducing the behavior
      of the original model, on top of new inputs tensors,
      using newly instantiated weights.

  Raises:
      ValueError: in case of invalid `model` argument value or `layer_fn`
      argument value.
  """
  if not isinstance(model, Model):
    raise ValueError('Expected `model` argument '
                     'to be a `Model` instance, got ', model)
  if isinstance(model, Sequential):
    raise ValueError('Expected `model` argument '
                     'to be a functional `Model` instance, '
                     'got a `Sequential` instance instead:', model)
  if not model._is_graph_network:
    raise ValueError('Expected `model` argument '
                     'to be a functional `Model` instance, '
                     'but got a subclass model instead.')

  layer_map = {}  # Cache for created layers.
  tensor_map = object_identity.ObjectIdentityDictionary(
  )  # Map {reference_tensor: corresponding_tensor}
  if input_tensors is None:
    # Create placeholders to build the model on top of.
    input_tensors = []
    for layer in model._input_layers:
      input_tensor = Input(
          batch_shape=layer._batch_input_shape,
          dtype=layer.dtype,
          sparse=layer.sparse,
          name=layer.name)
      input_tensors.append(input_tensor)
      # Cache newly created input layer.
      newly_created_input_layer = input_tensor._keras_history.layer
      layer_map[layer] = newly_created_input_layer
  else:
    # Make sure that all input tensors come from a Keras layer.
    # If tensor comes from an input layer: cache the input layer.
    input_tensors = nest.flatten(input_tensors)
    input_tensors_ = []
    for i, input_tensor in enumerate(input_tensors):
      if not K.is_keras_tensor(input_tensor):
        original_input_layer = model._input_layers[i]
        name = original_input_layer.name
        input_tensor = Input(tensor=input_tensor,
                             name='input_wrapper_for_' + name)

        input_tensors_.append(input_tensor)
        # Cache newly created input layer.
        newly_created_input_layer = input_tensor._keras_history.layer
        layer_map[original_input_layer] = newly_created_input_layer
      else:
        input_tensors_.append(input_tensor)
    input_tensors = input_tensors_

  for x, y in zip(model.inputs, input_tensors):
    tensor_map[x] = y

  if not callable(layer_fn):
    raise ValueError('Expected `layer_fn` argument to be a callable.')

  # Has the side effect of filling out `layer_map` and `tensor_map`.
  new_nodes = _make_new_nodes(model._nodes_by_depth, layer_fn, layer_map,
                              tensor_map)
  # Check that we did compute the model outputs,
  # then instantiate a new model from inputs and outputs.
  output_tensors = []
  for x in model.outputs:
    assert x in tensor_map, 'Could not compute output ' + str(x)
    output_tensors.append(tensor_map[x])

  input_tensors = nest.pack_sequence_as(model._nested_inputs, input_tensors)
  output_tensors = nest.pack_sequence_as(model._nested_outputs, output_tensors)
  metrics_names = model.metrics_names
  model = Model(input_tensors, output_tensors, name=model.name)
  # Layers not directly tied to outputs of the Model, such as loss layers
  # created in `add_loss` and `add_metric`.
  ancillary_layers = [
      layer for layer in layer_map.values() if layer not in model.layers
  ]
  if ancillary_layers:
    _insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes)
  return model
Exemplo n.º 23
0
    def map_resources(self):
        """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
        # Only makes sense when adding to the export Graph
        assert not context.executing_eagerly()
        # TODO(allenl): Handle MirroredVariables and other types of variables which
        # may need special casing.
        object_map = object_identity.ObjectIdentityDictionary()
        resource_map = {}
        asset_info = _AssetInfo(asset_defs=[],
                                asset_initializers_by_resource={},
                                asset_filename_map={},
                                asset_index={})

        for node_id, obj in enumerate(self.nodes):
            if isinstance(obj, tracking.CapturableResource):
                new_obj = object_map[obj] = copy.copy(obj)
                # pylint: disable=protected-access
                with ops.device(obj._resource_device):
                    new_resource = new_obj._create_resource()
                new_obj._resource_handle = new_resource
                # pylint: enable=protected-access
                resource_map[obj.resource_handle] = new_resource
                self.captured_tensor_node_ids[obj.resource_handle] = node_id
            elif (ds_values.is_distributed_variable(obj)
                  or resource_variable_ops.is_resource_variable(obj)):
                obj_to_copy = obj.primary if ds_values.is_distributed_variable(
                    obj) else obj
                new_variable = resource_variable_ops.copy_to_graph_uninitialized(
                    obj_to_copy)
                if ds_values.is_distributed_variable(obj):
                    self.captured_tensor_node_ids[obj] = node_id
                    for v in obj.values:
                        object_map[v] = new_variable
                        resource_map[v.handle] = new_variable.handle
                        self.captured_tensor_node_ids[v.handle] = node_id
                object_map[obj] = new_variable
                resource_map[obj.handle] = new_variable.handle
                self.captured_tensor_node_ids[obj.handle] = node_id
            elif isinstance(obj, tracking.Asset):
                _process_asset(obj, asset_info, resource_map)
                self.captured_tensor_node_ids[obj.asset_path] = node_id

        # Note: some concrete functions can have been realized when tracing other
        # functions, and might closure-capture tensors from their parent functions.
        # This is normal, but it means those concrete functions can't be serialized
        # as their own independent endpoints, so we filter them out here.
        bad_functions = []
        for concrete_function in self.concrete_functions:
            if not concrete_function.graph.saveable:
                raise ValueError((
                    "Unable to save function {name} for the following reason(s):\n"
                    + "\n".join(concrete_function.graph.saving_errors)).format(
                        name=concrete_function.name))
            for capture in concrete_function.captured_inputs:
                if (tensor_util.is_tensor(capture)
                        and capture.dtype not in _UNCOPIABLE_DTYPES
                        and capture not in self.captured_tensor_node_ids):
                    if hasattr(capture, "_cached_variable"):
                        if concrete_function not in self.wrapped_functions:
                            wrapped = self.wrapped_functions[
                                concrete_function] = (
                                    function_serialization.
                                    wrap_cached_variables(concrete_function))
                            self.function_name_map[compat.as_text(
                                concrete_function.name)] = (compat.as_text(
                                    wrapped.name))
                        continue
                    capture_constant_value = tensor_util.constant_value(
                        capture)
                    if capture_constant_value is None:
                        bad_functions.append(concrete_function)
                        continue
                    copied_tensor = constant_op.constant(
                        capture_constant_value)
                    node_id = len(self.nodes)
                    node = _CapturedConstant(eager_tensor=capture,
                                             graph_tensor=copied_tensor)
                    self.nodes.append(node)
                    self.node_ids[capture] = node_id
                    self.node_ids[node] = node_id
                    self.captured_tensor_node_ids[capture] = node_id
                    resource_map[capture] = copied_tensor

        self.concrete_functions = [
            self.wrapped_functions.get(x, x) for x in self.concrete_functions
            if x not in bad_functions
        ]
        return object_map, resource_map, asset_info
Exemplo n.º 24
0
  def get_concrete_function(self, *args, **kwargs):
    """Returns a `ConcreteFunction` specialized to inputs and execution context.

    If this `Function` was created with an `input_signature`, `args` and
    `kwargs` may be omitted. With an input signature there is only one
    concrete function associated with this `Function`.

    If there is no fixed `input_signature` associated with this
    `Function`, positional and keyword arguments to `get_concrete_function`
    follow the same rules as input signature specification, with `tf.TensorSpec`
    objects describing `tf.Tensor`s which will be passed to the concrete
    function.

    Each `tf.Tensor` argument to the concrete function must have a unique name,
    either because it is the only one associated with a named argument of the
    Python function or because an explicit `name=` was passed to its
    `tf.TensorSpec` object. These names become the argument names for the
    concrete function.

    Arguments to the concrete function may always be specified as keyword
    arguments, naming the Tensor input. Positional arguments may be used instead
    when each preceding argument to the Python function is a Tensor.

    ```python
    @tf.function
    def f(x):
      return x

    f_concrete = f.get_concrete_function(tf.TensorSpec([], tf.float64))
    f_concrete(tf.constant(1.))
    f_concrete(x=tf.constant(1.))
    ```

    Nested structures containing Tensors may be specified when retrieving
    concrete functions. Structures with multiple Tensors are expanded into
    multiple arguments of the concrete function. Since multiple concrete
    function arguments are associated with one argument to the original
    function, these Tensors must be named explicitly. Tensors in nested
    structures may not be passed using positional arguments when calling the
    concrete function.

    ```python
    f_concrete2 = f.get_concrete_function(
        (tf.TensorSpec(None, tf.float64, name="first"),
         tf.TensorSpec([], tf.float32, name="second")))
    # Keyword arguments are required when identifying Tensors in nested
    # structures.
    f_concrete2(first=tf.constant([1.]), second=tf.constant(0.))
    ```

    Functions with fixed input signatures have only one concrete function
    associated with them, which can be retrieved without specifying any
    arguments. As before Tensors must have unique names, either inferred from
    the argument names in the original Python function or specified
    explicitly.

    ```python
    @tf.function(input_signature=(tf.TensorSpec(None, tf.float32)))
    def f_sig(y):
      return y

    f_sig_concrete = f.get_concrete_function()
    f_sig_concrete(tf.constant(1.))
    f_sig_concrete(y=tf.constant(1.))
    ```

    Args:
      *args: inputs to specialize on.
      **kwargs: inputs to specialize on.

    Returns:
      A TensorFlow function which takes exactly one `tf.Tensor` per argument.

    Raises:
      ValueError: if this object has not yet been called on concrete values.
    """
    with self._lock:
      if self._stateful_fn is None:
        initializer_map = object_identity.ObjectIdentityDictionary()
        self._initialize(args, kwargs, add_initializers_to=initializer_map)
        self._initialize_uninitialized_variables(initializer_map)

    if self._created_variables:
      # In this case we have created variables on the first call, so we run the
      # defunned version which is guaranteed to never create variables.
      return self._stateless_fn.get_concrete_function(*args, **kwargs)
    elif self._stateful_fn is not None:
      # In this case we have not created variables on the first call. So we can
      # run the first trace but we should fail if variables are created.
      concrete = self._stateful_fn.get_concrete_function(*args, **kwargs)
      if self._created_variables:
        raise ValueError("Creating variables on a non-first call to a function"
                         " decorated with tf.function.")
      return concrete
Exemplo n.º 25
0
def lift_to_graph(tensors,
                  graph,
                  sources=None,
                  disallowed_placeholders=None,
                  add_sources=False,
                  handle_captures=False,
                  base_graph=None,
                  op_map=None):
  """Copies the tensor and all its inputs recursively to the outer graph.

  Args:
    tensors: The Tensors to lift.
    graph: The graph to lift to.
    sources: Optional sequence of nodes to start from. If omitted the whole
      subgraph which feeds into `init_tensor` is lifted.
    disallowed_placeholders: An optional set of ops which may not appear in the
      lifted graph. Defaults to all placeholders.
    add_sources: A boolean indicating whether placeholders which are not in
      sources should be allowed.
    handle_captures: A boolean indicating whether to re-capture s in the new
      graph or simply create a vanilla placeholder.
    base_graph: The graph from which to lift ops. This will be inferred if not
      specified.
    op_map: A map contains all the existing nodes that have been lifted to the
      destination graph, so they won't be lifted and copied again.

  Returns:
    A mapping from ops in the current default graph to ops in `graph`.

  Raises:
    UnliftableError: If a placeholder blocks lifting.
  """
  variable_init_tensors = []
  init_tensors = []
  for tensor in tensors:
    if isinstance(tensor, resource_variable_ops.ResourceVariable):
      variable_init_tensors.append(tensor)
    else:
      init_tensors.append(tensor)
  base_graph = base_graph or init_tensors[0].graph
  op_map = op_map or object_identity.ObjectIdentityDictionary()

  # Check that the initializer does not depend on any placeholders.
  sources = object_identity.ObjectIdentitySet(sources or [])
  visited_ops = set(x.op for x in sources)
  op_outputs = collections.defaultdict(set)

  # First we extract the subgraph between init_tensors and sources.
  for init_tensor in init_tensors:
    sources.update(op_selector.map_subgraph(
        init_tensor=init_tensor,
        sources=sources,
        disallowed_placeholders=disallowed_placeholders,
        visited_ops=visited_ops,
        op_outputs=op_outputs,
        add_sources=add_sources))

  # Try to topologically sort the nodes we've extracted. Now we know how many of
  # their outputs are part of this subgraph.
  ops_to_copy = []
  marked_ops = set([])
  ops_to_visit = [_as_operation(t) for t in init_tensors
                  if not op_outputs[_as_operation(t)]]
  unvisited_ops = set(ops_to_visit)
  while unvisited_ops:
    while ops_to_visit:
      op = ops_to_visit.pop()
      if op in marked_ops:
        continue
      marked_ops.add(op)
      ops_to_copy.append(op)
      for inp in op_selector.graph_inputs(op):
        # Don't lift the TPUReplicateMetadata nodes out of the function, because
        # it has no registered kernels.
        if inp.type == "TPUReplicateMetadata":
          continue
        unvisited_ops.add(inp)
        if (all(x in marked_ops for x in op_outputs[inp]) and
            inp not in sources):
          ops_to_visit.append(inp)
    unvisited_ops.difference_update(marked_ops)
    if unvisited_ops:
      # `unvisited_ops` should only have elements if the graph has a loop. In
      # this case we want to keep copying and there's no topological ordering;
      # we'll do ugly post-hoc mutations instead.
      ops_to_visit.append(next(iter(unvisited_ops)))

  # When the topological sort fails due to loops, it can result in exceptions
  # later when copying a node which inputs haven't been copied yet. We can
  # improve that pseudo-topological order slightly by putting the ops without
  # inputs, such as constants, at the start of the topological order (i.e at
  # the end of ops_to_copy).
  ops_to_copy.sort(key=(lambda op: len(op_selector.graph_inputs(op)) == 0))

  # When lifting from one FuncGraph to another, we will need to capture the
  # relevant tensors as well.
  captures = []
  inverse_captures = object_identity.ObjectIdentityDictionary()
  internal_captures = []
  if (isinstance(base_graph, func_graph.FuncGraph) and
      isinstance(graph, func_graph.FuncGraph)):
    captures = base_graph.captures
    for external_capture, internal_capture in captures:
      inverse_captures[internal_capture] = external_capture
    internal_captures = base_graph.internal_captures

  # ops_to_copy now holds a reverse topologically sorted list of ops which
  # ends in the initializer. We copy those to the outermost graph and
  # build the initialization op there.
  with graph.as_default():
    for i in variable_init_tensors:
      op_map[i] = i
    source_ops = set()
    # Add the sources in the same order as the original graph.
    for s in internal_captures:
      if s in sources:
        sources.remove(s)
        source_ops.add(s.op)
        _copy_source(
            s=s,
            graph=graph,
            op_map=op_map,
            handle_captures=handle_captures,
            inverse_captures=inverse_captures,
            base_graph=base_graph)
    for s in sources:
      source_ops.add(s.op)
      _copy_source(
          s=s,
          graph=graph,
          op_map=op_map,
          handle_captures=handle_captures,
          inverse_captures=inverse_captures,
          base_graph=base_graph)

    input_mutations = []
    control_mutations = []
    for op in reversed(ops_to_copy):
      if op in source_ops or op in op_map:
        continue
      new_input_mutations, new_control_mutations = _copy_non_source(
          op=op, graph=graph, op_map=op_map, base_graph=base_graph)
      input_mutations.extend(new_input_mutations)
      control_mutations.extend(new_control_mutations)

    # Mutate the new graph to insert any loops which existed in the source
    # graph due to v1 while_loops.
    #
    # pylint: disable=protected-access
    with graph._mutation_lock():
      for mutation in input_mutations:
        mutation.copied_op._update_input(
            mutation.input_index, op_map[mutation.old_graph_tensor])
      for mutation in control_mutations:
        # Don't lift the TPUReplicateMetadata nodes out of the function, because
        # it has no registered kernels.
        if mutation.old_graph_op.type == "TPUReplicateMetadata":
          continue
        mutation.copied_op._add_control_input(op_map[mutation.old_graph_op])
    # pylint: enable=protected-access

    return op_map
Exemplo n.º 26
0
  def _call(self, *args, **kwds):
    """Calls the graph function."""
    self._lock.acquire()
    if self._created_variables:
      # Release the lock early so that multiple threads can perform the call
      # in parallel.
      self._lock.release()
      # In this case we have created variables on the first call, so we run the
      # defunned version which is guaranteed to never create variables.
      return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    elif self._stateful_fn is not None:
      # Release the lock early so that multiple threads can perform the call
      # in parallel.
      self._lock.release()
      # In this case we have not created variables on the first call. So we can
      # run the first trace but we should fail if variables are created.
      results = self._stateful_fn(*args, **kwds)
      if self._created_variables:
        raise ValueError("Creating variables on a non-first call to a function"
                         " decorated with tf.function.")
      return results

    try:
      # This is the first call of __call__, so we have to initialize.
      initializer_map = object_identity.ObjectIdentityDictionary()
      self._initialize(args, kwds, add_initializers_to=initializer_map)
    finally:
      # At this point we know that the initialization is complete (or less
      # interestingly an exception was raised) so we no longer need a lock.
      self._lock.release()

    if self._created_variables:
      try:
        # Attempt to initialize variables eagerly and without conds by lifting
        # out initialization graphs. This is the only initialization strategy
        # compatible with XLA at the moment.
        self._initialize_uninitialized_variables(initializer_map)
      except lift_to_graph.UnliftableError:
        pass  # Fall through to cond-based initialization.
      else:
        # Lifting succeeded, so variables are initialized and we can run the
        # stateless function.
        return self._stateless_fn(*args, **kwds)
    else:
      canon_args, canon_kwds = \
          self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
              *args, **kwds)
      # If we did not create any variables the trace we have is good enough.
      return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access

    def fn_with_cond(*inner_args, **inner_kwds):
      """Conditionally runs initialization if it's needed."""
      condition = True
      for wr in self._created_variables:
        variable = wr()
        if variable is None:
          raise ValueError(
              "A tf.Variable created inside your tf.function has been"
              " garbage-collected. Your code needs to keep Python references"
              " to variables created inside `tf.function`s.\n"
              "\n"
              "A common way to raise this error is to create and return a"
              " variable only referenced inside your function:\n"
              "\n"
              "@tf.function\n"
              "def f():\n"
              "  v = tf.Variable(1.0)\n"
              "  return v\n"
              "\n"
              "v = f()  # Crashes with this error message!\n"
              "\n"
              "The reason this crashes is that @tf.function annotated"
              " function returns a **`tf.Tensor`** with the **value** of the"
              " variable when the function is called rather than the"
              " variable instance itself. As such there is no code holding a"
              " reference to the `v` created inside the function and Python"
              " garbage collects it.\n"
              "\n"
              "The simplest way to fix this issue is to create variables"
              " outside the function and capture them:\n"
              "\n"
              "v = tf.Variable(1.0)\n"
              "\n"
              "@tf.function\n"
              "def f():\n"
              "  return v\n"
              "\n"
              "f()  # <tf.Tensor: ... numpy=1.>\n"
              "v.assign_add(1.)\n"
              "f()  # <tf.Tensor: ... numpy=2.>")
        condition = math_ops.logical_and(
            condition, resource_variable_ops.var_is_initialized_op(
                variable.handle))
      # We want to call stateless_fn if possible because it avoids recomputing
      # potentially expensive initializers.
      return control_flow_ops.cond(
          condition,
          lambda: self._stateless_fn(*inner_args, **inner_kwds),
          functools.partial(self._concrete_stateful_fn._filtered_call,  # pylint: disable=protected-access
                            inner_args, inner_kwds))

    # We've created variables and are unable to lift the initialization graphs,
    # so we fall back to initializing with conds while running the function.
    canon_args, canon_kwds = \
        self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
            *args, **kwds)
    return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds)
Exemplo n.º 27
0
def _lift_unlifted_variables(graph, variable_holder):
    """Finds resource variables and lifts them into the outer context.

  When we import a GraphDef inside a wrap_function, no Python graph building
  code runs. This means we get VarHandleOps which create variable resources,
  but no corresponding Python objects. Leaving them like this works but gives
  the user no way to interact with or modify the variables outside the graph.

  This method searches for variables and lifts them out as regular variable
  objects when possible, indicating to the FuncGraph that they are captures.

  Args:
    graph: The FuncGraph to lift variables from.
    variable_holder: A VariableHolder to record the lifted variables in.
  """
    with graph.as_default():
        global_collection_variables = ops.get_collection(
            ops.GraphKeys.GLOBAL_VARIABLES)
        local_collection_variables = ops.get_collection(
            ops.GraphKeys.LOCAL_VARIABLES)
        existing_captures = object_identity.ObjectIdentitySet(
            graph.internal_captures)
        lifted_variables = object_identity.ObjectIdentityDictionary()

        def _should_lift_variable(v):
            return ((
                v._in_graph_mode  # pylint: disable=protected-access
                and v.graph.building_function) and isinstance(
                    v, resource_variable_ops.BaseResourceVariable)
                    and v.handle not in existing_captures)

        for old_variable in global_collection_variables:
            if _should_lift_variable(old_variable):
                new_variable = _lift_single_variable(old_variable, graph,
                                                     variable_holder)
                lifted_variables[old_variable] = new_variable
                existing_captures.add(old_variable.handle)

        for old_variable in local_collection_variables:
            if _should_lift_variable(old_variable):
                new_variable = _lift_single_variable(old_variable, graph,
                                                     variable_holder)
                lifted_variables[old_variable] = new_variable
                existing_captures.add(old_variable.handle)
                if new_variable._in_graph_mode:  # pylint: disable=protected-access
                    outer_graph = new_variable.graph
                    # Variables are added to the global collection by default. In this
                    # case we only want the variable in the local collection, so we'll pop
                    # it out.
                    global_collection = outer_graph.get_collection_ref(
                        ops.GraphKeys.GLOBAL_VARIABLES)
                    global_collection.remove(new_variable)
                    outer_graph.add_to_collection(
                        ops.GraphKeys.LOCAL_VARIABLES, new_variable)

        # Update the FuncGraph's collections, partly for the user and partly so this
        # function is idempotent when it runs again in prune() calls.
        for collection_name in [
                ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES
        ]:
            mutable_collection = ops.get_collection_ref(collection_name)
            for index, current in enumerate(mutable_collection):
                mutable_collection[index] = lifted_variables.get(
                    current, current)
                if not resource_variable_ops.is_resource_variable(
                        mutable_collection[index]):
                    logging.warning(
                        "Unable to create a python object for variable {} because it is "
                        "a reference variable. It may not be visible to training APIs. "
                        "If this is a problem, consider rebuilding the SavedModel after "
                        "running tf.compat.v1.enable_resource_variables().".
                        format(mutable_collection[index]))