def test_latest_checkpoint(self): prefix = 'chkpnt_' latest_checkpoint = checkpoint_utils.latest_checkpoint( self.get_temp_dir(), prefix) self.assertIsNone(latest_checkpoint) # Create checkpoints and ensure that the latest checkpoint found is # always the most recently created path. state = build_fake_state() for round_num in range(5): export_dir = os.path.join(self.get_temp_dir(), '{}{:03d}'.format(prefix, round_num)) checkpoint_utils.save(state, export_dir) latest_checkpoint_path = checkpoint_utils.latest_checkpoint( self.get_temp_dir(), prefix) self.assertEndsWith( latest_checkpoint_path, '{:03d}'.format(round_num), msg=latest_checkpoint_path) # Delete the checkpoints in reverse order and ensure the latest checkpoint # decreases. for round_num in reversed(range(2, 5)): export_dir = os.path.join(self.get_temp_dir(), '{}{:03d}'.format(prefix, round_num)) tf.io.gfile.rmtree(export_dir) latest_checkpoint_path = checkpoint_utils.latest_checkpoint( self.get_temp_dir(), prefix) self.assertEndsWith( latest_checkpoint_path, '{:03d}'.format(round_num - 1), msg=latest_checkpoint_path)
def test_load_latest_state(self): server_optimizer_fn = functools.partial(tf.keras.optimizers.SGD, learning_rate=0.1, momentum=0.9) iterative_process = tff.learning.build_federated_averaging_process( models.model_fn, server_optimizer_fn=server_optimizer_fn) server_state = iterative_process.initialize() # TODO(b/130724878): These conversions should not be needed. obj_1 = Obj.from_anon_tuple(server_state, 1) export_dir = os.path.join(self.get_temp_dir(), 'ckpt_1') checkpoint_utils.save(obj_1, export_dir) # TODO(b/130724878): These conversions should not be needed. obj_2 = Obj.from_anon_tuple(server_state, 2) export_dir = os.path.join(self.get_temp_dir(), 'ckpt_2') checkpoint_utils.save(obj_2, export_dir) export_dir = checkpoint_utils.latest_checkpoint(self.get_temp_dir()) loaded_obj = checkpoint_utils.load(export_dir, obj_1) self.assertEqual(os.path.join(self.get_temp_dir(), 'ckpt_2'), export_dir) self.assertAllClose(tf.nest.flatten(obj_2), tf.nest.flatten(loaded_obj))
def maybe_read_latest_checkpoint(root_checkpoint_dir, server_state): """Returns server_state, round_num, possibly from a recent checkpoint.""" if root_checkpoint_dir is None: latest_checkpoint_dir = None else: latest_checkpoint_dir = checkpoint_utils.latest_checkpoint( root_checkpoint_dir, CHECKPOINT_PREFIX) logging.info('Looking for checkpoints in [%s/%s].', root_checkpoint_dir, CHECKPOINT_PREFIX) if latest_checkpoint_dir is None: write_checkpoint(root_checkpoint_dir, server_state, 0) logging.info('No previous checkpoints, initializing experiment.') return server_state, 0 else: server_state, round_num = read_checkpoint(latest_checkpoint_dir, server_state) round_num = int(round_num.numpy()) logging.info('Restarting from checkpoint round %d.', round_num) return server_state, round_num
def federated_averaging_training_loop(model_fn, server_optimizer_fn, client_datasets_fn, total_rounds=500, rounds_per_eval=1, metrics_hook=lambda *args: None): """A simple example of training loop for the Federated Averaging algorithm. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer`. client_datasets_fn: A function that takes the round number, and returns a list of `tf.data.Datset`, one per client. total_rounds: Number of rounds to train. rounds_per_eval: How often to call the `metrics_hook` function. metrics_hook: A function taking arguments (server_state, train_metrics, round_num) and performs evaluation. Optional. Returns: Final `ServerState`. """ logging.info('Starting federated_training_loop') checkpoint_dir = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name) if FLAGS.server_optimizer != 'flars': iterative_process = simple_fedavg.build_federated_averaging_process( model_fn, server_optimizer_fn=server_optimizer_fn) ServerState = simple_fedavg.ServerState # pylint: disable=invalid-name else: iterative_process = flars_fedavg.build_federated_averaging_process( model_fn, server_optimizer_fn=server_optimizer_fn) ServerState = flars_fedavg.ServerState # pylint: disable=invalid-name server_state = ServerState.from_anon_tuple(iterative_process.initialize()) round_num = None train_metrics = {} latest_checkpoint_dir = checkpoint_utils.latest_checkpoint( checkpoint_dir, CHECKPOINT_PREFIX) logging.info('Looking for checkpoints in [%s/%s]', checkpoint_dir, CHECKPOINT_PREFIX) while latest_checkpoint_dir is not None: # Restart from a previous round. logging.info('Loading a previous checkpoint') try: server_state, metrics_hook.results, round_num = read_checkpoint( latest_checkpoint_dir, server_state) break except OSError as e: # Likely corrupted checkpoint, possibly job died while writing. Delete the # checkpoint directory and try again. logging.error('Exception [%s]', e) logging.warning('Deleteing likely corrupted checkpoint at [%s]', latest_checkpoint_dir) tf.io.gfile.rmtree(latest_checkpoint_dir) latest_checkpoint_dir = checkpoint_utils.latest_checkpoint( checkpoint_dir, CHECKPOINT_PREFIX) if round_num is not None: logging.info('Restarted from checkpoint round %d', round_num) else: # Write the initial checkpoint logging.info('No previous checkpoints, initializing experiment') round_num = 0 metrics_hook(server_state, train_metrics, round_num) write_checkpoint(checkpoint_dir, server_state, metrics_hook.results, round_num) while round_num < total_rounds: round_num += 1 # Reset the executor to clear the cache, and clear the default graph to # garbage collect tf.Functions that will no longer be used. tff.framework.set_default_executor( tff.framework.create_local_executor(max_fanout=25)) tf.compat.v1.reset_default_graph() round_start_time = time.time() data_prep_start_time = time.time() federated_train_data = client_datasets_fn(round_num) train_metrics['prepare_datasets_secs'] = time.time( ) - data_prep_start_time training_start_time = time.time() anon_tuple_server_state, tff_train_metrics = iterative_process.next( server_state, federated_train_data) server_state = ServerState.from_anon_tuple(anon_tuple_server_state) train_metrics.update(tff_train_metrics._asdict(recursive=True)) train_metrics['training_secs'] = time.time() - training_start_time logging.info('Round {:2d} elapsed time: {:.2f}s .'.format( round_num, (time.time() - round_start_time))) train_metrics['total_round_secs'] = time.time() - round_start_time if round_num % FLAGS.rounds_per_checkpoint == 0: write_checkpoint_start_time = time.time() write_checkpoint(checkpoint_dir, server_state, metrics_hook.results, round_num) train_metrics['write_checkpoint_secs'] = ( time.time() - write_checkpoint_start_time) if round_num % rounds_per_eval == 0: metrics_hook(server_state, train_metrics, round_num) metrics_hook(server_state, train_metrics, total_rounds) write_checkpoint(checkpoint_dir, server_state, metrics_hook.results, round_num) return server_state