コード例 #1
0
ファイル: data_structures.py プロジェクト: clsung/tensorflow
 def _track_value(self, value, name):
   """Add a dependency on `value`."""
   value = sticky_attribute_assignment(
       checkpointable=self, value=value, name=name)
   if isinstance(value, variables.Variable):
     self._extra_variables.append(value)
   if not isinstance(value, base.CheckpointableBase):
     raise ValueError(
         ("Only checkpointable objects (such as Layers or Optimizers) may be "
          "stored in a List object. Got %s, which does not inherit from "
          "CheckpointableBase.") % (value,))
   if (isinstance(value, CheckpointableDataStructure)
       or layer_utils.is_layer(value)
       or layer_utils.has_weights(value)):
     # Check for object-identity rather than with __eq__ to avoid
     # de-duplicating empty container types. Automatically generated list
     # wrappers keep things like "[] == []" true, which means "[] in [[]]" is
     # also true. This becomes not true once one of the lists is mutated.
     if not any((layer is value for layer in self._layers)):
       self._layers.append(value)
       if hasattr(value, "_use_resource_variables"):
         # In subclassed models, legacy layers (tf.layers) must always use
         # resource variables.
         value._use_resource_variables = True  # pylint: disable=protected-access
   return value
コード例 #2
0
 def _layers(self):
     """All Layers and Layer containers, including empty containers."""
     # Filter objects on demand so that wrapper objects use values from the thing
     # they're wrapping if out of sync.
     collected = []
     for obj in self._values:
         if (isinstance(obj, CheckpointableDataStructure)
                 or layer_utils.is_layer(obj)
                 or layer_utils.has_weights(obj)):
             collected.append(obj)
     return collected
コード例 #3
0
 def _layers(self):
   """All Layers and Layer containers, including empty containers."""
   # Filter objects on demand so that wrapper objects use values from the thing
   # they're wrapping if out of sync.
   collected = []
   for obj in self._values:
     if (isinstance(obj, CheckpointableDataStructure)
         or layer_utils.is_layer(obj)
         or layer_utils.has_weights(obj)):
       collected.append(obj)
   return collected
コード例 #4
0
 def _track_value(self, value, name):
     """Add a dependency on `value`."""
     if isinstance(value, checkpointable_lib.CheckpointableBase):
         self._track_checkpointable(value, name=name)
         if isinstance(value, variables.Variable):
             self._extra_variables.append(value)
     else:
         raise ValueError((
             "Only checkpointable objects (such as Layers or Optimizers) may be "
             "stored in a List object. Got %s, which does not inherit from "
             "CheckpointableBase.") % (value, ))
     if (isinstance(value, CheckpointableDataStructure)
             or layer_utils.is_layer(value)):
         if value not in self._layers:
             self._layers.append(value)
             if hasattr(value, "_use_resource_variables"):
                 # In subclassed models, legacy layers (tf.layers) must always use
                 # resource variables.
                 value._use_resource_variables = True  # pylint: disable=protected-access
コード例 #5
0
 def _track_value(self, value, name):
   """Add a dependency on `value`."""
   if isinstance(value, checkpointable_lib.CheckpointableBase):
     self._track_checkpointable(value, name=name)
     if isinstance(value, variables.Variable):
       self._extra_variables.append(value)
   else:
     raise ValueError(
         ("Only checkpointable objects (such as Layers or Optimizers) may be "
          "stored in a List object. Got %s, which does not inherit from "
          "CheckpointableBase.") % (value,))
   if (isinstance(value, CheckpointableDataStructure)
       or layer_utils.is_layer(value)):
     if value not in self._layers:
       self._layers.append(value)
       if hasattr(value, "_use_resource_variables"):
         # In subclassed models, legacy layers (tf.layers) must always use
         # resource variables.
         value._use_resource_variables = True  # pylint: disable=protected-access