def test_latest_checkpoint(self): prefix = 'chkpnt_' latest_checkpoint = checkpoint_utils.latest_checkpoint( self.get_temp_dir(), prefix) self.assertIsNone(latest_checkpoint) # Create checkpoints and ensure that the latest checkpoint found is # always the most recently created path. state = build_fake_state() for round_num in range(5): export_dir = os.path.join(self.get_temp_dir(), '{}{:03d}'.format(prefix, round_num)) checkpoint_utils.save(state, export_dir) latest_checkpoint_path = checkpoint_utils.latest_checkpoint( self.get_temp_dir(), prefix) self.assertEndsWith( latest_checkpoint_path, '{:03d}'.format(round_num), msg=latest_checkpoint_path) # Delete the checkpoints in reverse order and ensure the latest checkpoint # decreases. for round_num in reversed(range(2, 5)): export_dir = os.path.join(self.get_temp_dir(), '{}{:03d}'.format(prefix, round_num)) tf.io.gfile.rmtree(export_dir) latest_checkpoint_path = checkpoint_utils.latest_checkpoint( self.get_temp_dir(), prefix) self.assertEndsWith( latest_checkpoint_path, '{:03d}'.format(round_num - 1), msg=latest_checkpoint_path)
def test_save_and_load_roundtrip(self): state = build_fake_state() export_dir = os.path.join(self.get_temp_dir(), 'ckpt_1') checkpoint_utils.save(state, export_dir) loaded_state = checkpoint_utils.load(export_dir, state) self.assertEqual(state, loaded_state)
def test_load_latest_state(self): server_optimizer_fn = functools.partial(tf.keras.optimizers.SGD, learning_rate=0.1, momentum=0.9) iterative_process = tff.learning.build_federated_averaging_process( models.model_fn, server_optimizer_fn=server_optimizer_fn) server_state = iterative_process.initialize() # TODO(b/130724878): These conversions should not be needed. obj_1 = Obj.from_anon_tuple(server_state, 1) export_dir = os.path.join(self.get_temp_dir(), 'ckpt_1') checkpoint_utils.save(obj_1, export_dir) # TODO(b/130724878): These conversions should not be needed. obj_2 = Obj.from_anon_tuple(server_state, 2) export_dir = os.path.join(self.get_temp_dir(), 'ckpt_2') checkpoint_utils.save(obj_2, export_dir) export_dir = checkpoint_utils.latest_checkpoint(self.get_temp_dir()) loaded_obj = checkpoint_utils.load(export_dir, obj_1) self.assertEqual(os.path.join(self.get_temp_dir(), 'ckpt_2'), export_dir) self.assertAllClose(tf.nest.flatten(obj_2), tf.nest.flatten(loaded_obj))
def write_checkpoint(root_checkpoint_dir, server_state, round_num): """Write the current experiment state to disk.""" if root_checkpoint_dir is None: return if tf.io.gfile.exists(root_checkpoint_dir): # Clean-up old checkpoints if more than 5 exist, not including the # original (which captures random model initialization). checkpoints = sorted( tf.io.gfile.glob( os.path.join(root_checkpoint_dir, CHECKPOINT_PREFIX + '*'))) to_remove = checkpoints[1:-1] logging.info('Cleaning up %s', to_remove) for checkpoint in to_remove: tf.io.gfile.rmtree(checkpoint) state = ExperimentState(round_num, server_state) checkpoint_dir = os.path.join(root_checkpoint_dir, '{}{:04d}'.format(CHECKPOINT_PREFIX, round_num)) checkpoint_utils.save(state, checkpoint_dir)
def write_checkpoint(checkpoint_dir, server_state, metrics_dataframe, round_num): """Write the current experiment state to disk.""" # Clean-up old checkpoints if more than 5 exist; but not the initialization # checkpoint. checkpoints = sorted( tf.io.gfile.glob(os.path.join(checkpoint_dir, 'ckpt_*'))) for checkpoint in checkpoints[1:-3]: tf.io.gfile.rmtree(checkpoint) # We must flatten the pd.Dataframe to a single string, otherwise we don't # know the nested structure (how many rounds have passed) to unpack # in `checkpoint_utils.load()` during `read_checkpoint`. csv_string = io.StringIO() metrics_dataframe.to_csv(csv_string, header=True) experiment_state = ExperimentState( round_num=round_num, metrics_csv_string=csv_string.getvalue(), server_state=server_state) checkpoint_utils.save( experiment_state, os.path.join(checkpoint_dir, 'ckpt_{:03d}'.format(round_num)))