def _track_value(self, value, name):
     """Add a dependency on `value`."""
     value = sticky_attribute_assignment(checkpointable=self,
                                         value=value,
                                         name=name)
     if isinstance(value, variables.Variable):
         self._extra_variables.append(value)
     if not isinstance(value, base.CheckpointableBase):
         raise ValueError((
             "Only checkpointable objects (such as Layers or Optimizers) may be "
             "stored in a List object. Got %s, which does not inherit from "
             "CheckpointableBase.") % (value, ))
     if (isinstance(value, CheckpointableDataStructure)
             or layer_utils.is_layer(value)
             or layer_utils.has_weights(value)):
         # Check for object-identity rather than with __eq__ to avoid
         # de-duplicating empty container types. Automatically generated list
         # wrappers keep things like "[] == []" true, which means "[] in [[]]" is
         # also true. This becomes not true once one of the lists is mutated.
         if not any((layer is value for layer in self._layers)):
             self._layers.append(value)
             if hasattr(value, "_use_resource_variables"):
                 # In subclassed models, legacy layers (tf.layers) must always use
                 # resource variables.
                 value._use_resource_variables = True  # pylint: disable=protected-access
     return value
예제 #2
0
 def _track_value(self, value, name):
   """Add a dependency on `value`."""
   value = sticky_attribute_assignment(
       checkpointable=self, value=value, name=name)
   if isinstance(value, variables.Variable):
     self._extra_variables.append(value)
   if not isinstance(value, base.CheckpointableBase):
     raise ValueError(
         ("Only checkpointable objects (such as Layers or Optimizers) may be "
          "stored in a List object. Got %s, which does not inherit from "
          "CheckpointableBase.") % (value,))
   if (isinstance(value, CheckpointableDataStructure)
       or layer_utils.is_layer(value)
       or layer_utils.has_weights(value)):
     # Check for object-identity rather than with __eq__ to avoid
     # de-duplicating empty container types. Automatically generated list
     # wrappers keep things like "[] == []" true, which means "[] in [[]]" is
     # also true. This becomes not true once one of the lists is mutated.
     if not any((layer is value for layer in self._layers)):
       self._layers.append(value)
       if hasattr(value, "_use_resource_variables"):
         # In subclassed models, legacy layers (tf.layers) must always use
         # resource variables.
         value._use_resource_variables = True  # pylint: disable=protected-access
   return value
예제 #3
0
 def _layers(self):
     """All Layers and Layer containers, including empty containers."""
     # Filter objects on demand so that wrapper objects use values from the thing
     # they're wrapping if out of sync.
     collected = []
     for obj in self._values:
         if (isinstance(obj, CheckpointableDataStructure)
                 or layer_utils.is_layer(obj)
                 or layer_utils.has_weights(obj)):
             collected.append(obj)
     return collected
예제 #4
0
 def _layers(self):
   """All Layers and Layer containers, including empty containers."""
   # Filter objects on demand so that wrapper objects use values from the thing
   # they're wrapping if out of sync.
   collected = []
   for obj in self._values:
     if (isinstance(obj, CheckpointableDataStructure)
         or layer_utils.is_layer(obj)
         or layer_utils.has_weights(obj)):
       collected.append(obj)
   return collected