Exemple #1
0
  def testDeferredSlotRestoration(self):
    checkpoint_directory = self.get_temp_dir()

    root = checkpointable.Checkpointable()
    root.var = checkpointable_utils.add_variable(
        root, name="var", initializer=0.)
    optimizer = CheckpointableAdam(0.1)
    if context.in_graph_mode():
      train_op = optimizer.minimize(root.var)
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(train_op)
    else:
      optimizer.minimize(root.var.read_value)
    self.evaluate(state_ops.assign(root.var, 12.))
    no_slots_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "no_slots"), root)
    root.optimizer = optimizer
    self.evaluate(state_ops.assign(root.var, 13.))
    self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
                                   14.))
    slots_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "with_slots"), root)
    new_root = checkpointable.Checkpointable()
    # Load the slot-containing checkpoint (deferred), then immediately overwrite
    # the non-slot variable (also deferred).
    slot_status = checkpointable_utils.restore(
        slots_path, new_root)
    no_slot_status = checkpointable_utils.restore(
        no_slots_path, new_root)
    with self.assertRaises(AssertionError):
      no_slot_status.assert_consumed()
    new_root.var = checkpointable_utils.add_variable(
        new_root, name="var", shape=[])
    no_slot_status.assert_consumed()
    self.evaluate(no_slot_status.restore_ops)
    self.assertEqual(12., self.evaluate(new_root.var))
    new_root.optimizer = CheckpointableAdam(0.1)
    with self.assertRaisesRegexp(AssertionError, "beta1_power"):
      slot_status.assert_consumed()
    self.assertEqual(12., self.evaluate(new_root.var))
    if context.in_eager_mode():
      # Slot variables are only created with restoring initializers when
      # executing eagerly.
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
    else:
      self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
                    None)
    if context.in_graph_mode():
      train_op = new_root.optimizer.minimize(new_root.var)
      # The slot variable now exists; restore() didn't create it, but we should
      # now have a restore op for it.
      self.evaluate(slot_status.restore_ops)
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
      self.evaluate(train_op)
    else:
      new_root.optimizer.minimize(new_root.var.read_value)
    slot_status.assert_consumed()
 def testNumberedPath(self):
   root = checkpointable.Checkpointable()
   leaf = checkpointable.Checkpointable()
   root.leaf = leaf
   checkpointable_utils.add_variable(leaf, name="v", shape=[])
   named_variables, _ = checkpointable_utils._serialize_object_graph(root)
   variable_name, = named_variables.keys()
   self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", variable_name)
 def _get_checkpoint_name(self, name):
   root = checkpointable.Checkpointable()
   checkpointable_utils.add_variable(
       root, name=name, shape=[1, 2], dtype=dtypes.float64)
   named_variables, _ = checkpointable_utils._serialize_object_graph(root)
   checkpoint_name, = named_variables.keys()
   with ops.name_scope("root/" + checkpoint_name):
     pass  # Make sure we can use this as an op name if we prefix it.
   return checkpoint_name
  def testDeferredSlotRestoration(self):
    checkpoint_directory = self.get_temp_dir()

    root = checkpointable.Checkpointable()
    root.var = checkpointable_utils.add_variable(
        root, name="var", initializer=0.)
    optimizer = CheckpointableAdam(0.1)
    if context.in_graph_mode():
      train_op = optimizer.minimize(root.var)
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(train_op)
    else:
      optimizer.minimize(root.var.read_value)
    self.evaluate(state_ops.assign(root.var, 12.))
    no_slots_path = checkpointable_utils.Saver(root).save(
        os.path.join(checkpoint_directory, "no_slots"))
    root.optimizer = optimizer
    self.evaluate(state_ops.assign(root.var, 13.))
    self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
                                   14.))
    slots_path = checkpointable_utils.Saver(root).save(
        os.path.join(checkpoint_directory, "with_slots"))
    new_root = checkpointable.Checkpointable()
    # Load the slot-containing checkpoint (deferred), then immediately overwrite
    # the non-slot variable (also deferred).
    slot_status = checkpointable_utils.Saver(new_root).restore(slots_path)
    no_slot_status = checkpointable_utils.Saver(new_root).restore(no_slots_path)
    with self.assertRaises(AssertionError):
      no_slot_status.assert_consumed()
    new_root.var = checkpointable_utils.add_variable(
        new_root, name="var", shape=[])
    no_slot_status.assert_consumed()
    no_slot_status.run_restore_ops()
    self.assertEqual(12., self.evaluate(new_root.var))
    new_root.optimizer = CheckpointableAdam(0.1)
    with self.assertRaisesRegexp(AssertionError, "beta1_power"):
      slot_status.assert_consumed()
    self.assertEqual(12., self.evaluate(new_root.var))
    if context.in_eager_mode():
      # Slot variables are only created with restoring initializers when
      # executing eagerly.
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
    else:
      self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
                    None)
    if context.in_graph_mode():
      train_op = new_root.optimizer.minimize(new_root.var)
      # The slot variable now exists; restore() didn't create it, but we should
      # now have a restore op for it.
      slot_status.run_restore_ops()
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
      self.evaluate(train_op)
    else:
      new_root.optimizer.minimize(new_root.var.read_value)
    slot_status.assert_consumed()
 def testLocalNameValidation(self):
   root = checkpointable.Checkpointable()
   leaf = checkpointable.Checkpointable()
   # Dots are escaped, which avoids conflicts with reserved names.
   root._track_checkpointable(leaf, name=".ATTRIBUTES")
   checkpointable_utils.add_variable(checkpointable=leaf, name="a", shape=[])
   named_variables, _ = checkpointable_utils._serialize_object_graph(root)
   name, = named_variables.keys()
   self.assertEqual(name, "..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE")
  def testInitNotCalled(self):

    class NoInit(checkpointable.Checkpointable):

      def __init__(self):
        pass

    # __init__ for Checkpointable will be called implicitly.
    checkpointable_utils.add_variable(NoInit(), "var", shape=[])
Exemple #7
0
  def testOverlappingRestores(self):
    checkpoint_directory = self.get_temp_dir()
    save_root = checkpointable.Checkpointable()
    save_root.dep = checkpointable.Checkpointable()
    save_root.dep.var = checkpointable_utils.add_variable(
        save_root.dep, name="var", initializer=0.)
    self.evaluate(state_ops.assign(save_root.dep.var, 12.))
    first_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "first"), save_root)
    self.evaluate(state_ops.assign(save_root.dep.var, 13.))
    second_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "second"), save_root)

    first_root = checkpointable.Checkpointable()
    second_root = checkpointable.Checkpointable()
    first_status = checkpointable_utils.restore(
        first_path, first_root)
    second_status = checkpointable_utils.restore(
        second_path, second_root)
    load_dep = checkpointable.Checkpointable()
    load_dep.var = checkpointable_utils.add_variable(
        load_dep, name="var", shape=[])
    first_root.dep = load_dep
    first_status.assert_consumed()
    self.evaluate(first_status.restore_ops)
    self.assertEqual([], second_status.restore_ops)
    self.assertEqual(12., self.evaluate(load_dep.var))
    second_root.dep = load_dep
    second_status.assert_consumed()
    self.evaluate(second_status.restore_ops)
    self.assertEqual(13., self.evaluate(load_dep.var))

    # Try again with the order of the restore() reversed. The last restore
    # determines the final value.
    first_root = checkpointable.Checkpointable()
    second_root = checkpointable.Checkpointable()
    second_status = checkpointable_utils.restore(
        second_path, second_root)
    first_status = checkpointable_utils.restore(
        first_path, first_root)
    load_dep = checkpointable.Checkpointable()
    load_dep.var = checkpointable_utils.add_variable(
        load_dep, name="var", shape=[])
    first_root.dep = load_dep
    first_status.assert_consumed()
    self.assertEqual([], second_status.restore_ops)
    self.evaluate(first_status.restore_ops)
    self.assertEqual(12., self.evaluate(load_dep.var))
    second_root.dep = load_dep
    second_status.assert_consumed()
    self.evaluate(second_status.restore_ops)
    self.assertEqual(12., self.evaluate(load_dep.var))
 def testShapeDtype(self):
   root = checkpointable.Checkpointable()
   v1 = checkpointable_utils.add_variable(
       root, name="v1", initializer=3., dtype=dtypes.float64)
   self.assertEqual(dtypes.float64, v1.dtype)
   v2 = checkpointable_utils.add_variable(
       root,
       name="v2",
       shape=[3],
       initializer=init_ops.ones_initializer,
       dtype=dtypes.float64)
   self.assertEqual(dtypes.float64, v2.dtype)
   self.assertAllEqual([1., 1., 1.], self.evaluate(v2))
  def testOverlappingRestores(self):
    checkpoint_directory = self.get_temp_dir()
    save_root = checkpointable.Checkpointable()
    save_root.dep = checkpointable.Checkpointable()
    save_root.dep.var = checkpointable_utils.add_variable(
        save_root.dep, name="var", initializer=0.)
    self.evaluate(state_ops.assign(save_root.dep.var, 12.))
    saver = checkpointable_utils.CheckpointableSaver(save_root)
    first_path = saver.save(os.path.join(checkpoint_directory, "first"))
    self.evaluate(state_ops.assign(save_root.dep.var, 13.))
    second_path = saver.save(os.path.join(checkpoint_directory, "second"))

    first_root = checkpointable.Checkpointable()
    second_root = checkpointable.Checkpointable()
    first_status = checkpointable_utils.CheckpointableSaver(
        first_root).restore(first_path)
    second_status = checkpointable_utils.CheckpointableSaver(
        second_root).restore(second_path)
    load_dep = checkpointable.Checkpointable()
    load_dep.var = checkpointable_utils.add_variable(
        load_dep, name="var", shape=[])
    first_root.dep = load_dep
    first_status.assert_consumed()
    first_status.run_restore_ops()
    self.assertEqual(12., self.evaluate(load_dep.var))
    second_root.dep = load_dep
    second_status.assert_consumed()
    second_status.run_restore_ops()
    self.assertEqual(13., self.evaluate(load_dep.var))

    # Try again with the order of the restore() reversed. The last restore
    # determines the final value.
    first_root = checkpointable.Checkpointable()
    second_root = checkpointable.Checkpointable()
    second_status = checkpointable_utils.CheckpointableSaver(
        second_root).restore(second_path)
    first_status = checkpointable_utils.CheckpointableSaver(
        first_root).restore(first_path)
    load_dep = checkpointable.Checkpointable()
    load_dep.var = checkpointable_utils.add_variable(
        load_dep, name="var", shape=[])
    first_root.dep = load_dep
    first_status.assert_consumed()
    first_status.run_restore_ops()
    self.assertEqual(12., self.evaluate(load_dep.var))
    second_root.dep = load_dep
    second_status.assert_consumed()
    second_status.run_restore_ops()
    self.assertEqual(12., self.evaluate(load_dep.var))
  def testAddVariable(self):
    obj = NonLayerCheckpointable()
    with self.assertRaisesRegexp(ValueError, "do not specify shape"):
      checkpointable_utils.add_variable(
          obj, name="shape_specified_twice", shape=[], initializer=1)
    constant_initializer = checkpointable_utils.add_variable(
        obj, name="constant_initializer", initializer=1)
    with variable_scope.variable_scope("some_variable_scope"):
      ones_initializer = checkpointable_utils.add_variable(
          obj,
          name="ones_initializer",
          shape=[2],
          initializer=init_ops.ones_initializer(dtype=dtypes.float32))
    bare_initializer = checkpointable_utils.add_variable(
        obj,
        name="bare_initializer",
        shape=[2, 2],
        dtype=dtypes.float64,
        initializer=init_ops.zeros_initializer)

    # Even in graph mode, there are no naming conflicts between objects, only
    # naming conflicts within an object.
    other_duplicate = resource_variable_ops.ResourceVariable(
        name="duplicate", initial_value=1.)
    duplicate = checkpointable_utils.add_variable(
        obj, name="duplicate", shape=[])
    with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"):
      checkpointable_utils.add_variable(obj, name="duplicate", shape=[])

    if context.in_graph_mode():
      self.evaluate(variables.global_variables_initializer())
    self.assertEqual("constant_initializer:0", constant_initializer.name)
    self.assertEqual(1, self.evaluate(constant_initializer))
    self.assertEqual("some_variable_scope/ones_initializer:0",
                     ones_initializer.name)
    self.assertAllEqual([1, 1], self.evaluate(ones_initializer))
    self.assertAllEqual([[0., 0.],
                         [0., 0.]], self.evaluate(bare_initializer))
    self.assertEqual("a_variable:0", obj.a_variable.name)
    self.assertEqual("duplicate:0", other_duplicate.name)
    if context.in_graph_mode():
      # The .name attribute may be globally influenced, but the checkpoint name
      # won't be (tested below).
      self.assertEqual("duplicate_1:0", duplicate.name)
    else:
      # When executing eagerly, there's no uniquification of variable names. The
      # checkpoint name will be the same.
      self.assertEqual("duplicate:0", duplicate.name)
    named_variables, _ = checkpointable_utils._serialize_object_graph(obj)
    expected_checkpoint_names = (
        "a_variable/.ATTRIBUTES/VARIABLE_VALUE",
        "bare_initializer/.ATTRIBUTES/VARIABLE_VALUE",
        "constant_initializer/.ATTRIBUTES/VARIABLE_VALUE",
        "duplicate/.ATTRIBUTES/VARIABLE_VALUE",
        "ones_initializer/.ATTRIBUTES/VARIABLE_VALUE",
    )
    six.assertCountEqual(
        self, expected_checkpoint_names, named_variables.keys())
  def testDeferredSlotRestoration(self):
    checkpoint_directory = self.get_temp_dir()

    root = checkpointable.Checkpointable()
    root.var = checkpointable_utils.add_variable(
        root, name="var", initializer=0.)
    optimizer = CheckpointableAdam(0.1)
    if context.in_graph_mode():
      train_op = optimizer.minimize(root.var)
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(train_op)
    else:
      optimizer.minimize(root.var.read_value)
    self.evaluate(state_ops.assign(root.var, 12.))
    no_slots_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "no_slots"), root)
    root.optimizer = optimizer
    self.evaluate(state_ops.assign(root.var, 13.))
    self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
                                   14.))
    slots_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "with_slots"), root)
    new_root = checkpointable.Checkpointable()
    # Load the slot-containing checkpoint (deferred), then immediately overwrite
    # the non-slot variable (also deferred).
    slot_status = checkpointable_utils.restore(
        slots_path, new_root)
    no_slot_status = checkpointable_utils.restore(
        no_slots_path, new_root)
    with self.assertRaises(AssertionError):
      no_slot_status.assert_consumed()
    new_root.var = checkpointable_utils.add_variable(
        new_root, name="var", shape=[])
    self.assertEqual(12., self.evaluate(new_root.var))
    no_slot_status.assert_consumed()
    new_root.optimizer = CheckpointableAdam(0.1)
    with self.assertRaisesRegexp(AssertionError, "beta1_power"):
      slot_status.assert_consumed()
    self.assertEqual(12., self.evaluate(new_root.var))
    self.assertEqual(14., self.evaluate(
        new_root.optimizer.get_slot(name="m", var=new_root.var)))
    if context.in_graph_mode():
      train_op = new_root.optimizer.minimize(new_root.var)
      self.evaluate(train_op)
    else:
      new_root.optimizer.minimize(new_root.var.read_value)
    slot_status.assert_consumed()
  def testDependencyLoop(self):
    # Note: this test creates garbage during eager execution because it
    # purposefully creates a reference cycle.
    first = checkpointable.Checkpointable()
    second = checkpointable.Checkpointable()
    first.second = second
    second.first = first
    first.v = checkpointable_utils.add_variable(
        first, "v1", initializer=[3., 1., 4.])
    second.v = checkpointable_utils.add_variable(
        second, "v2", initializer=[1., 1., 2., 3.])
    self.evaluate(checkpointable_utils.gather_initializers(first))
    checkpoint_directory = self.get_temp_dir()
    save_path = checkpointable_utils.CheckpointableSaver(first).save(
        os.path.join(checkpoint_directory, "ckpt"))

    # Test deferred loading
    first_load = checkpointable.Checkpointable()
    status = checkpointable_utils.CheckpointableSaver(
        first_load).restore(save_path)
    second_load = checkpointable.Checkpointable()
    first_load.second = second_load
    second_load.first = first_load
    with self.assertRaises(AssertionError):
      status.assert_consumed()
    first_load.v = checkpointable_utils.add_variable(
        first_load, "v1", shape=[3])
    second_load.v = checkpointable_utils.add_variable(
        second_load, "v2", shape=[4])
    status.assert_consumed()
    status.run_restore_ops()
    self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v))
    self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v))

    # Test loading when variables have already been created
    self.evaluate(first_load.v.assign([2., 7., 1.]))
    self.assertAllEqual([2., 7., 1.], self.evaluate(first_load.v))
    self.evaluate(second_load.v.assign([2., 7., 1., 8.]))
    self.assertAllEqual([2., 7., 1., 8.], self.evaluate(second_load.v))
    status = checkpointable_utils.CheckpointableSaver(first_load).restore(
        save_path).assert_consumed()
    status.run_restore_ops()
    self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v))
    self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v))
 def testAmbiguousLoad(self):
   # Not OK to split one checkpoint object into two
   checkpoint_directory = self.get_temp_dir()
   save_root = checkpointable.Checkpointable()
   save_root.dep_one = checkpointable.Checkpointable()
   save_root.dep_two = checkpointable.Checkpointable()
   dep_three = checkpointable.Checkpointable()
   save_root.dep_one.dep_three = dep_three
   save_root.dep_two.dep_three = dep_three
   checkpointable_utils.add_variable(dep_three, name="var", initializer=0.)
   self.evaluate(checkpointable_utils.gather_initializers(save_root))
   save_path = checkpointable_utils.CheckpointableSaver(save_root).save(
       os.path.join(checkpoint_directory, "ckpt"))
   load_root = checkpointable.Checkpointable()
   checkpointable_utils.CheckpointableSaver(load_root).restore(save_path)
   load_root.dep_one = checkpointable.Checkpointable()
   load_root.dep_two = checkpointable.Checkpointable()
   load_root.dep_one.dep_three = checkpointable.Checkpointable()
   with self.assertRaisesRegexp(AssertionError,
                                "resolved to different objects"):
     load_root.dep_two.dep_three = checkpointable.Checkpointable()
 def testObjectsCombined(self):
   # Currently fine to load two checkpoint objects into one Python object
   checkpoint_directory = self.get_temp_dir()
   save_root = checkpointable.Checkpointable()
   save_root.dep_one = checkpointable.Checkpointable()
   save_root.dep_two = checkpointable.Checkpointable()
   checkpointable_utils.add_variable(
       save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64)
   checkpointable_utils.add_variable(
       save_root.dep_two, name="var2", initializer=64., dtype=dtypes.float64)
   self.evaluate(checkpointable_utils.gather_initializers(save_root))
   save_path = checkpointable_utils.CheckpointableSaver(save_root).save(
       os.path.join(checkpoint_directory, "ckpt"))
   load_root = checkpointable.Checkpointable()
   load_root.dep_one = checkpointable.Checkpointable()
   load_root.dep_two = load_root.dep_one
   v1 = checkpointable_utils.add_variable(
       load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64)
   v2 = checkpointable_utils.add_variable(
       load_root.dep_one, name="var2", shape=[], dtype=dtypes.float64)
   status = checkpointable_utils.CheckpointableSaver(load_root).restore(
       save_path).assert_consumed()
   status.run_restore_ops()
   self.assertEqual(32., self.evaluate(v1))
   self.assertEqual(64., self.evaluate(v2))
  def save_counter(self):
    """An integer variable which starts at zero and is incremented on save.

    Used to number checkpoints.

    Returns:
      The save counter variable.
    """
    if self._save_counter is None:
      # Initialized to 0 and incremented before saving.
      self._save_counter = checkpointable_utils.add_variable(
          self, name="save_counter", initializer=0, dtype=dtypes.int64)
    return self._save_counter
 def __init__(self):
   super(NonLayerCheckpointable, self).__init__()
   self.a_variable = checkpointable_utils.add_variable(
       self, name="a_variable", shape=[])
  def testDeferredSlotRestoration(self):
    checkpoint_directory = self.get_temp_dir()

    root = checkpointable.Checkpointable()
    root.var = checkpointable_utils.add_variable(
        root, name="var", initializer=0.)
    optimizer = adam.AdamOptimizer(0.1)
    if context.in_graph_mode():
      train_op = optimizer.minimize(root.var)
      # Note that `optimizer` has not been added as a dependency of
      # `root`. Create a one-off grouping so that slot variables for `root.var`
      # get initialized too.
      self.evaluate(checkpointable_utils.gather_initializers(
          checkpointable_utils.Checkpoint(root=root, optimizer=optimizer)))
      self.evaluate(train_op)
    else:
      optimizer.minimize(root.var.read_value)
    self.evaluate(state_ops.assign(root.var, 12.))
    no_slots_path = checkpointable_utils.CheckpointableSaver(root).save(
        os.path.join(checkpoint_directory, "no_slots"))
    root.optimizer = optimizer
    self.evaluate(state_ops.assign(root.var, 13.))
    self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
                                   14.))
    slots_path = checkpointable_utils.CheckpointableSaver(root).save(
        os.path.join(checkpoint_directory, "with_slots"))
    new_root = checkpointable.Checkpointable()
    # Load the slot-containing checkpoint (deferred), then immediately overwrite
    # the non-slot variable (also deferred).
    slot_status = checkpointable_utils.CheckpointableSaver(
        new_root).restore(slots_path)
    no_slot_status = checkpointable_utils.CheckpointableSaver(
        new_root).restore(no_slots_path)
    with self.assertRaises(AssertionError):
      no_slot_status.assert_consumed()
    new_root.var = checkpointable_utils.add_variable(
        new_root, name="var", shape=[])
    no_slot_status.assert_consumed()
    no_slot_status.run_restore_ops()
    self.assertEqual(12., self.evaluate(new_root.var))
    new_root.optimizer = adam.AdamOptimizer(0.1)
    with self.assertRaisesRegexp(AssertionError, "beta1_power"):
      slot_status.assert_consumed()
    self.assertEqual(12., self.evaluate(new_root.var))
    if context.in_eager_mode():
      # Slot variables are only created with restoring initializers when
      # executing eagerly.
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
    else:
      self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
                    None)
    if context.in_graph_mode():
      train_op = new_root.optimizer.minimize(new_root.var)
      # The slot variable now exists; restore() didn't create it, but we should
      # now have a restore op for it.
      slot_status.run_restore_ops()
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
      self.evaluate(train_op)
    else:
      new_root.optimizer.minimize(new_root.var.read_value)
    slot_status.assert_consumed()
 def build(self):
   self.var = checkpointable_utils.add_variable(
       self, "var", initializer=0.)
  def testDeferredSlotRestoration(self):
    checkpoint_directory = self.get_temp_dir()

    root = checkpointable.Checkpointable()
    root.var = checkpointable_utils.add_variable(
        root, name="var", initializer=0.)
    optimizer = adam.AdamOptimizer(0.1)
    if context.executing_eagerly():
      optimizer.minimize(root.var.read_value)
    else:
      train_op = optimizer.minimize(root.var)
      # Note that `optimizer` has not been added as a dependency of
      # `root`. Create a one-off grouping so that slot variables for `root.var`
      # get initialized too.
      self.evaluate(checkpointable_utils.gather_initializers(
          checkpointable_utils.Checkpoint(root=root, optimizer=optimizer)))
      self.evaluate(train_op)
    self.evaluate(state_ops.assign(root.var, 12.))
    no_slots_path = checkpointable_utils.CheckpointableSaver(root).save(
        os.path.join(checkpoint_directory, "no_slots"))
    root.optimizer = optimizer
    self.evaluate(state_ops.assign(root.var, 13.))
    self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
                                   14.))
    slots_path = checkpointable_utils.CheckpointableSaver(root).save(
        os.path.join(checkpoint_directory, "with_slots"))
    new_root = checkpointable.Checkpointable()
    # Load the slot-containing checkpoint (deferred), then immediately overwrite
    # the non-slot variable (also deferred).
    slot_status = checkpointable_utils.CheckpointableSaver(
        new_root).restore(slots_path)
    no_slot_status = checkpointable_utils.CheckpointableSaver(
        new_root).restore(no_slots_path)
    with self.assertRaises(AssertionError):
      no_slot_status.assert_consumed()
    new_root.var = checkpointable_utils.add_variable(
        new_root, name="var", shape=[])
    no_slot_status.assert_consumed()
    no_slot_status.run_restore_ops()
    self.assertEqual(12., self.evaluate(new_root.var))
    new_root.optimizer = adam.AdamOptimizer(0.1)
    with self.assertRaisesRegexp(AssertionError, "beta1_power"):
      slot_status.assert_consumed()
    self.assertEqual(12., self.evaluate(new_root.var))
    if context.executing_eagerly():
      # Slot variables are only created with restoring initializers when
      # executing eagerly.
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
    else:
      self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
                    None)
    if context.executing_eagerly():
      new_root.optimizer.minimize(new_root.var.read_value)
    else:
      train_op = new_root.optimizer.minimize(new_root.var)
      # The slot variable now exists; restore() didn't create it, but we should
      # now have a restore op for it.
      slot_status.run_restore_ops()
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
      self.evaluate(train_op)
    slot_status.assert_consumed()