def testRestoreOnAssign(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   save_graph = ops.Graph()
   with save_graph.as_default(), self.test_session(save_graph):
     first = checkpointable.Checkpointable()
     first.var1 = variable_scope.get_variable(
         name="outside_var", initializer=0.)
     first.var2 = variable_scope.get_variable(
         name="blah", initializer=0.)
     self.evaluate(first.var1.assign(4.))
     self.evaluate(first.var2.assign(8.))
     save_path = checkpointable_utils.save(
         checkpoint_prefix, root_checkpointable=first)
   restore_graph = ops.Graph()
   with restore_graph.as_default(), self.test_session(restore_graph):
     second = checkpointable.Checkpointable()
     second.var2 = variable_scope.get_variable(
         name="blah", initializer=0.)
     checkpointable_utils.restore(save_path, root_checkpointable=second)
     recreated_var1 = variable_scope.get_variable(
         name="outside_var", initializer=0.)
     self.assertEqual(8., self.evaluate(second.var2))
     self.evaluate(recreated_var1.assign(-2.))
     self.assertEqual(-2., self.evaluate(recreated_var1))
     second.var1 = recreated_var1
     self.assertEqual(4., self.evaluate(recreated_var1))
Ejemplo n.º 2
0
  def testDeferredSlotRestoration(self):
    checkpoint_directory = self.get_temp_dir()

    root = checkpointable.Checkpointable()
    root.var = checkpointable_utils.add_variable(
        root, name="var", initializer=0.)
    optimizer = CheckpointableAdam(0.1)
    if context.in_graph_mode():
      train_op = optimizer.minimize(root.var)
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(train_op)
    else:
      optimizer.minimize(root.var.read_value)
    self.evaluate(state_ops.assign(root.var, 12.))
    no_slots_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "no_slots"), root)
    root.optimizer = optimizer
    self.evaluate(state_ops.assign(root.var, 13.))
    self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
                                   14.))
    slots_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "with_slots"), root)
    new_root = checkpointable.Checkpointable()
    # Load the slot-containing checkpoint (deferred), then immediately overwrite
    # the non-slot variable (also deferred).
    slot_status = checkpointable_utils.restore(
        slots_path, new_root)
    no_slot_status = checkpointable_utils.restore(
        no_slots_path, new_root)
    with self.assertRaises(AssertionError):
      no_slot_status.assert_consumed()
    new_root.var = checkpointable_utils.add_variable(
        new_root, name="var", shape=[])
    no_slot_status.assert_consumed()
    self.evaluate(no_slot_status.restore_ops)
    self.assertEqual(12., self.evaluate(new_root.var))
    new_root.optimizer = CheckpointableAdam(0.1)
    with self.assertRaisesRegexp(AssertionError, "beta1_power"):
      slot_status.assert_consumed()
    self.assertEqual(12., self.evaluate(new_root.var))
    if context.in_eager_mode():
      # Slot variables are only created with restoring initializers when
      # executing eagerly.
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
    else:
      self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
                    None)
    if context.in_graph_mode():
      train_op = new_root.optimizer.minimize(new_root.var)
      # The slot variable now exists; restore() didn't create it, but we should
      # now have a restore op for it.
      self.evaluate(slot_status.restore_ops)
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
      self.evaluate(train_op)
    else:
      new_root.optimizer.minimize(new_root.var.read_value)
    slot_status.assert_consumed()
Ejemplo n.º 3
0
  def testOverlappingRestores(self):
    checkpoint_directory = self.get_temp_dir()
    save_root = checkpointable.Checkpointable()
    save_root.dep = checkpointable.Checkpointable()
    save_root.dep.var = checkpointable_utils.add_variable(
        save_root.dep, name="var", initializer=0.)
    self.evaluate(state_ops.assign(save_root.dep.var, 12.))
    first_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "first"), save_root)
    self.evaluate(state_ops.assign(save_root.dep.var, 13.))
    second_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "second"), save_root)

    first_root = checkpointable.Checkpointable()
    second_root = checkpointable.Checkpointable()
    first_status = checkpointable_utils.restore(
        first_path, first_root)
    second_status = checkpointable_utils.restore(
        second_path, second_root)
    load_dep = checkpointable.Checkpointable()
    load_dep.var = checkpointable_utils.add_variable(
        load_dep, name="var", shape=[])
    first_root.dep = load_dep
    first_status.assert_consumed()
    self.evaluate(first_status.restore_ops)
    self.assertEqual([], second_status.restore_ops)
    self.assertEqual(12., self.evaluate(load_dep.var))
    second_root.dep = load_dep
    second_status.assert_consumed()
    self.evaluate(second_status.restore_ops)
    self.assertEqual(13., self.evaluate(load_dep.var))

    # Try again with the order of the restore() reversed. The last restore
    # determines the final value.
    first_root = checkpointable.Checkpointable()
    second_root = checkpointable.Checkpointable()
    second_status = checkpointable_utils.restore(
        second_path, second_root)
    first_status = checkpointable_utils.restore(
        first_path, first_root)
    load_dep = checkpointable.Checkpointable()
    load_dep.var = checkpointable_utils.add_variable(
        load_dep, name="var", shape=[])
    first_root.dep = load_dep
    first_status.assert_consumed()
    self.assertEqual([], second_status.restore_ops)
    self.evaluate(first_status.restore_ops)
    self.assertEqual(12., self.evaluate(load_dep.var))
    second_root.dep = load_dep
    second_status.assert_consumed()
    self.evaluate(second_status.restore_ops)
    self.assertEqual(12., self.evaluate(load_dep.var))
  def testDepAfterVar(self):

    class Dependency(checkpointable.Checkpointable):

      def build(self):
        self.var = checkpointable_utils.add_variable(
            self, "var", initializer=0.)

    class DepAfterVar(checkpointable.Checkpointable):

      def add_dep(self):
        dep = Dependency()
        dep.build()
        self.dep = dep

    dep_after_var = DepAfterVar()
    dep_after_var.add_dep()
    self.evaluate(state_ops.assign(dep_after_var.dep.var, -14.))
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = checkpointable_utils.save(
        checkpoint_prefix, dep_after_var)

    loaded_dep_after_var = DepAfterVar()
    status = checkpointable_utils.restore(
        save_path, loaded_dep_after_var)
    loaded_dep_after_var.add_dep()
    status.assert_consumed()
    self.assertEqual(-14., self.evaluate(loaded_dep_after_var.dep.var))
  def testLateDependencyTracking(self):

    class Dependency(checkpointable.Checkpointable):

      def build(self):
        self.var = checkpointable_utils.add_variable(
            self, "var", initializer=0.)

    class LateDependencies(checkpointable.Checkpointable):

      def add_dep(self):
        self.dep = Dependency()
        self.dep.build()

    original = LateDependencies()
    original.add_dep()
    self.evaluate(state_ops.assign(original.dep.var, 123.))
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = checkpointable_utils.save(checkpoint_prefix, original)
    load_into = LateDependencies()
    status = checkpointable_utils.restore(save_path, load_into)
    with self.assertRaises(AssertionError):
      status.assert_consumed()
    load_into.add_dep()
    status.assert_consumed()
    self.assertEqual(123., self.evaluate(load_into.dep.var))
Ejemplo n.º 6
0
 def testObjectsCombined(self):
   # Currently fine to load two checkpoint objects into one Python object
   checkpoint_directory = self.get_temp_dir()
   save_root = checkpointable.Checkpointable()
   save_root.dep_one = checkpointable.Checkpointable()
   save_root.dep_two = checkpointable.Checkpointable()
   checkpointable_utils.add_variable(
       save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64)
   checkpointable_utils.add_variable(
       save_root.dep_two, name="var2", initializer=64., dtype=dtypes.float64)
   self.evaluate(variables.global_variables_initializer())
   save_path = checkpointable_utils.save(
       os.path.join(checkpoint_directory, "ckpt"), save_root)
   load_root = checkpointable.Checkpointable()
   load_root.dep_one = checkpointable.Checkpointable()
   load_root.dep_two = load_root.dep_one
   v1 = checkpointable_utils.add_variable(
       load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64)
   v2 = checkpointable_utils.add_variable(
       load_root.dep_one, name="var2", shape=[], dtype=dtypes.float64)
   status = checkpointable_utils.restore(
       save_path, load_root).assert_consumed()
   self.evaluate(status.restore_ops)
   self.assertEqual(32., self.evaluate(v1))
   self.assertEqual(64., self.evaluate(v2))
  def testDeferredSlotRestoration(self):
    checkpoint_directory = self.get_temp_dir()

    root = checkpointable.Checkpointable()
    root.var = checkpointable_utils.add_variable(
        root, name="var", initializer=0.)
    optimizer = CheckpointableAdam(0.1)
    if context.in_graph_mode():
      train_op = optimizer.minimize(root.var)
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(train_op)
    else:
      optimizer.minimize(root.var.read_value)
    self.evaluate(state_ops.assign(root.var, 12.))
    no_slots_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "no_slots"), root)
    root.optimizer = optimizer
    self.evaluate(state_ops.assign(root.var, 13.))
    self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
                                   14.))
    slots_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "with_slots"), root)
    new_root = checkpointable.Checkpointable()
    # Load the slot-containing checkpoint (deferred), then immediately overwrite
    # the non-slot variable (also deferred).
    slot_status = checkpointable_utils.restore(
        slots_path, new_root)
    no_slot_status = checkpointable_utils.restore(
        no_slots_path, new_root)
    with self.assertRaises(AssertionError):
      no_slot_status.assert_consumed()
    new_root.var = checkpointable_utils.add_variable(
        new_root, name="var", shape=[])
    self.assertEqual(12., self.evaluate(new_root.var))
    no_slot_status.assert_consumed()
    new_root.optimizer = CheckpointableAdam(0.1)
    with self.assertRaisesRegexp(AssertionError, "beta1_power"):
      slot_status.assert_consumed()
    self.assertEqual(12., self.evaluate(new_root.var))
    self.assertEqual(14., self.evaluate(
        new_root.optimizer.get_slot(name="m", var=new_root.var)))
    if context.in_graph_mode():
      train_op = new_root.optimizer.minimize(new_root.var)
      self.evaluate(train_op)
    else:
      new_root.optimizer.minimize(new_root.var.read_value)
    slot_status.assert_consumed()
  def testOverlappingRestores(self):
    checkpoint_directory = self.get_temp_dir()
    save_root = checkpointable.Checkpointable()
    save_root.dep = checkpointable.Checkpointable()
    save_root.dep.var = checkpointable_utils.add_variable(
        save_root.dep, name="var", initializer=0.)
    self.evaluate(state_ops.assign(save_root.dep.var, 12.))
    first_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "first"), save_root)
    self.evaluate(state_ops.assign(save_root.dep.var, 13.))
    second_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "second"), save_root)

    first_root = checkpointable.Checkpointable()
    second_root = checkpointable.Checkpointable()
    first_status = checkpointable_utils.restore(
        first_path, first_root)
    second_status = checkpointable_utils.restore(
        second_path, second_root)
    load_dep = checkpointable.Checkpointable()
    load_dep.var = checkpointable_utils.add_variable(
        load_dep, name="var", shape=[])
    first_root.dep = load_dep
    first_status.assert_consumed()
    self.assertEqual(12., self.evaluate(load_dep.var))
    second_root.dep = load_dep
    second_status.assert_consumed()
    self.assertEqual(13., self.evaluate(load_dep.var))

    # Try again with the order of the restore() reversed. The last restore
    # determines the final value.
    first_root = checkpointable.Checkpointable()
    second_root = checkpointable.Checkpointable()
    second_status = checkpointable_utils.restore(
        second_path, second_root)
    first_status = checkpointable_utils.restore(
        first_path, first_root)
    load_dep = checkpointable.Checkpointable()
    load_dep.var = checkpointable_utils.add_variable(
        load_dep, name="var", shape=[])
    first_root.dep = load_dep
    first_status.assert_consumed()
    self.assertEqual(12., self.evaluate(load_dep.var))
    second_root.dep = load_dep
    second_status.assert_consumed()
    self.assertEqual(12., self.evaluate(load_dep.var))
Ejemplo n.º 9
0
  def testDependencyLoop(self):
    # Note: this test creates garbage during eager execution because it
    # purposefully creates a reference cycle.
    first = checkpointable.Checkpointable()
    second = checkpointable.Checkpointable()
    first.second = second
    second.first = first
    first.v = checkpointable_utils.add_variable(
        first, "v1", initializer=[3., 1., 4.])
    second.v = checkpointable_utils.add_variable(
        second, "v2", initializer=[1., 1., 2., 3.])
    self.evaluate(variables.global_variables_initializer())
    checkpoint_directory = self.get_temp_dir()
    save_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "ckpt"), first)

    # Test deferred loading
    first_load = checkpointable.Checkpointable()
    status = checkpointable_utils.restore(save_path, first_load)
    second_load = checkpointable.Checkpointable()
    first_load.second = second_load
    second_load.first = first_load
    with self.assertRaises(AssertionError):
      status.assert_consumed()
    first_load.v = checkpointable_utils.add_variable(
        first_load, "v1", shape=[3])
    second_load.v = checkpointable_utils.add_variable(
        second_load, "v2", shape=[4])
    status.assert_consumed()
    self.evaluate(status.restore_ops)
    self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v))
    self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v))

    # Test loading when variables have already been created
    self.evaluate(first_load.v.assign([2., 7., 1.]))
    self.assertAllEqual([2., 7., 1.], self.evaluate(first_load.v))
    self.evaluate(second_load.v.assign([2., 7., 1., 8.]))
    self.assertAllEqual([2., 7., 1., 8.], self.evaluate(second_load.v))
    status = checkpointable_utils.restore(
        save_path, first_load).assert_consumed()
    self.evaluate(status.restore_ops)
    self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v))
    self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v))
 def testManyRestoresGraph(self):
   """Restores after the first should not modify the graph."""
   with context.graph_mode():
     graph = ops.Graph()
     with graph.as_default(), self.test_session(graph):
       checkpoint_directory = self.get_temp_dir()
       checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
       obj = checkpointable.Checkpointable()
       obj.var = variable_scope.get_variable(name="v", initializer=0.)
       obj.opt = CheckpointableAdam(0.1)
       obj.opt.minimize(obj.var.read_value())
       self.evaluate(variables.global_variables_initializer())
       save_path = checkpointable_utils.save(
           checkpoint_prefix, root_checkpointable=obj)
       checkpointable_utils.restore(
           save_path, root_checkpointable=obj)
       before_ops = graph.get_operations()
       checkpointable_utils.restore(
           save_path, root_checkpointable=obj)
       self.assertEqual(before_ops, graph.get_operations())
  def testDependencyLoop(self):
    # Note: this test creates garbage during eager execution because it
    # purposefully creates a reference cycle.
    first = checkpointable.Checkpointable()
    second = checkpointable.Checkpointable()
    first.second = second
    second.first = first
    first.v = checkpointable_utils.add_variable(
        first, "v1", initializer=[3., 1., 4.])
    second.v = checkpointable_utils.add_variable(
        second, "v2", initializer=[1., 1., 2., 3.])
    self.evaluate(variables.global_variables_initializer())
    checkpoint_directory = self.get_temp_dir()
    save_path = checkpointable_utils.save(
        os.path.join(checkpoint_directory, "ckpt"), first)

    # Test deferred loading
    first_load = checkpointable.Checkpointable()
    status = checkpointable_utils.restore(save_path, first_load)
    second_load = checkpointable.Checkpointable()
    first_load.second = second_load
    second_load.first = first_load
    with self.assertRaises(AssertionError):
      status.assert_consumed()
    first_load.v = checkpointable_utils.add_variable(
        first_load, "v1", shape=[3])
    second_load.v = checkpointable_utils.add_variable(
        second_load, "v2", shape=[4])
    status.assert_consumed()
    self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v))
    self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v))

    # Test loading when variables have already been created
    self.evaluate(first_load.v.assign([2., 7., 1.]))
    self.assertAllEqual([2., 7., 1.], self.evaluate(first_load.v))
    self.evaluate(second_load.v.assign([2., 7., 1., 8.]))
    self.assertAllEqual([2., 7., 1., 8.], self.evaluate(second_load.v))
    checkpointable_utils.restore(
        save_path, first_load).assert_consumed()
    self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v))
    self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v))
 def testAmbiguousLoad(self):
   # Not OK to split one checkpoint object into two
   checkpoint_directory = self.get_temp_dir()
   save_root = checkpointable.Checkpointable()
   save_root.dep_one = checkpointable.Checkpointable()
   save_root.dep_two = checkpointable.Checkpointable()
   dep_three = checkpointable.Checkpointable()
   save_root.dep_one.dep_three = dep_three
   save_root.dep_two.dep_three = dep_three
   checkpointable_utils.add_variable(dep_three, name="var", initializer=0.)
   self.evaluate(variables.global_variables_initializer())
   save_path = checkpointable_utils.save(
       os.path.join(checkpoint_directory, "ckpt"), save_root)
   load_root = checkpointable.Checkpointable()
   checkpointable_utils.restore(save_path, load_root)
   load_root.dep_one = checkpointable.Checkpointable()
   load_root.dep_two = checkpointable.Checkpointable()
   load_root.dep_one.dep_three = checkpointable.Checkpointable()
   with self.assertRaisesRegexp(AssertionError,
                                "resolved to different objects"):
     load_root.dep_two.dep_three = checkpointable.Checkpointable()
 def testUsageGraph(self):
   """Expected usage when graph building."""
   with context.graph_mode():
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     for training_continuation in range(3):
       with ops.Graph().as_default():
         network = MyNetwork()
         optimizer = CheckpointableAdam(0.001)
         root = Checkpoint(
             optimizer=optimizer, network=network,
             global_step=training_util.get_or_create_global_step())
         input_value = constant_op.constant([[3.]])
         train_op = optimizer.minimize(
             network(input_value),
             global_step=root.global_step)
         root.save_counter  # pylint: disable=pointless-statement
         init_op = variables.global_variables_initializer()
         checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
         with self.test_session(graph=ops.get_default_graph()) as session:
           if checkpoint_path is None:
             self.assertEqual(0, training_continuation)
             session.run(init_op)
             # Another alternative would be to run initializers automatically
             # if no checkpoint is being loaded. This would make deferred
             # loading a bit more useful with graph execution.
           else:
             checkpointable_utils.restore(
                 save_path=checkpoint_path,
                 root_checkpointable=root,
                 session=session)
           for _ in range(num_training_steps):
             session.run(train_op)
           root.save(file_prefix=checkpoint_prefix,
                     session=session)
           self.assertEqual((training_continuation + 1) * num_training_steps,
                            session.run(root.global_step))
           self.assertEqual(training_continuation + 1,
                            session.run(root.save_counter))
 def testObjectsCombined(self):
   # Currently fine to load two checkpoint objects into one Python object
   checkpoint_directory = self.get_temp_dir()
   save_root = checkpointable.Checkpointable()
   save_root.dep_one = checkpointable.Checkpointable()
   save_root.dep_two = checkpointable.Checkpointable()
   checkpointable_utils.add_variable(
       save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64)
   checkpointable_utils.add_variable(
       save_root.dep_two, name="var2", initializer=64., dtype=dtypes.float64)
   self.evaluate(variables.global_variables_initializer())
   save_path = checkpointable_utils.save(
       os.path.join(checkpoint_directory, "ckpt"), save_root)
   load_root = checkpointable.Checkpointable()
   load_root.dep_one = checkpointable.Checkpointable()
   load_root.dep_two = load_root.dep_one
   v1 = checkpointable_utils.add_variable(
       load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64)
   v2 = checkpointable_utils.add_variable(
       load_root.dep_one, name="var2", shape=[], dtype=dtypes.float64)
   checkpointable_utils.restore(save_path, load_root).assert_consumed()
   self.assertEqual(32., self.evaluate(v1))
   self.assertEqual(64., self.evaluate(v2))
 def restore(self, save_path):
   return checkpointable_utils.restore(
       save_path=save_path,
       root_checkpointable=self)