Exemple #1
0
def in_place_subclassed_model_state_restoration(model):
    """Restores the original state of a model after it was "reset".

  This undoes this action of `_in_place_subclassed_model_reset`, which is called
  in `clone_and_build_model` if `in_place_reset` is set to True.

  Args:
    model: Instance of a Keras model created via subclassing, on which
      `_in_place_subclassed_model_reset` was previously called.
  """
    assert not model._is_graph_network
    # Restore layers and build attributes
    if (hasattr(model, '_original_attributes_cache')
            and model._original_attributes_cache is not None):
        # Models have sticky attribute assignment, so we want to be careful to add
        # back the previous attributes and track Layers by their original names
        # without adding dependencies on "utility" attributes which Models exempt
        # when they're constructed.
        model._layers = data_structures.NoDependency([])
        for name, value in model._original_attributes_cache.items():
            if not isinstance(value, checkpointable.CheckpointableBase):
                # If this value is not already checkpointable, it's probably that way
                # for a reason; we don't want to start tracking data structures that the
                # original Model didn't.
                value = data_structures.NoDependency(value)
            setattr(model, name, value)
        model._original_attributes_cache = None
    else:
        # Restore to the state of a never-called model.
        model.built = False
        model.inputs = None
        model.outputs = None
 def __init__(self):
     super(Foo, self).__init__()
     self.isdep = keras.layers.Dense(1)
     self.notdep = data_structures.NoDependency(
         keras.layers.Dense(2))
     self.notdep_var = data_structures.NoDependency(
         resource_variable_ops.ResourceVariable(1.,
                                                name='notdep_var'))
 def testNestedLists(self):
     a = tracking.Checkpointable()
     a.l = []
     b = tracking.Checkpointable()
     a.l.append([b])
     c = tracking.Checkpointable()
     a.l[0].append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     a.l[0].append(1)
     d = tracking.Checkpointable()
     a.l[0].append(d)
     a_deps = util.list_objects(a)
     self.assertIn(d, a_deps)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     self.assertNotIn(1, a_deps)
     e = tracking.Checkpointable()
     f = tracking.Checkpointable()
     a.l1 = [[], [e]]
     a.l1[0].append(f)
     a_deps = util.list_objects(a)
     self.assertIn(e, a_deps)
     self.assertIn(f, a_deps)
     checkpoint = util.Checkpoint(a=a)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     a.l[0].append(data_structures.NoDependency([]))
     a.l[0][-1].append(5)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     # Dirtying the inner list means the root object is unsaveable.
     a.l[0][1] = 2
     with self.assertRaisesRegexp(ValueError,
                                  "A list element was replaced"):
         checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
Exemple #4
0
 def testDictWrapperNoDependency(self):
   a = tracking.Checkpointable()
   a.d = data_structures.NoDependency({})
   a.d[1] = [3]
   self.assertEqual([a], util.list_objects(a))
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   model.save_weights(save_path)
   model.load_weights(save_path)
Exemple #5
0
 def testNonAppendNotCheckpointable(self):
   # Non-append mutations (deleting or overwriting values) are OK when the
   # values aren't tracked.
   a = tracking.Checkpointable()
   a.d = {}
   a.d["a"] = [3]
   a.d[1] = 3
   a.d[1] = 2
   self.assertEqual(2, a.d[1])
   del a.d[1]
   a.d[2] = data_structures.NoDependency(tracking.Checkpointable())
   second = tracking.Checkpointable()
   a.d[2] = data_structures.NoDependency(second)
   self.assertIs(second, a.d[2])
   self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   model.save_weights(save_path)
   model.load_weights(save_path)
Exemple #6
0
 def testNonStringKeyNotCheckpointableValue(self):
   a = tracking.Checkpointable()
   a.d = {}
   a.d["a"] = [3]
   a.d[1] = data_structures.NoDependency([3])
   self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   model.save_weights(save_path)
   model.load_weights(save_path)
Exemple #7
0
 def testNoDepList(self):
   a = training.Model()
   a.l1 = data_structures.NoDependency([])
   a.l1.insert(1, 0)
   self.assertTrue(isinstance(a.l1, list))
   checkpoint = util.Checkpoint(a=a)
   checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
   a.l2 = []
   a.l2.insert(1, 0)
   with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
 def __init__(self):
   self.slotdeps = containers.UniqueNameTracker()
   slotdeps = self.slotdeps
   slots = []
   slots.append(slotdeps.track(
       resource_variable_ops.ResourceVariable(3.), "x"))
   slots.append(slotdeps.track(
       resource_variable_ops.ResourceVariable(4.), "y"))
   slots.append(slotdeps.track(
       resource_variable_ops.ResourceVariable(5.), "x"))
   self.slots = data_structures.NoDependency(slots)
    def testNoDependency(self):
        root = tracking.Checkpointable()
        hasdep = tracking.Checkpointable()
        root.hasdep = hasdep
        nodep = tracking.Checkpointable()
        root.nodep = data_structures.NoDependency(nodep)
        self.assertEqual(1, len(root._checkpoint_dependencies))
        self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
        self.assertIs(root.hasdep, hasdep)
        self.assertIs(root.nodep, nodep)

        class NoDependencyModel(training.Model):
            @base.no_automatic_dependency_tracking
            def __init__(self):
                super(NoDependencyModel, self).__init__()
                self.a = []
                self.b = tracking.Checkpointable()

        nodeps = NoDependencyModel()
        self.assertEqual([nodeps], util.list_objects(nodeps))
Exemple #10
0
    def commit(self, prefix, session):
        """
        Commit the latest checkpoint.
        """

        if self._cached_checkpoint == None:
            if len(self._checkpoints) > 0:
                return self._checkpoints[-1]
            else:
                return ""

        if len(self._checkpoints) == self._max_to_keep:
            for filename in self._get_checkpoint_filenames(
                    self._checkpoints.pop(0)):
                os.remove(filename)

        # Replication from checkpoint.save
        if self._checkpoint._save_counter is None:
            session.run(self._checkpoint.save_counter.initializer)
        if self._checkpoint._save_assign_op is None:
            self._checkpoint._save_assign_op = data_structures.NoDependency(
                self._checkpoint.save_counter.assign_add(1, read_value=True))

        checkpoint_count = session.run(self._checkpoint._save_assign_op)
        filename_prefix = "%s-%d" % (prefix, checkpoint_count)

        for filename in self._get_checkpoint_filenames(
                self._cached_checkpoint):
            # Change prefix
            os.rename(
                filename,
                filename.replace(self._cached_checkpoint, filename_prefix))

        self._checkpoints.append(filename_prefix)
        self._cached_checkpoint = None
        # Update checkpoint state file (@tf.train.latest_checkpoint)
        checkpoint_management.update_checkpoint_state_internal(
            self._directory, self._checkpoints[-1], self._checkpoints)
        return filename_prefix
Exemple #11
0
def _in_place_subclassed_model_reset(model):
    """Substitute for model cloning that works for subclassed models.

  Subclassed models cannot be cloned because their topology is not serializable.
  To "instantiate" an identical model in a new TF graph, we reuse the original
  model object, but we clear its state.

  After calling this function on a model instance, you can use the model
  instance as if it were a model clone (in particular you can use it in a new
  graph).

  This method clears the state of the input model. It is thus destructive.
  However the original state can be restored fully by calling
  `_in_place_subclassed_model_state_restoration`.

  Args:
    model: Instance of a Keras model created via subclassing.

  Raises:
    ValueError: In case the model uses a subclassed model as inner layer.
  """
    assert not model._is_graph_network  # Only makes sense for subclassed networks
    # Retrieve all layers tracked by the model as well as their attribute names
    attributes_cache = {}
    for name in dir(model):
        try:
            value = getattr(model, name)
        except (AttributeError, ValueError, TypeError):
            continue
        if isinstance(value, Layer):
            attributes_cache[name] = value
            assert value in model._layers
        elif isinstance(value, (list, tuple)) and name not in (
                'layers', '_layers', 'metrics',
                '_compile_stateful_metric_functions'):
            # Handle case: list/tuple of layers (also tracked by the Network API).
            if value and all(isinstance(val, Layer) for val in value):
                raise ValueError(
                    'We do not support the use of list-of-layers '
                    'attributes in subclassed models used with '
                    '`model_to_estimator` at this time. Found list '
                    'model: %s' % name)

    # Replace layers on the model with fresh layers
    layers_to_names = {value: key for key, value in attributes_cache.items()}
    original_layers = model._layers[:]
    model._layers = data_structures.NoDependency([])
    for layer in original_layers:  # We preserve layer order.
        config = layer.get_config()
        # This will not work for nested subclassed models used as layers.
        # This would be theoretically possible to support, but would add complexity.
        # Only do it if users complain.
        if isinstance(layer, Network) and not layer._is_graph_network:
            raise ValueError(
                'We do not support the use of nested subclassed models '
                'in `model_to_estimator` at this time. Found nested '
                'model: %s' % layer)
        fresh_layer = layer.__class__.from_config(config)
        name = layers_to_names[layer]
        setattr(model, name, fresh_layer)

    # Cache original model build attributes (in addition to layers)
    if (not hasattr(model, '_original_attributes_cache')
            or model._original_attributes_cache is None):
        if model.built:
            attributes_to_cache = [
                'inputs',
                'outputs',
                '_feed_outputs',
                '_feed_output_names',
                '_feed_output_shapes',
                '_feed_loss_fns',
                'loss_weights_list',
                'targets',
                '_feed_targets',
                'sample_weight_modes',
                'total_loss',
                'sample_weights',
                '_feed_sample_weights',
                '_fit_function',
                '_eval_function',
                'train_function',
                'test_function',
                'predict_function',
                '_collected_trainable_weights',
                '_feed_inputs',
                '_feed_input_names',
                '_feed_input_shapes',
                'optimizer',
            ]
            for name in attributes_to_cache:
                attributes_cache[name] = getattr(model, name)
    model._original_attributes_cache = data_structures.NoDependency(
        attributes_cache)
    # Reset built state
    model.built = False
    model.inputs = None
    model.outputs = None
Exemple #12
0
 def _no_dependency(self, value):
     """Override to allow CheckpointableBase to disable dependency tracking."""
     return data_structures.NoDependency(value)