def testDeferredSlotRestoration(self):
    checkpoint_directory = self.get_temp_dir()

    root = trackable_utils.Checkpoint()
    root.var = trackable_utils.add_variable(
        root, name="var", initializer=0.)
    optimizer = adam.AdamOptimizer(0.1)
    if context.executing_eagerly():
      optimizer.minimize(root.var.read_value)
    else:
      train_op = optimizer.minimize(root.var)
      # Note that `optimizer` has not been added as a dependency of
      # `root`. Create a one-off grouping so that slot variables for `root.var`
      # get initialized too.
      self.evaluate(trackable_utils.gather_initializers(
          trackable_utils.Checkpoint(root=root, optimizer=optimizer)))
      self.evaluate(train_op)
    self.evaluate(state_ops.assign(root.var, 12.))
    no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots"))
    root.optimizer = optimizer
    self.evaluate(state_ops.assign(root.var, 13.))
    self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
                                   14.))
    slots_path = root.save(os.path.join(checkpoint_directory, "with_slots"))
    new_root = trackable_utils.Checkpoint()
    # Load the slot-containing checkpoint (deferred), then immediately overwrite
    # the non-slot variable (also deferred).
    slot_status = new_root.restore(slots_path)
    no_slot_status = new_root.restore(no_slots_path)
    with self.assertRaises(AssertionError):
      no_slot_status.assert_consumed()
    new_root.var = trackable_utils.add_variable(
        new_root, name="var", shape=[])
    no_slot_status.assert_consumed()
    no_slot_status.run_restore_ops()
    self.assertEqual(12., self.evaluate(new_root.var))
    new_root.optimizer = adam.AdamOptimizer(0.1)
    slot_status.assert_existing_objects_matched()
    with self.assertRaisesRegex(AssertionError, "beta1_power"):
      slot_status.assert_consumed()
    self.assertEqual(12., self.evaluate(new_root.var))
    if context.executing_eagerly():
      # Slot variables are only created with restoring initializers when
      # executing eagerly.
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
    else:
      self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
                    None)
    if context.executing_eagerly():
      new_root.optimizer.minimize(new_root.var.read_value)
    else:
      train_op = new_root.optimizer.minimize(new_root.var)
      # The slot variable now exists; restore() didn't create it, but we should
      # now have a restore op for it.
      slot_status.run_restore_ops()
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
      self.evaluate(train_op)
    slot_status.assert_consumed()
示例#2
0
 def __init__(self):
     super().__init__()
     self.a_variable = trackable_utils.add_variable(self,
                                                    name="a_variable",
                                                    shape=[])
示例#3
0
    def testDeferredSlotRestoration(self):
        with self.test_session():
            checkpoint_directory = self.get_temp_dir()

            root = tf.train.Checkpoint()
            root.var = trackable_utils.add_variable(root,
                                                    name="var",
                                                    initializer=0.0)
            optimizer = adam.Adam(0.1)
            variables = [root.var]
            gradients = [1.0]
            train_op = optimizer.apply_gradients(zip(gradients, variables))
            # Note that `optimizer` has not been added as a dependency of
            # `root`. Create a one-off grouping so that slot variables for
            # `root.var` get initialized too.
            self.evaluate(
                trackable_utils.gather_initializers(
                    tf.train.Checkpoint(root=root, optimizer=optimizer)))
            self.evaluate(train_op)
            self.evaluate(tf.compat.v1.assign(root.var, 12.0))
            no_slots_path = root.save(
                os.path.join(checkpoint_directory, "no_slots"))
            root.optimizer = optimizer
            self.evaluate(tf.compat.v1.assign(root.var, 13.0))
            self.evaluate(
                tf.compat.v1.assign(
                    optimizer.get_slot(slot_name="m", var=root.var), 14.0))
            slots_path = root.save(
                os.path.join(checkpoint_directory, "with_slots"))
            new_root = tf.train.Checkpoint()
            # Load the slot-containing checkpoint (deferred), then immediately
            # overwrite the non-slot variable (also deferred).
            slot_status = new_root.restore(slots_path)
            no_slot_status = new_root.restore(no_slots_path)
            with self.assertRaises(AssertionError):
                no_slot_status.assert_consumed()
            new_root.var = trackable_utils.add_variable(new_root,
                                                        name="var",
                                                        shape=[])
            no_slot_status.assert_consumed()
            no_slot_status.run_restore_ops()
            self.assertEqual(12.0, self.evaluate(new_root.var))
            new_root.optimizer = adam.Adam(0.1)
            slot_status.assert_existing_objects_matched()
            if not tf.executing_eagerly():
                with self.assertRaisesRegex(AssertionError,
                                            "Unresolved object"):
                    slot_status.assert_consumed()
            self.assertEqual(12.0, self.evaluate(new_root.var))
            if tf.executing_eagerly():
                # Slot variables are only created with restoring initializers
                # when executing eagerly.
                self.assertEqual(
                    14.0,
                    self.evaluate(
                        new_root.optimizer.get_slot(slot_name="m",
                                                    var=new_root.var)),
                )
            else:
                # Slot variables are not created eagerly when graph building.
                with self.assertRaises(KeyError):
                    new_root.optimizer.get_slot(slot_name="m",
                                                var=new_root.var)
            variables = [new_root.var]
            gradients = [1.0]
            train_op = new_root.optimizer.apply_gradients(
                zip(gradients, variables))
            # The slot variable now exists; restore() didn't create it, but we
            # should now have a restore op for it.
            slot_status.run_restore_ops()
            if not tf.executing_eagerly():
                # The train op hasn't run when graph building, so the slot
                # variable has its restored value. It has run in eager, so the
                # value will be different.
                self.assertEqual(
                    14.0,
                    self.evaluate(
                        new_root.optimizer.get_slot(slot_name="m",
                                                    var=new_root.var)),
                )
            self.evaluate(train_op)
            slot_status.assert_consumed()