def test_raises_file_not_found_error_with_no_checkpoint(self): temp_dir = self.get_temp_dir() checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir) structure = _create_dummy_state() with self.assertRaises(FileNotFoundError): _ = checkpoint_mngr.load_checkpoint(structure, 0)
def _setup_outputs(root_output_dir, experiment_name, hparam_dict): """Set up directories for experiment loops, write hyperparameters to disk.""" if not experiment_name: raise ValueError('experiment_name must be specified.') create_if_not_exists(root_output_dir) checkpoint_dir = os.path.join(root_output_dir, 'checkpoints', experiment_name) create_if_not_exists(checkpoint_dir) checkpoint_mngr = checkpoint_manager.FileCheckpointManager(checkpoint_dir) results_dir = os.path.join(root_output_dir, 'results', experiment_name) create_if_not_exists(results_dir) metrics_mngr = metrics_manager.ScalarMetricsManager(results_dir) summary_logdir = os.path.join(root_output_dir, 'logdir', experiment_name) create_if_not_exists(summary_logdir) summary_writer = tf.compat.v2.summary.create_file_writer(summary_logdir) hparam_dict['metrics_file'] = metrics_mngr.metrics_filename hparams_file = os.path.join(results_dir, 'hparams.csv') utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file) logging.info('Writing...') logging.info(' checkpoints to: %s', checkpoint_dir) logging.info(' metrics csv to: %s', metrics_mngr.metrics_filename) logging.info(' summaries to: %s', summary_logdir) return checkpoint_mngr, metrics_mngr, summary_writer
def test_saves_one_checkpoint(self): temp_dir = self.get_temp_dir() checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir) dummy_state_1 = _create_dummy_state(1) checkpoint_mngr.save_checkpoint(dummy_state_1, 1) self.assertCountEqual(os.listdir(temp_dir), ['ckpt_1'])
def test_saves_one_checkpoint(self): temp_dir = self.get_temp_dir() checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir) state = _create_dummy_state() checkpoint_mngr.save_checkpoint(state, 1) self.assertEqual(set(os.listdir(temp_dir)), set(['ckpt_1']))
def test_returns_none_and_zero_with_no_checkpoints(self): temp_dir = self.get_temp_dir() checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir) structure = _create_dummy_structure() state, round_num = checkpoint_mngr.load_latest_checkpoint(structure) self.assertIsNone(state) self.assertEqual(round_num, 0)
def test_raises_already_exists_error_with_existing_round_number(self): temp_dir = self.get_temp_dir() checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir) state = _create_dummy_state() checkpoint_mngr.save_checkpoint(state, 1) with self.assertRaises(tf.errors.AlreadyExistsError): checkpoint_mngr.save_checkpoint(state, 1)
def test_raises_value_error_with_bad_structure(self): temp_dir = self.get_temp_dir() checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir) state = _create_dummy_state() checkpoint_mngr.save_checkpoint(state, 1) structure = None with self.assertRaises(ValueError): _, _ = checkpoint_mngr.load_latest_checkpoint(structure)
def test_returns_state_with_one_checkpoint(self): temp_dir = self.get_temp_dir() checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir) dummy_state_1 = _create_dummy_state(1) checkpoint_mngr.save_checkpoint(dummy_state_1, 1) structure = _create_dummy_state() state = checkpoint_mngr.load_checkpoint(structure, 1) self.assertEqual(state, dummy_state_1)
def test_saves_and_returns_structure_and_zero_with_no_checkpoints(self): temp_dir = self.get_temp_dir() checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir) structure = _create_dummy_state() state, round_num = checkpoint_mngr.load_latest_checkpoint_or_default( structure) self.assertEqual(state, structure) self.assertEqual(round_num, 0) self.assertCountEqual(os.listdir(temp_dir), ['ckpt_0'])
def test_returns_state_and_round_num_with_one_checkpoint(self): temp_dir = self.get_temp_dir() checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir) state = _create_dummy_state() checkpoint_mngr.save_checkpoint(state, 1) structure = _create_dummy_structure() state, round_num = checkpoint_mngr.load_latest_checkpoint(structure) expected_state = _create_dummy_state() self.assertEqual(state, expected_state) self.assertEqual(round_num, 1)
def test_returns_default_and_zero_with_no_checkpoints_and_also_saves(self): temp_dir = self.get_temp_dir() checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir) default = _create_dummy_structure() state, round_num = checkpoint_mngr.load_latest_checkpoint_or_default( default) self.assertEqual(state, default) self.assertEqual(round_num, 0) self.assertEqual(set(os.listdir(temp_dir)), set(['ckpt_0']))
def test_saves_three_checkpoints(self): temp_dir = self.get_temp_dir() checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir) dummy_state_1 = _create_dummy_state(1) checkpoint_mngr.save_checkpoint(dummy_state_1, 1) dummy_state_2 = _create_dummy_state(2) checkpoint_mngr.save_checkpoint(dummy_state_2, 2) dummy_state_3 = _create_dummy_state(3) checkpoint_mngr.save_checkpoint(dummy_state_3, 3) self.assertCountEqual(os.listdir(temp_dir), ['ckpt_1', 'ckpt_2', 'ckpt_3'])
def test_removes_oldest_with_keep_first_false(self): temp_dir = self.get_temp_dir() checkpoint_mngr = checkpoint_manager.FileCheckpointManager( temp_dir, keep_total=3, keep_first=False) state = _create_dummy_state() checkpoint_mngr.save_checkpoint(state, 1) checkpoint_mngr.save_checkpoint(state, 2) checkpoint_mngr.save_checkpoint(state, 3) checkpoint_mngr.save_checkpoint(state, 4) self.assertEqual( set(os.listdir(temp_dir)), set(['ckpt_2', 'ckpt_3', 'ckpt_4']))
def test_returns_state_with_three_checkpoint_for_third_round(self): temp_dir = self.get_temp_dir() checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir) dummy_state_1 = _create_dummy_state(1) checkpoint_mngr.save_checkpoint(dummy_state_1, 1) dummy_state_2 = _create_dummy_state(2) checkpoint_mngr.save_checkpoint(dummy_state_2, 2) dummy_state_3 = _create_dummy_state(3) checkpoint_mngr.save_checkpoint(dummy_state_3, 3) structure = _create_dummy_state() state = checkpoint_mngr.load_checkpoint(structure, 3) self.assertEqual(state, dummy_state_3)
def test_returns_state_and_round_num_with_special_characters(self): temp_dir = self.get_temp_dir() special_char_dir = 'subdir(subdir_number=1*)/sub.dir(subdir_number=2?)' path = os.path.join(temp_dir, special_char_dir) checkpoint_mngr = checkpoint_manager.FileCheckpointManager(path) state = _create_dummy_state() checkpoint_mngr.save_checkpoint(state, 1) structure = _create_dummy_structure() state, round_num = checkpoint_mngr.load_latest_checkpoint(structure) expected_state = _create_dummy_state() self.assertEqual(state, expected_state) self.assertEqual(round_num, 1)
def test_removes_oldest_with_keep_first_false(self): temp_dir = self.get_temp_dir() checkpoint_mngr = checkpoint_manager.FileCheckpointManager( temp_dir, keep_total=3, keep_first=False) dummy_state_1 = _create_dummy_state(1) checkpoint_mngr.save_checkpoint(dummy_state_1, 1) dummy_state_2 = _create_dummy_state(2) checkpoint_mngr.save_checkpoint(dummy_state_2, 2) dummy_state_3 = _create_dummy_state(3) checkpoint_mngr.save_checkpoint(dummy_state_3, 3) dummy_state_4 = _create_dummy_state(4) checkpoint_mngr.save_checkpoint(dummy_state_4, 4) self.assertCountEqual(os.listdir(temp_dir), ['ckpt_2', 'ckpt_3', 'ckpt_4'])
def _setup_outputs(root_output_dir, experiment_name, hparam_dict, write_metrics_with_bz2=True, rounds_per_profile=0): """Set up directories for experiment loops, write hyperparameters to disk.""" if not experiment_name: raise ValueError('experiment_name must be specified.') create_if_not_exists(root_output_dir) checkpoint_dir = os.path.join(root_output_dir, 'checkpoints', experiment_name) create_if_not_exists(checkpoint_dir) checkpoint_mngr = checkpoint_manager.FileCheckpointManager(checkpoint_dir) results_dir = os.path.join(root_output_dir, 'results', experiment_name) create_if_not_exists(results_dir) metrics_mngr = metrics_manager.ScalarMetricsManager( results_dir, use_bz2=write_metrics_with_bz2) summary_logdir = os.path.join(root_output_dir, 'logdir', experiment_name) create_if_not_exists(summary_logdir) summary_writer = tf.summary.create_file_writer(summary_logdir) if hparam_dict: hparam_dict['metrics_file'] = metrics_mngr.metrics_filename hparams_file = os.path.join(results_dir, 'hparams.csv') utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file) logging.info('Writing...') logging.info(' checkpoints to: %s', checkpoint_dir) logging.info(' metrics csv to: %s', metrics_mngr.metrics_filename) logging.info(' summaries to: %s', summary_logdir) @contextlib.contextmanager def profiler(round_num): if (rounds_per_profile > 0 and round_num % rounds_per_profile == 0): with tf.profiler.experimental.Profile(summary_logdir): yield else: yield return checkpoint_mngr, metrics_mngr, summary_writer, profiler
def test_checkpoint_manager_saves_state(self): FLAGS.total_rounds = 1 FLAGS.experiment_name = 'checkpoint_manager_saves_state' iterative_process = _build_federated_averaging_process() federated_data = [[_batch_fn()]] def client_datasets_fn(round_num): del round_num return federated_data def evaluate_fn(model): del model return {} temp_filepath = self.get_temp_dir() FLAGS.root_output_dir = temp_filepath final_state = training_loop.run(iterative_process, client_datasets_fn, evaluate_fn) ckpt_manager = checkpoint_manager.FileCheckpointManager( os.path.join( temp_filepath, 'checkpoints', FLAGS.experiment_name, )) restored_state, restored_round = ckpt_manager.load_latest_checkpoint( final_state) self.assertEqual(restored_round, 0) keras_model = tff.simulation.models.mnist.create_keras_model( compile_model=True) restored_state.model.assign_weights_to(keras_model) restored_loss = keras_model.test_on_batch(federated_data[0][0].x, federated_data[0][0].y) final_state.model.assign_weights_to(keras_model) final_loss = keras_model.test_on_batch(federated_data[0][0].x, federated_data[0][0].y) self.assertEqual(final_loss, restored_loss)
def _federated_averaging_training_loop(model_fn, client_optimizer_fn, server_optimizer_fn, client_datasets_fn, evaluate_fn, total_rounds=500, rounds_per_eval=1, metrics_hook=None): """A simple example of training loop for the Federated Averaging algorithm. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer`. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer`. client_datasets_fn: A function that takes the round number, and returns a list of `tf.data.Datset`, one per client. evaluate_fn: A function that takes state, performs evaluation, and returns evaluations metrics. total_rounds: Number of rounds to train. rounds_per_eval: How often to call the `metrics_hook` function. metrics_hook: A function taking arguments (training metrics, evaluation metrics, and round number). Optional. Returns: Final `ServerState`. """ logging.info('Starting federated training loop') checkpoint_dir = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name) checkpoint_manager_obj = checkpoint_manager.FileCheckpointManager( checkpoint_dir) if FLAGS.server_optimizer != 'flars': logging.error('Unsupported server_optimzier: %s', FLAGS.server_optimizer) else: iterative_process = flars_fedavg.build_federated_averaging_process( model_fn, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn) ServerState = flars_fedavg.ServerState # pylint: disable=invalid-name # construct an initial state here to act as a checkpoint template inital_state = iterative_process.initialize() inital_state = ServerState.from_tff_result(inital_state) logging.info('Looking for checkpoints in \'%s\'', checkpoint_dir) state, round_num = checkpoint_manager_obj.load_latest_checkpoint( inital_state) if state is None: logging.info('No previous checkpoints, initializing experiment') state = inital_state round_num = 0 if metrics_hook is not None: eval_metrics = evaluate_fn(state) metrics_hook({}, eval_metrics, round_num) checkpoint_manager_obj.save_checkpoint(state, 0) else: logging.info('Restarted from checkpoint round %d', round_num) while round_num < total_rounds: round_num += 1 train_metrics = {} # Reset the executor to clear the cache, and clear the default graph to # garbage collect tf.Functions that will no longer be used. tff.framework.set_default_executor( tff.framework.local_executor_factory(max_fanout=25)) tf.compat.v1.reset_default_graph() round_start_time = time.time() data_prep_start_time = time.time() train_data = client_datasets_fn(round_num) train_metrics['prepare_datasets_secs'] = time.time( ) - data_prep_start_time training_start_time = time.time() state, tff_train_metrics = iterative_process.next(state, train_data) state = ServerState.from_tff_result(state) tff_train_metrics = tff_train_metrics._asdict(recursive=True) train_metrics.update(tff_train_metrics) train_metrics['training_secs'] = time.time() - training_start_time logging.info('Round {:2d} elapsed time: {:.2f}s .'.format( round_num, (time.time() - round_start_time))) train_metrics['total_round_secs'] = time.time() - round_start_time if (round_num % FLAGS.rounds_per_checkpoint == 0 or round_num == total_rounds): save_checkpoint_start_time = time.time() checkpoint_manager_obj.save_checkpoint(state, round_num) train_metrics['save_checkpoint_secs'] = ( time.time() - save_checkpoint_start_time) if round_num % rounds_per_eval == 0 or round_num == total_rounds: if metrics_hook is not None: eval_metrics = evaluate_fn(state) metrics_hook(train_metrics, eval_metrics, round_num)
def setUp(self): super(FileCheckpointManagerLoadCheckpointTest, self).setUp() self._checkpoint_mngr = checkpoint_manager.FileCheckpointManager( self.get_temp_dir())