def test_metrics_v1(self): api_label = util._CHECKPOINT_V1 prefix = os.path.join(self.get_temp_dir(), 'ckpt') with self.cached_session(): ckpt = util.CheckpointV1() v = variables_lib.Variable(1.) self.evaluate(v.initializer) ckpt.v = v self.assertEqual(self._get_time_saved(api_label), 0.0) self.assertEqual(self._get_write_histogram_proto(api_label).num, 0.0) for i in range(3): time_saved = self._get_time_saved(api_label) ckpt_path = ckpt.write(file_prefix=prefix) filesize = util._get_checkpoint_size(ckpt_path) self.assertEqual(self._get_checkpoint_size(api_label, filesize), i + 1) self.assertGreaterEqual(self._get_time_saved(api_label), time_saved) self.assertEqual(self._get_write_histogram_proto(api_label).num, 3.0) self.assertEqual(self._get_read_histogram_proto(api_label).num, 0.0) time_saved = self._get_time_saved(api_label) ckpt.restore(ckpt_path) self.assertEqual(self._get_read_histogram_proto(api_label).num, 1.0) # Restoring a checkpoint in the same "job" does not increase training time # saved. self.assertEqual(self._get_time_saved(api_label), time_saved)
def testUsageGraph(self): """Expected usage when graph building.""" with context.graph_mode(): num_training_steps = 10 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") for training_continuation in range(3): with ops.Graph().as_default(): model = MyModel() optimizer = adam.AdamOptimizer(0.001) root = util.CheckpointV1( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) input_value = constant_op.constant([[3.]]) train_op = optimizer.minimize( model(input_value), global_step=root.global_step) checkpoint_path = checkpoint_management.latest_checkpoint( checkpoint_directory) with self.session(graph=ops.get_default_graph()) as session: status = root.restore(save_path=checkpoint_path) status.initialize_or_restore(session=session) if checkpoint_path is None: self.assertEqual(0, training_continuation) with self.assertRaises(AssertionError): status.assert_consumed() else: status.assert_consumed() for _ in range(num_training_steps): session.run(train_op) root.save(file_prefix=checkpoint_prefix, session=session) self.assertEqual((training_continuation + 1) * num_training_steps, session.run(root.global_step)) self.assertEqual(training_continuation + 1, session.run(root.save_counter))
def testUsageGraph(self): """Expected usage when graph building.""" with context.graph_mode(): num_training_steps = 10 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") for training_continuation in range(3): with ops.Graph().as_default(): model = MyModel() optimizer = adam.Adam(0.001) root = trackable_utils.CheckpointV1(optimizer=optimizer, model=model) input_value = constant_op.constant([[3.]]) with backprop.GradientTape() as tape: loss = model(input_value) variables = model.trainable_variables gradients = tape.gradient(loss, variables) train_op = optimizer.apply_gradients( zip(gradients, variables)) checkpoint_path = checkpoint_management.latest_checkpoint( checkpoint_directory) with self.session( graph=ops.get_default_graph()) as session: status = root.restore(save_path=checkpoint_path) status.initialize_or_restore(session=session) if checkpoint_path is None: self.assertEqual(0, training_continuation) with self.assertRaises(AssertionError): status.assert_consumed() with self.assertRaises(AssertionError): status.assert_existing_objects_matched() else: status.assert_consumed() status.assert_existing_objects_matched() for _ in range(num_training_steps): session.run(train_op) root.save(file_prefix=checkpoint_prefix, session=session) self.assertEqual( (training_continuation + 1) * num_training_steps, session.run(root.optimizer.iterations)) self.assertEqual(training_continuation + 1, session.run(root.save_counter))
def test_metrics_v1(self): label = 'V1' prefix = os.path.join(self.get_temp_dir(), 'ckpt') with self.cached_session(): ckpt = util.CheckpointV1() v = variables_lib.Variable(1.) self.evaluate(v.initializer) ckpt.v = v self.assertEqual(self._get_time_saved(label), 0.0) self.assertEqual(self._get_write_durations(label).num, 0.0) for _ in range(3): time_saved = self._get_time_saved(label) ckpt_path = ckpt.write(file_prefix=prefix) self.assertGreaterEqual(self._get_time_saved(label), time_saved) self.assertEqual(self._get_write_durations(label).num, 3.0) self.assertEqual(self._get_read_durations(label).num, 0.0) time_saved = self._get_time_saved(label) ckpt.restore(ckpt_path) self.assertEqual(self._get_read_durations(label).num, 1.0) # Restoring a checkpoint does not increase training time saved. self.assertEqual(self._get_time_saved(label), time_saved)