def in_place_subclassed_model_state_restoration(model): """Restores the original state of a model after it was "reset". This undoes this action of `_in_place_subclassed_model_reset`, which is called in `clone_and_build_model` if `in_place_reset` is set to True. Args: model: Instance of a Keras model created via subclassing, on which `_in_place_subclassed_model_reset` was previously called. """ assert not model._is_graph_network # Restore layers and build attributes if (hasattr(model, '_original_attributes_cache') and model._original_attributes_cache is not None): # Models have sticky attribute assignment, so we want to be careful to add # back the previous attributes and track Layers by their original names # without adding dependencies on "utility" attributes which Models exempt # when they're constructed. model._layers = data_structures.NoDependency([]) for name, value in model._original_attributes_cache.items(): if not isinstance(value, checkpointable.CheckpointableBase): # If this value is not already checkpointable, it's probably that way # for a reason; we don't want to start tracking data structures that the # original Model didn't. value = data_structures.NoDependency(value) setattr(model, name, value) model._original_attributes_cache = None else: # Restore to the state of a never-called model. model.built = False model.inputs = None model.outputs = None
def __init__(self): super(Foo, self).__init__() self.isdep = keras.layers.Dense(1) self.notdep = data_structures.NoDependency( keras.layers.Dense(2)) self.notdep_var = data_structures.NoDependency( resource_variable_ops.ResourceVariable(1., name='notdep_var'))
def testNestedLists(self): a = tracking.Checkpointable() a.l = [] b = tracking.Checkpointable() a.l.append([b]) c = tracking.Checkpointable() a.l[0].append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) a.l[0].append(1) d = tracking.Checkpointable() a.l[0].append(d) a_deps = util.list_objects(a) self.assertIn(d, a_deps) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertNotIn(1, a_deps) e = tracking.Checkpointable() f = tracking.Checkpointable() a.l1 = [[], [e]] a.l1[0].append(f) a_deps = util.list_objects(a) self.assertIn(e, a_deps) self.assertIn(f, a_deps) checkpoint = util.Checkpoint(a=a) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) a.l[0].append(data_structures.NoDependency([])) a.l[0][-1].append(5) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) # Dirtying the inner list means the root object is unsaveable. a.l[0][1] = 2 with self.assertRaisesRegexp(ValueError, "A list element was replaced"): checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
def testDictWrapperNoDependency(self): a = tracking.Checkpointable() a.d = data_structures.NoDependency({}) a.d[1] = [3] self.assertEqual([a], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testNonAppendNotCheckpointable(self): # Non-append mutations (deleting or overwriting values) are OK when the # values aren't tracked. a = tracking.Checkpointable() a.d = {} a.d["a"] = [3] a.d[1] = 3 a.d[1] = 2 self.assertEqual(2, a.d[1]) del a.d[1] a.d[2] = data_structures.NoDependency(tracking.Checkpointable()) second = tracking.Checkpointable() a.d[2] = data_structures.NoDependency(second) self.assertIs(second, a.d[2]) self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testNonStringKeyNotCheckpointableValue(self): a = tracking.Checkpointable() a.d = {} a.d["a"] = [3] a.d[1] = data_structures.NoDependency([3]) self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testNoDepList(self): a = training.Model() a.l1 = data_structures.NoDependency([]) a.l1.insert(1, 0) self.assertTrue(isinstance(a.l1, list)) checkpoint = util.Checkpoint(a=a) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) a.l2 = [] a.l2.insert(1, 0) with self.assertRaisesRegexp(ValueError, "A list element was replaced"): checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
def __init__(self): self.slotdeps = containers.UniqueNameTracker() slotdeps = self.slotdeps slots = [] slots.append(slotdeps.track( resource_variable_ops.ResourceVariable(3.), "x")) slots.append(slotdeps.track( resource_variable_ops.ResourceVariable(4.), "y")) slots.append(slotdeps.track( resource_variable_ops.ResourceVariable(5.), "x")) self.slots = data_structures.NoDependency(slots)
def testNoDependency(self): root = tracking.Checkpointable() hasdep = tracking.Checkpointable() root.hasdep = hasdep nodep = tracking.Checkpointable() root.nodep = data_structures.NoDependency(nodep) self.assertEqual(1, len(root._checkpoint_dependencies)) self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) self.assertIs(root.hasdep, hasdep) self.assertIs(root.nodep, nodep) class NoDependencyModel(training.Model): @base.no_automatic_dependency_tracking def __init__(self): super(NoDependencyModel, self).__init__() self.a = [] self.b = tracking.Checkpointable() nodeps = NoDependencyModel() self.assertEqual([nodeps], util.list_objects(nodeps))
def commit(self, prefix, session): """ Commit the latest checkpoint. """ if self._cached_checkpoint == None: if len(self._checkpoints) > 0: return self._checkpoints[-1] else: return "" if len(self._checkpoints) == self._max_to_keep: for filename in self._get_checkpoint_filenames( self._checkpoints.pop(0)): os.remove(filename) # Replication from checkpoint.save if self._checkpoint._save_counter is None: session.run(self._checkpoint.save_counter.initializer) if self._checkpoint._save_assign_op is None: self._checkpoint._save_assign_op = data_structures.NoDependency( self._checkpoint.save_counter.assign_add(1, read_value=True)) checkpoint_count = session.run(self._checkpoint._save_assign_op) filename_prefix = "%s-%d" % (prefix, checkpoint_count) for filename in self._get_checkpoint_filenames( self._cached_checkpoint): # Change prefix os.rename( filename, filename.replace(self._cached_checkpoint, filename_prefix)) self._checkpoints.append(filename_prefix) self._cached_checkpoint = None # Update checkpoint state file (@tf.train.latest_checkpoint) checkpoint_management.update_checkpoint_state_internal( self._directory, self._checkpoints[-1], self._checkpoints) return filename_prefix
def _in_place_subclassed_model_reset(model): """Substitute for model cloning that works for subclassed models. Subclassed models cannot be cloned because their topology is not serializable. To "instantiate" an identical model in a new TF graph, we reuse the original model object, but we clear its state. After calling this function on a model instance, you can use the model instance as if it were a model clone (in particular you can use it in a new graph). This method clears the state of the input model. It is thus destructive. However the original state can be restored fully by calling `_in_place_subclassed_model_state_restoration`. Args: model: Instance of a Keras model created via subclassing. Raises: ValueError: In case the model uses a subclassed model as inner layer. """ assert not model._is_graph_network # Only makes sense for subclassed networks # Retrieve all layers tracked by the model as well as their attribute names attributes_cache = {} for name in dir(model): try: value = getattr(model, name) except (AttributeError, ValueError, TypeError): continue if isinstance(value, Layer): attributes_cache[name] = value assert value in model._layers elif isinstance(value, (list, tuple)) and name not in ( 'layers', '_layers', 'metrics', '_compile_stateful_metric_functions'): # Handle case: list/tuple of layers (also tracked by the Network API). if value and all(isinstance(val, Layer) for val in value): raise ValueError( 'We do not support the use of list-of-layers ' 'attributes in subclassed models used with ' '`model_to_estimator` at this time. Found list ' 'model: %s' % name) # Replace layers on the model with fresh layers layers_to_names = {value: key for key, value in attributes_cache.items()} original_layers = model._layers[:] model._layers = data_structures.NoDependency([]) for layer in original_layers: # We preserve layer order. config = layer.get_config() # This will not work for nested subclassed models used as layers. # This would be theoretically possible to support, but would add complexity. # Only do it if users complain. if isinstance(layer, Network) and not layer._is_graph_network: raise ValueError( 'We do not support the use of nested subclassed models ' 'in `model_to_estimator` at this time. Found nested ' 'model: %s' % layer) fresh_layer = layer.__class__.from_config(config) name = layers_to_names[layer] setattr(model, name, fresh_layer) # Cache original model build attributes (in addition to layers) if (not hasattr(model, '_original_attributes_cache') or model._original_attributes_cache is None): if model.built: attributes_to_cache = [ 'inputs', 'outputs', '_feed_outputs', '_feed_output_names', '_feed_output_shapes', '_feed_loss_fns', 'loss_weights_list', 'targets', '_feed_targets', 'sample_weight_modes', 'total_loss', 'sample_weights', '_feed_sample_weights', '_fit_function', '_eval_function', 'train_function', 'test_function', 'predict_function', '_collected_trainable_weights', '_feed_inputs', '_feed_input_names', '_feed_input_shapes', 'optimizer', ] for name in attributes_to_cache: attributes_cache[name] = getattr(model, name) model._original_attributes_cache = data_structures.NoDependency( attributes_cache) # Reset built state model.built = False model.inputs = None model.outputs = None
def _no_dependency(self, value): """Override to allow CheckpointableBase to disable dependency tracking.""" return data_structures.NoDependency(value)