def _initialized_model(self):
     input_value = constant_op.constant([[3.]])
     model = MyModel()
     optimizer = adam.AdamOptimizer(0.001)
     optimizer_step = training_util.get_or_create_global_step()
     root_checkpointable = util.Checkpoint(optimizer=optimizer,
                                           model=model,
                                           optimizer_step=optimizer_step)
     train_op = optimizer.minimize(functools.partial(model, input_value),
                                   global_step=optimizer_step)
     self.evaluate(util.gather_initializers(root_checkpointable))
     self.evaluate(train_op)
     # A regular variable, a slot variable, and a non-slot Optimizer variable
     # with known values to check when loading.
     self.evaluate(model._named_dense.bias.assign([1.]))
     self.evaluate(
         optimizer.get_slot(var=model._named_dense.bias,
                            name="m").assign([2.]))
     beta_1_power, _ = optimizer._get_beta_accumulators()
     self.evaluate(beta_1_power.assign(3.))
     return root_checkpointable
Ejemplo n.º 2
0
 def testDictionariesBasic(self):
   a = training.Model()
   b = training.Model()
   a.attribute = {"b": b}
   c = training.Model()
   a.attribute["c"] = []
   a.attribute["c"].append(c)
   a_deps = util.list_objects(a)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   self.assertIs(b, a.attribute["b"])
   six.assertCountEqual(
       self,
       ["b", "c"],
       [dep.name for dep in a.attribute._checkpoint_dependencies])
   self.assertEqual([b, c], a.layers)
   self.assertEqual([b, c], a.attribute.layers)
   self.assertEqual([c], a.attribute["c"].layers)
   checkpoint = util.Checkpoint(a=a)
   save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
   checkpoint.restore(save_path).assert_consumed()
 def testDeferredRestorationUsageEager(self):
   """An idiomatic eager execution example."""
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   for training_continuation in range(3):
     model = MyModel()
     optimizer = adam.AdamOptimizer(0.001)
     root = checkpointable_utils.Checkpoint(
         optimizer=optimizer, model=model,
         optimizer_step=training_util.get_or_create_global_step())
     root.restore(core_saver.latest_checkpoint(checkpoint_directory))
     for _ in range(num_training_steps):
       # TODO(allenl): Use a Dataset and serialize/checkpoint it.
       input_value = constant_op.constant([[3.]])
       optimizer.minimize(
           lambda: model(input_value),  # pylint: disable=cell-var-from-loop
           global_step=root.optimizer_step)
     root.save(file_prefix=checkpoint_prefix)
     self.assertEqual((training_continuation + 1) * num_training_steps,
                      root.optimizer_step.numpy())
 def testWithDefun(self):
   num_training_steps = 2
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   for training_continuation in range(3):
     with test_util.device(use_gpu=True):
       model = MyModel()
       # Don't actually train so we can test variable values
       optimizer = adam.AdamOptimizer(0.)
       root = checkpointable_utils.Checkpoint(
           optimizer=optimizer, model=model,
           global_step=training_util.get_or_create_global_step())
       checkpoint_path = checkpoint_management.latest_checkpoint(
           checkpoint_directory)
       status = root.restore(save_path=checkpoint_path)
       def train_fn():
         @def_function.function
         def _call_model(x):
           return model(x)
         with backprop.GradientTape() as tape:
           loss = _call_model(constant_op.constant([[3.]]))
         gradients = tape.gradient(loss, model.variables)
         return optimizer.apply_gradients(zip(gradients, model.variables),
                                          global_step=root.global_step)
       if not context.executing_eagerly():
         train_fn = functools.partial(
             self.evaluate, train_fn())
       status.initialize_or_restore()
       for _ in range(num_training_steps):
         train_fn()
       if training_continuation > 0:
         status.assert_consumed()
         self.assertAllClose([[42.]], self.evaluate(model.variables[0]))
       else:
         self.evaluate(model.variables[0].assign([[42.]]))
       root.save(file_prefix=checkpoint_prefix)
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        self.evaluate(root.global_step))
       self.assertEqual(training_continuation + 1,
                        self.evaluate(root.save_counter))
Ejemplo n.º 5
0
 def test_table(self, cycles):
     # TODO(b/123408779): Handle generic TrackableResources and enable this test
     self.skipTest("Need to handle generic TrackableResources")
     vocab_path = self._make_asset("alpha\nbeta\ngamma\n")
     initializer = lookup_ops.TextFileInitializer(
         vocab_path,
         key_dtype=dtypes.string,
         key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
         value_dtype=dtypes.int64,
         value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
     root = util.Checkpoint(
         table=lookup_ops.HashTable(initializer, default_value=-1))
     root.table_user = def_function.function(
         root.table.lookup,
         input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
     self.assertEqual(
         2,
         root.table_user(constant_op.constant("gamma")).numpy())
     imported = self.cycle(root, cycles)
     self.assertEqual(
         2,
         imported.table_user(constant_op.constant("gamma")).numpy())
 def testUsageGraph(self):
     """Expected usage when graph building."""
     with context.graph_mode():
         num_training_steps = 10
         checkpoint_directory = self.get_temp_dir()
         checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
         for training_continuation in range(3):
             with ops.Graph().as_default():
                 model = MyModel()
                 optimizer = adam.AdamOptimizer(0.001)
                 root = checkpointable_utils.Checkpoint(
                     optimizer=optimizer,
                     model=model,
                     global_step=training_util.get_or_create_global_step())
                 input_value = constant_op.constant([[3.]])
                 train_op = optimizer.minimize(model(input_value),
                                               global_step=root.global_step)
                 checkpoint_path = core_saver.latest_checkpoint(
                     checkpoint_directory)
                 with self.test_session(
                         graph=ops.get_default_graph()) as session:
                     status = root.restore(save_path=checkpoint_path)
                     status.initialize_or_restore(session=session)
                     if checkpoint_path is None:
                         self.assertEqual(0, training_continuation)
                         with self.assertRaises(AssertionError):
                             status.assert_consumed()
                     else:
                         status.assert_consumed()
                     for _ in range(num_training_steps):
                         session.run(train_op)
                     root.save(file_prefix=checkpoint_prefix,
                               session=session)
                     self.assertEqual(
                         (training_continuation + 1) * num_training_steps,
                         session.run(root.global_step))
                     self.assertEqual(training_continuation + 1,
                                      session.run(root.save_counter))
Ejemplo n.º 7
0
 def testCustomNumbering(self):
   directory = self.get_temp_dir()
   step = variables.Variable(0, dtype=dtypes.int64)
   checkpoint = util.Checkpoint(step=step)
   manager = checkpoint_management.CheckpointManager(
       checkpoint, directory, max_to_keep=2)
   self.evaluate(step.initializer)
   for i in range(5):
     path = manager.save(checkpoint_number=step)
     expected_suffix = "-%d" % (2 * i,)
     if not path.endswith(expected_suffix):
       self.fail("%s should have suffix %s" % (path, expected_suffix))
     self.evaluate(step.assign_add(2))
   self.assertEqual(5, self.evaluate(checkpoint.save_counter))
   # Test regular integers
   last_path = manager.save(checkpoint_number=32)
   self.assertIn("-32", last_path)
   self.assertEqual(last_path, manager.latest_checkpoint)
   self.assertEqual(
       last_path, checkpoint_management.latest_checkpoint(directory))
   state = checkpoint_management.get_checkpoint_state(directory)
   # Only the most recent two checkpoints are saved
   self.assertEqual([path, last_path], state.all_model_checkpoint_paths)
Ejemplo n.º 8
0
 def testDeferredRestorationUsageEager(self):
   """An idiomatic eager execution example."""
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   for training_continuation in range(3):
     with self.test_scope():
       model = Subclassed()
       optimizer = adam.Adam(0.001)
       root = checkpointable_utils.Checkpoint(
           optimizer=optimizer, model=model)
       manager = checkpoint_management.CheckpointManager(
           root, checkpoint_directory, max_to_keep=2)
       root.restore(manager.latest_checkpoint)
       for _ in range(num_training_steps):
         input_value = constant_op.constant([[3.]])
         with backprop.GradientTape() as tape:
           loss = model(input_value)
         variables = model.trainable_variables
         gradients = tape.gradient(loss, variables)
         optimizer.apply_gradients(zip(gradients, variables))
       manager.save()
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        root.optimizer.iterations.numpy())
Ejemplo n.º 9
0
    def testGraphDistributionStrategy(self):
        self.skipTest("b/121381184")
        num_training_steps = 10
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

        def _train_fn(optimizer, model):
            input_value = constant_op.constant([[3.]])
            return optimizer.minimize(functools.partial(model, input_value),
                                      global_step=root.optimizer_step)

        for training_continuation in range(3):
            with ops.Graph().as_default():
                strategy = mirrored_strategy.MirroredStrategy()
                with strategy.scope():
                    model = MyModel()
                    optimizer = adam.AdamOptimizer(0.001)
                    root = checkpointable_utils.Checkpoint(
                        optimizer=optimizer,
                        model=model,
                        optimizer_step=training_util.get_or_create_global_step(
                        ))
                    status = root.restore(
                        checkpoint_management.latest_checkpoint(
                            checkpoint_directory))
                    train_op = strategy.extended.call_for_each_replica(
                        functools.partial(_train_fn, optimizer, model))
                    with self.session() as session:
                        if training_continuation > 0:
                            status.assert_consumed()
                        status.initialize_or_restore()
                        for _ in range(num_training_steps):
                            session.run(train_op)
                        root.save(file_prefix=checkpoint_prefix)
                self.assertEqual(
                    (training_continuation + 1) * num_training_steps,
                    root.optimizer_step.numpy())
    def testMultipleGraphsNonSlotVariables(self):
        with context.graph_mode():
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            optimizer = adam.AdamOptimizer(0.001)
            # Construct a model in one graph
            first_graph = ops.Graph()
            first_session = session_lib.Session(graph=first_graph)
            with first_graph.as_default(), first_session.as_default():
                first_variable = resource_variable_ops.ResourceVariable([1.])
                first_root_checkpointable = checkpointable_utils.Checkpoint(
                    optimizer=optimizer, variable=first_variable)
                train_op = optimizer.minimize(first_variable.read_value)
                self.evaluate(
                    checkpointable_utils.gather_initializers(
                        first_root_checkpointable))
                self.evaluate(train_op)
                self.evaluate(first_variable.assign([1.]))
                self.evaluate(
                    optimizer.get_slot(var=first_variable,
                                       name="m").assign([2.]))
                beta1_power, _ = optimizer._get_beta_accumulators()
                self.evaluate(beta1_power.assign(3.))

            # Save and load in a second graph
            second_graph = ops.Graph()
            with second_graph.as_default(), session_lib.Session(
                    graph=second_graph):
                second_variable = resource_variable_ops.ResourceVariable([1.])
                second_root_checkpointable = checkpointable_utils.Checkpoint(
                    optimizer=optimizer, variable=second_variable)
                train_op = optimizer.minimize(second_variable.read_value)
                second_root_checkpointable.restore(
                    None).initialize_or_restore()
                self.evaluate(train_op)
                self.evaluate(second_variable.assign([4.]))
                self.evaluate(
                    optimizer.get_slot(var=second_variable,
                                       name="m").assign([5.]))
                beta1_power, _ = optimizer._get_beta_accumulators()
                self.evaluate(beta1_power.assign(6.))
                save_path = second_root_checkpointable.save(checkpoint_prefix)
                self.evaluate(second_variable.assign([7.]))
                self.evaluate(
                    optimizer.get_slot(var=second_variable,
                                       name="m").assign([8.]))
                beta1_power, _ = optimizer._get_beta_accumulators()
                self.assertAllEqual(6., self.evaluate(beta1_power))
                status = second_root_checkpointable.restore(save_path)
                status.assert_consumed().run_restore_ops()
                self.assertAllEqual([4.], self.evaluate(second_variable))
                self.assertAllEqual([5.],
                                    self.evaluate(
                                        optimizer.get_slot(var=second_variable,
                                                           name="m")))
                beta1_power, _ = optimizer._get_beta_accumulators()
                self.assertAllEqual(6., self.evaluate(beta1_power))

            # Check that the first graph is unmolested
            with first_graph.as_default(), first_session.as_default():
                self.assertAllEqual([1.], self.evaluate(first_variable))
                self.assertAllEqual([2.],
                                    self.evaluate(
                                        optimizer.get_slot(var=first_variable,
                                                           name="m")))
                beta1_power, _ = optimizer._get_beta_accumulators()
                self.assertAllEqual(3., self.evaluate(beta1_power))
    def testDeferredSlotRestoration(self):
        checkpoint_directory = self.get_temp_dir()

        root = checkpointable.Checkpointable()
        root.var = checkpointable_utils.add_variable(root,
                                                     name="var",
                                                     initializer=0.)
        optimizer = adam.AdamOptimizer(0.1)
        if context.executing_eagerly():
            optimizer.minimize(root.var.read_value)
        else:
            train_op = optimizer.minimize(root.var)
            # Note that `optimizer` has not been added as a dependency of
            # `root`. Create a one-off grouping so that slot variables for `root.var`
            # get initialized too.
            self.evaluate(
                checkpointable_utils.gather_initializers(
                    checkpointable_utils.Checkpoint(root=root,
                                                    optimizer=optimizer)))
            self.evaluate(train_op)
        self.evaluate(state_ops.assign(root.var, 12.))
        no_slots_path = checkpointable_utils.CheckpointableSaver(root).save(
            os.path.join(checkpoint_directory, "no_slots"))
        root.optimizer = optimizer
        self.evaluate(state_ops.assign(root.var, 13.))
        self.evaluate(
            state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.))
        slots_path = checkpointable_utils.CheckpointableSaver(root).save(
            os.path.join(checkpoint_directory, "with_slots"))
        new_root = checkpointable.Checkpointable()
        # Load the slot-containing checkpoint (deferred), then immediately overwrite
        # the non-slot variable (also deferred).
        slot_status = checkpointable_utils.CheckpointableSaver(
            new_root).restore(slots_path)
        no_slot_status = checkpointable_utils.CheckpointableSaver(
            new_root).restore(no_slots_path)
        with self.assertRaises(AssertionError):
            no_slot_status.assert_consumed()
        new_root.var = checkpointable_utils.add_variable(new_root,
                                                         name="var",
                                                         shape=[])
        no_slot_status.assert_consumed()
        no_slot_status.run_restore_ops()
        self.assertEqual(12., self.evaluate(new_root.var))
        new_root.optimizer = adam.AdamOptimizer(0.1)
        with self.assertRaisesRegexp(AssertionError, "beta1_power"):
            slot_status.assert_consumed()
        self.assertEqual(12., self.evaluate(new_root.var))
        if context.executing_eagerly():
            # Slot variables are only created with restoring initializers when
            # executing eagerly.
            self.assertEqual(
                14.,
                self.evaluate(
                    new_root.optimizer.get_slot(name="m", var=new_root.var)))
        else:
            self.assertIs(
                new_root.optimizer.get_slot(name="m", var=new_root.var), None)
        if context.executing_eagerly():
            new_root.optimizer.minimize(new_root.var.read_value)
        else:
            train_op = new_root.optimizer.minimize(new_root.var)
            # The slot variable now exists; restore() didn't create it, but we should
            # now have a restore op for it.
            slot_status.run_restore_ops()
            self.assertEqual(
                14.,
                self.evaluate(
                    new_root.optimizer.get_slot(name="m", var=new_root.var)))
            self.evaluate(train_op)
        slot_status.assert_consumed()
 def testSaveRestore(self):
     model = MyModel()
     optimizer = adam.AdamOptimizer(0.001)
     root_checkpointable = checkpointable_utils.Checkpoint(
         optimizer=optimizer, model=model)
     input_value = constant_op.constant([[3.]])
     if context.executing_eagerly():
         optimizer.minimize(lambda: model(input_value))
     else:
         train_op = optimizer.minimize(model(input_value))
         # TODO(allenl): Make initialization more pleasant when graph building.
         root_checkpointable.save_counter  # pylint: disable=pointless-statement
         self.evaluate(
             checkpointable_utils.gather_initializers(root_checkpointable))
         self.evaluate(train_op)
     prefix = os.path.join(self.get_temp_dir(), "ckpt")
     self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.]))
     m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m")
     self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
     save_path = root_checkpointable.save(file_prefix=prefix)
     self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.]))
     self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3))
     optimizer_variables = self.evaluate(optimizer.variables())
     self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
     # Immediate restoration
     status = root_checkpointable.restore(
         save_path=save_path).assert_consumed()
     status.run_restore_ops()
     self.assertAllEqual([42.],
                         self.evaluate(model._named_dense.variables[1]))
     self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter))
     self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
     if not context.executing_eagerly():
         return  # Restore-on-create is only supported when executing eagerly
     on_create_model = MyModel()
     on_create_optimizer = adam.AdamOptimizer(
         0.001,
         # Preserve beta1_power and beta2_power when appying gradients so we can
         # test that they've been restored correctly.
         beta1=1.0,
         beta2=1.0)
     on_create_root = checkpointable_utils.Checkpoint(
         optimizer=on_create_optimizer, model=on_create_model)
     # Deferred restoration
     status = on_create_root.restore(save_path=save_path)
     on_create_model(constant_op.constant([[3.]]))  # create variables
     self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
     self.assertAllEqual([42.],
                         self.evaluate(
                             on_create_model._named_dense.variables[1]))
     on_create_m_bias_slot = on_create_optimizer.get_slot(
         on_create_model._named_dense.variables[1], "m")
     # Optimizer slot variables are created when the original variable is
     # restored.
     self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
     self.assertAllEqual(optimizer_variables[2:],
                         self.evaluate(on_create_optimizer.variables()))
     dummy_var = resource_variable_ops.ResourceVariable([1.])
     on_create_optimizer.minimize(loss=dummy_var.read_value)
     status.assert_consumed()
     beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators()
     self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power))
     self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power))
 def testNamingWithOptimizer(self):
     input_value = constant_op.constant([[3.]])
     model = MyModel()
     # A nuisance Model using the same optimizer. Its slot variables should not
     # go in the checkpoint, since it is never depended on.
     other_model = MyModel()
     optimizer = adam.AdamOptimizer(0.001)
     optimizer_step = training_util.get_or_create_global_step()
     root_checkpointable = checkpointable_utils.Checkpoint(
         optimizer=optimizer, model=model, optimizer_step=optimizer_step)
     if context.executing_eagerly():
         optimizer.minimize(lambda: model(input_value),
                            global_step=optimizer_step)
         optimizer.minimize(lambda: other_model(input_value),
                            global_step=optimizer_step)
     else:
         train_op = optimizer.minimize(model(input_value),
                                       global_step=optimizer_step)
         optimizer.minimize(other_model(input_value),
                            global_step=optimizer_step)
         self.evaluate(
             checkpointable_utils.gather_initializers(root_checkpointable))
         self.evaluate(train_op)
     named_variables, serialized_graph, _ = (
         checkpointable_utils._serialize_object_graph(root_checkpointable,
                                                      saveables_cache=None))
     expected_checkpoint_names = (
         # Created in the root node, so no prefix.
         "optimizer_step",
         "model/_second/kernel",
         "model/_named_dense/kernel",
         "model/_named_dense/bias",
         # non-Layer dependency of the model
         "model/_non_layer/a_variable",
         # The optimizer creates two non-slot variables
         "optimizer/beta1_power",
         "optimizer/beta2_power",
         # Slot variables
         "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
         "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
         "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
         "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
         "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
         "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
     )
     suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
     expected_checkpoint_names = [
         name + suffix for name in expected_checkpoint_names
     ]
     # The Dense layers also save get_config() JSON
     expected_checkpoint_names.extend([
         "model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
         "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"
     ])
     named_variables = {v.name: v for v in named_variables}
     six.assertCountEqual(self, expected_checkpoint_names,
                          named_variables.keys())
     # Check that we've mapped to the right variable objects (not exhaustive)
     self.assertEqual("global_step",
                      named_variables["optimizer_step" + suffix].full_name)
     self.assertEqual(
         "my_model/dense_1/kernel",
         named_variables["model/_second/kernel" + suffix].full_name)
     self.assertEqual(
         "my_model/dense/kernel",
         named_variables["model/_named_dense/kernel" + suffix].full_name)
     self.assertEqual(
         "beta1_power",
         named_variables["optimizer/beta1_power" + suffix].full_name)
     self.assertEqual(
         "beta2_power",
         named_variables["optimizer/beta2_power" + suffix].full_name)
     # Spot check the generated protocol buffers.
     self.assertEqual("optimizer",
                      serialized_graph.nodes[0].children[1].local_name)
     optimizer_node = serialized_graph.nodes[
         serialized_graph.nodes[0].children[1].node_id]
     self.assertEqual("beta1_power", optimizer_node.children[0].local_name)
     self.assertEqual(
         "beta1_power", serialized_graph.nodes[
             optimizer_node.children[0].node_id].attributes[0].full_name)
     self.assertEqual(
         "my_model/dense/kernel",
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].original_variable_node_id].attributes[0].full_name)
     # We strip off the :0 suffix, as variable.name-based saving does.
     self.assertEqual(
         "my_model/dense/kernel/Adam",
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].slot_variable_node_id].attributes[0].full_name)
     self.assertEqual(
         "my_model/dense/kernel/Adam:0",
         optimizer.get_slot(var=model._named_dense.kernel, name="m").name)
     self.assertEqual(
         "model/_named_dense/kernel" + suffix,
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].original_variable_node_id].attributes[0].checkpoint_key)
     self.assertEqual("m", optimizer_node.slot_variables[0].slot_name)
     self.assertEqual(
         "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix,
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].slot_variable_node_id].attributes[0].checkpoint_key)
  def testSaveRestoreState(self, mock_time):
    directory = self.get_temp_dir()
    mock_time.time.return_value = 3.
    checkpoint = util.Checkpoint()
    first_manager = checkpoint_management.CheckpointManager(
        checkpoint, directory, max_to_keep=2)
    first_time = 10000.
    first_name = os.path.join(directory, "ckpt-1")
    mock_time.time.return_value = first_time
    first_manager.save()
    state = checkpoint_management.get_checkpoint_state(directory)
    self.assertEqual([first_time], state.all_model_checkpoint_timestamps)
    self.assertEqual(3., state.last_preserved_timestamp)
    second_time = first_time + 3610.
    second_name = os.path.join(directory, "ckpt-2")
    mock_time.time.return_value = second_time
    first_manager.save()
    state = checkpoint_management.get_checkpoint_state(directory)
    self.assertEqual([first_time, second_time],
                     state.all_model_checkpoint_timestamps)
    self.assertEqual(3., state.last_preserved_timestamp)
    self.assertEqual([first_name, second_name], first_manager.checkpoints)
    self.assertEqual(second_name, first_manager.latest_checkpoint)
    del first_manager

    second_manager = checkpoint_management.CheckpointManager(
        checkpoint, directory,
        max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
    self.assertEqual([first_name, second_name], second_manager.checkpoints)
    self.assertEqual(second_name, second_manager.latest_checkpoint)
    third_name = os.path.join(directory, "ckpt-3")
    third_time = second_time + 3600. * 0.2
    mock_time.time.return_value = third_time
    second_manager.save()
    self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
    self.assertTrue(checkpoint_management.checkpoint_exists(second_name))
    self.assertEqual([second_name, third_name],
                     second_manager.checkpoints)
    state = checkpoint_management.get_checkpoint_state(directory)
    self.assertEqual(first_time, state.last_preserved_timestamp)
    fourth_time = third_time + 3600. * 0.5
    mock_time.time.return_value = fourth_time
    fourth_name = os.path.join(directory, "ckpt-4")
    second_manager.save()
    self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
    self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
    self.assertEqual([third_name, fourth_name],
                     second_manager.checkpoints)
    fifth_time = fourth_time + 3600. * 0.5
    mock_time.time.return_value = fifth_time
    fifth_name = os.path.join(directory, "ckpt-5")
    second_manager.save()
    self.assertEqual([fourth_name, fifth_name],
                     second_manager.checkpoints)
    state = checkpoint_management.get_checkpoint_state(directory)
    self.assertEqual(first_time, state.last_preserved_timestamp)
    del second_manager
    third_manager = checkpoint_management.CheckpointManager(
        checkpoint, directory,
        max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
    self.assertEqual(fifth_name, third_manager.latest_checkpoint)
    mock_time.time.return_value += 10.
    third_manager.save()
    sixth_name = os.path.join(directory, "ckpt-6")
    state = checkpoint_management.get_checkpoint_state(directory)
    self.assertEqual(fourth_time, state.last_preserved_timestamp)
    self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
    self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name))
    self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name))
    self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name))
    self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
    self.assertFalse(checkpoint_management.checkpoint_exists(third_name))
    self.assertEqual([fifth_name, sixth_name],
                     third_manager.checkpoints)
  def test_initialize_if_not_restoring(self):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    optimizer_only_prefix = os.path.join(checkpoint_directory, "opt")
    with test_util.device(use_gpu=True):
      model = MyModel()
      optimizer = adam.AdamOptimizer(0.001)
      root = checkpointable_utils.Checkpoint(
          model=model,  # Do not save the optimizer with the checkpoint.
          global_step=training_util.get_or_create_global_step())
      optimizer_checkpoint = checkpointable_utils.Checkpoint(
          optimizer=optimizer)

      checkpoint_path = checkpoint_management.latest_checkpoint(
          checkpoint_directory)
      status = root.restore(save_path=checkpoint_path)
      input_value = constant_op.constant([[3.]])
      train_fn = functools.partial(
          optimizer.minimize,
          functools.partial(model, input_value),
          global_step=root.global_step)
      if not context.executing_eagerly():
        train_fn = functools.partial(self.evaluate, train_fn())
      status.initialize_or_restore()
      self.evaluate([v.initializer for v in optimizer.variables()])
      train_fn()
      model_save_path = root.save(file_prefix=checkpoint_prefix)
      self.evaluate(optimizer.variables()[0].assign(42.))
      optimizer_save_path = optimizer_checkpoint.save(optimizer_only_prefix)

    # Restore into a graph with the optimizer
    with test_util.device(use_gpu=True):
      model = MyModel()
      optimizer = adam.AdamOptimizer(0.001)
      root = checkpointable_utils.Checkpoint(
          optimizer=optimizer, model=model,
          global_step=training_util.get_or_create_global_step())
      status = root.restore(save_path=model_save_path)
      input_value = constant_op.constant([[3.]])
      train_fn = functools.partial(
          optimizer.minimize,
          functools.partial(model, input_value),
          global_step=root.global_step)
      if not context.executing_eagerly():
        train_fn = functools.partial(self.evaluate, train_fn())
      status.initialize_or_restore()
      train_fn()
      with self.assertRaises(AssertionError):
        status.assert_existing_objects_matched()
      with self.assertRaises(AssertionError):
        status.assert_consumed()

    # Make sure initialization doesn't clobber later restores
    with test_util.device(use_gpu=True):
      model = MyModel()
      optimizer = adam.AdamOptimizer(0.001, beta1=1.0)
      root = checkpointable_utils.Checkpoint(
          optimizer=optimizer, model=model,
          global_step=training_util.get_or_create_global_step())
      opt_root = checkpointable_utils.Checkpoint(
          optimizer=optimizer)
      status = root.restore(save_path=model_save_path)
      init_only_optimizer_status = opt_root.restore(save_path=None)
      optimizer_status = opt_root.restore(save_path=optimizer_save_path)
      input_value = constant_op.constant([[3.]])
      train_fn = functools.partial(
          optimizer.minimize,
          functools.partial(model, input_value),
          global_step=root.global_step)
      if not context.executing_eagerly():
        train_fn = functools.partial(self.evaluate, train_fn())
      optimizer_status.run_restore_ops()
      status.initialize_or_restore()
      init_only_optimizer_status.initialize_or_restore()
      train_fn()
      self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
Ejemplo n.º 16
0
 def test_checkpointing_not_implemented(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint = checkpointable_utils.Checkpoint(net=MyNetwork())
   with self.assertRaises(NotImplementedError):
     checkpoint.save(checkpoint_directory)