def _track_value(self, value, name): """Add a dependency on `value`.""" value = sticky_attribute_assignment(checkpointable=self, value=value, name=name) if isinstance(value, variables.Variable): self._extra_variables.append(value) if not isinstance(value, base.CheckpointableBase): raise ValueError(( "Only checkpointable objects (such as Layers or Optimizers) may be " "stored in a List object. Got %s, which does not inherit from " "CheckpointableBase.") % (value, )) if (isinstance(value, CheckpointableDataStructure) or layer_utils.is_layer(value) or layer_utils.has_weights(value)): # Check for object-identity rather than with __eq__ to avoid # de-duplicating empty container types. Automatically generated list # wrappers keep things like "[] == []" true, which means "[] in [[]]" is # also true. This becomes not true once one of the lists is mutated. if not any((layer is value for layer in self._layers)): self._layers.append(value) if hasattr(value, "_use_resource_variables"): # In subclassed models, legacy layers (tf.layers) must always use # resource variables. value._use_resource_variables = True # pylint: disable=protected-access return value
def _track_value(self, value, name): """Add a dependency on `value`.""" value = sticky_attribute_assignment( checkpointable=self, value=value, name=name) if isinstance(value, variables.Variable): self._extra_variables.append(value) if not isinstance(value, base.CheckpointableBase): raise ValueError( ("Only checkpointable objects (such as Layers or Optimizers) may be " "stored in a List object. Got %s, which does not inherit from " "CheckpointableBase.") % (value,)) if (isinstance(value, CheckpointableDataStructure) or layer_utils.is_layer(value) or layer_utils.has_weights(value)): # Check for object-identity rather than with __eq__ to avoid # de-duplicating empty container types. Automatically generated list # wrappers keep things like "[] == []" true, which means "[] in [[]]" is # also true. This becomes not true once one of the lists is mutated. if not any((layer is value for layer in self._layers)): self._layers.append(value) if hasattr(value, "_use_resource_variables"): # In subclassed models, legacy layers (tf.layers) must always use # resource variables. value._use_resource_variables = True # pylint: disable=protected-access return value
def _layers(self): """All Layers and Layer containers, including empty containers.""" # Filter objects on demand so that wrapper objects use values from the thing # they're wrapping if out of sync. collected = [] for obj in self._values: if (isinstance(obj, CheckpointableDataStructure) or layer_utils.is_layer(obj) or layer_utils.has_weights(obj)): collected.append(obj) return collected