def test_reload_of_csvfile(self): temp_dir = self.get_temp_dir() metrics_mngr = metrics_manager.ScalarMetricsManager(temp_dir, prefix='bar') metrics_mngr.update_metrics(0, _create_dummy_metrics()) metrics_mngr.update_metrics(5, _create_dummy_metrics()) new_metrics_mngr = metrics_manager.ScalarMetricsManager( temp_dir, prefix='bar') metrics = new_metrics_mngr.get_metrics() self.assertEqual(2, len(metrics.index), 'There should be 2 rows of metrics (for rounds 0 and 5).') self.assertEqual(5, metrics['round_num'].iloc[-1], 'Last metrics are for round 5.') self.assertEqual(set(os.listdir(temp_dir)), set(['bar.metrics.csv.bz2']))
def test_rows_are_cleared_and_last_round_num_is_reset(self): metrics_mngr = metrics_manager.ScalarMetricsManager(self.get_temp_dir()) metrics_mngr.update_metrics(0, _create_dummy_metrics()) metrics_mngr.update_metrics(5, _create_dummy_metrics()) metrics_mngr.update_metrics(10, _create_dummy_metrics()) metrics = metrics_mngr.get_metrics() self.assertEqual( 3, len(metrics.index), 'There should be 3 rows of metrics (for rounds 0, 5, and 10).') metrics_mngr.clear_rounds_after(last_valid_round_num=7) metrics = metrics_mngr.get_metrics() self.assertEqual( 2, len(metrics.index), 'After clearing all rounds after last_valid_round_num=7, should be 2 ' 'rows of metrics (for rounds 0 and 5).') self.assertEqual(5, metrics['round_num'].iloc[-1], 'Last metrics retained are for round 5.') # The internal state of the manager knows the last round number is 7, so it # raises an exception if a user attempts to add new metrics at round 7, ... with self.assertRaises(ValueError): metrics_mngr.update_metrics(7, _create_dummy_metrics()) # ... but allows a user to add new metrics at a round number greater than 7. metrics_mngr.update_metrics(8, _create_dummy_metrics()) # (No exception.)
def test_fn_writes_metrics(self): FLAGS.total_rounds = 1 FLAGS.rounds_per_eval = 10 FLAGS.experiment_name = 'test_metrics' iterative_process = _build_federated_averaging_process() batch = _batch_fn() federated_data = [[batch]] def client_datasets_fn(round_num): del round_num return federated_data def evaluate(model): keras_model = tff.simulation.models.mnist.create_keras_model( compile_model=True) model.assign_weights_to(keras_model) return {'loss': keras_model.evaluate(batch.x, batch.y)} temp_filepath = self.get_temp_dir() FLAGS.root_output_dir = temp_filepath training_loop.run(iterative_process, client_datasets_fn, evaluate, test_fn=evaluate) results_dir = os.path.join(FLAGS.root_output_dir, 'results', FLAGS.experiment_name) scalar_manager = metrics_manager.ScalarMetricsManager(results_dir) metrics = scalar_manager.get_metrics() self.assertEqual(2, len(metrics.index)) self.assertIn('eval/loss', metrics.columns) self.assertIn('test/loss', metrics.columns) self.assertNotIn('train_eval/loss', metrics.columns)
def test_update_metrics_raises_value_error_if_round_num_is_out_of_order(self): metrics_mngr = metrics_manager.ScalarMetricsManager(self.get_temp_dir()) metrics_mngr.update_metrics(1, _create_dummy_metrics()) with self.assertRaises(ValueError): metrics_mngr.update_metrics(0, _create_dummy_metrics())
def _setup_outputs(root_output_dir, experiment_name, hparam_dict): """Set up directories for experiment loops, write hyperparameters to disk.""" if not experiment_name: raise ValueError('experiment_name must be specified.') create_if_not_exists(root_output_dir) checkpoint_dir = os.path.join(root_output_dir, 'checkpoints', experiment_name) create_if_not_exists(checkpoint_dir) checkpoint_mngr = checkpoint_manager.FileCheckpointManager(checkpoint_dir) results_dir = os.path.join(root_output_dir, 'results', experiment_name) create_if_not_exists(results_dir) metrics_mngr = metrics_manager.ScalarMetricsManager(results_dir) summary_logdir = os.path.join(root_output_dir, 'logdir', experiment_name) create_if_not_exists(summary_logdir) summary_writer = tf.compat.v2.summary.create_file_writer(summary_logdir) hparam_dict['metrics_file'] = metrics_mngr.metrics_filename hparams_file = os.path.join(results_dir, 'hparams.csv') utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file) logging.info('Writing...') logging.info(' checkpoints to: %s', checkpoint_dir) logging.info(' metrics csv to: %s', metrics_mngr.metrics_filename) logging.info(' summaries to: %s', summary_logdir) return checkpoint_mngr, metrics_mngr, summary_writer
def test_clear_rounds_after_raises_value_error_if_round_num_is_negative( self): metrics_mngr = metrics_manager.ScalarMetricsManager( self.get_temp_dir()) metrics_mngr.update_metrics(0, _create_dummy_metrics()) with self.assertRaises(ValueError): metrics_mngr.clear_rounds_after(last_valid_round_num=-1)
def test_update_metrics_returns_flat_dict(self): metrics_mngr = metrics_manager.ScalarMetricsManager(self.get_temp_dir()) input_data_dict = _create_dummy_metrics() appended_data_dict = metrics_mngr.update_metrics(0, input_data_dict) self.assertEqual({ 'a/b': 1.0, 'a/c': 2.0, 'round_num': 0.0 }, appended_data_dict)
def test_clear_rounds_after_raises_runtime_error_if_no_metrics(self): metrics_mngr = metrics_manager.ScalarMetricsManager(self.get_temp_dir()) # Clear is allowed with no metrics if no rounds have yet completed. metrics_mngr.clear_rounds_after(last_valid_round_num=0) with self.assertRaises(RuntimeError): # Raise exception with no metrics if no rounds have yet completed. metrics_mngr.clear_rounds_after(last_valid_round_num=1)
def test_update_metrics_adds_nan_if_previously_seen_metric_not_provided( self): metrics_mngr = metrics_manager.ScalarMetricsManager( self.get_temp_dir()) metrics_mngr.update_metrics(0, _create_dummy_metrics_with_extra_column()) metrics_mngr.update_metrics(1, _create_dummy_metrics()) metrics = metrics_mngr.get_metrics() self.assertTrue(np.isnan(metrics.at[1, 'a/d']))
def test_metrics_are_appended(self): metrics_mngr = metrics_manager.ScalarMetricsManager(self.get_temp_dir()) metrics = metrics_mngr.get_metrics() self.assertTrue(metrics.empty) metrics_mngr.update_metrics(0, _create_dummy_metrics()) metrics = metrics_mngr.get_metrics() self.assertEqual(1, len(metrics.index)) metrics_mngr.update_metrics(1, _create_dummy_metrics()) metrics = metrics_mngr.get_metrics() self.assertEqual(2, len(metrics.index))
def test_run_federated(self, run_federated_fn): total_rounds = 1 shared_args = collections.OrderedDict( client_epochs_per_round=1, client_batch_size=10, clients_per_round=1, client_datasets_random_seed=1, total_rounds=total_rounds, max_batches_per_client=2, iterative_process_builder=iterative_process_builder, assign_weights_fn=fed_avg_schedule.ServerState. assign_weights_to_keras_model, rounds_per_checkpoint=10, rounds_per_eval=10, rounds_per_train_eval=10, max_eval_batches=2) root_output_dir = self.get_temp_dir() exp_name = 'test_run_federated' shared_args['root_output_dir'] = root_output_dir shared_args['experiment_name'] = exp_name run_federated_fn(**shared_args) results_dir = os.path.join(root_output_dir, 'results', exp_name) self.assertTrue(tf.io.gfile.exists(results_dir)) scalar_manager = metrics_manager.ScalarMetricsManager(results_dir) metrics = scalar_manager.get_metrics() self.assertIn( 'train/loss', metrics.columns, msg= 'The output metrics should have a `train/loss` column if training ' 'is successful.') self.assertIn( 'eval/loss', metrics.columns, msg= 'The output metrics should have a `train/loss` column if validation' ' metrics computation is successful.') self.assertIn( 'test/loss', metrics.columns, msg='The output metrics should have a `test/loss` column if test ' 'metrics computation is successful.') self.assertLen( metrics.index, total_rounds + 1, msg='The number of rows in the metrics CSV should be the number of ' 'training rounds + 1 (as there is an extra row for validation/test set' 'metrics after training has completed.')
def test_constructor_raises_value_error_if_csvfile_is_invalid(self): dataframe_missing_round_num = pd.DataFrame.from_dict( _create_dummy_metrics()) temp_dir = self.get_temp_dir() # This csvfile is 'invalid' in that it was not originally created by an # instance of ScalarMetricsManager, and is missing a column for # round_num. invalid_csvfile = os.path.join(temp_dir, 'foo.metrics.csv.bz2') utils_impl.atomic_write_to_csv(dataframe_missing_round_num, invalid_csvfile) with self.assertRaises(ValueError): metrics_manager.ScalarMetricsManager(temp_dir, prefix='foo')
def test_rows_are_cleared_is_reflected_in_saved_file(self): temp_dir = self.get_temp_dir() metrics_mngr = metrics_manager.ScalarMetricsManager(temp_dir, prefix='foo') metrics_mngr.update_metrics(0, _create_dummy_metrics()) metrics_mngr.update_metrics(5, _create_dummy_metrics()) metrics_mngr.update_metrics(10, _create_dummy_metrics()) file_contents_before = utils_impl.atomic_read_from_csv( os.path.join(temp_dir, 'foo.metrics.csv.bz2')) self.assertEqual(3, len(file_contents_before.index)) metrics_mngr.clear_rounds_after(last_valid_round_num=7) file_contents_after = utils_impl.atomic_read_from_csv( os.path.join(temp_dir, 'foo.metrics.csv.bz2')) self.assertEqual(2, len(file_contents_after.index))
def _setup_outputs(root_output_dir, experiment_name, hparam_dict, write_metrics_with_bz2=True, rounds_per_profile=0): """Set up directories for experiment loops, write hyperparameters to disk.""" if not experiment_name: raise ValueError('experiment_name must be specified.') create_if_not_exists(root_output_dir) checkpoint_dir = os.path.join(root_output_dir, 'checkpoints', experiment_name) create_if_not_exists(checkpoint_dir) checkpoint_mngr = checkpoint_manager.FileCheckpointManager(checkpoint_dir) results_dir = os.path.join(root_output_dir, 'results', experiment_name) create_if_not_exists(results_dir) metrics_mngr = metrics_manager.ScalarMetricsManager( results_dir, use_bz2=write_metrics_with_bz2) summary_logdir = os.path.join(root_output_dir, 'logdir', experiment_name) create_if_not_exists(summary_logdir) summary_writer = tf.summary.create_file_writer(summary_logdir) if hparam_dict: hparam_dict['metrics_file'] = metrics_mngr.metrics_filename hparams_file = os.path.join(results_dir, 'hparams.csv') utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file) logging.info('Writing...') logging.info(' checkpoints to: %s', checkpoint_dir) logging.info(' metrics csv to: %s', metrics_mngr.metrics_filename) logging.info(' summaries to: %s', summary_logdir) @contextlib.contextmanager def profiler(round_num): if (rounds_per_profile > 0 and round_num % rounds_per_profile == 0): with tf.profiler.experimental.Profile(summary_logdir): yield else: yield return checkpoint_mngr, metrics_mngr, summary_writer, profiler
def test_clear_rounds_after_raises_runtime_error_if_no_metrics(self): metrics_mngr = metrics_manager.ScalarMetricsManager( self.get_temp_dir()) with self.assertRaises(RuntimeError): metrics_mngr.clear_rounds_after(last_valid_round_num=0)
def test_csvfile_is_saved(self): temp_dir = self.get_temp_dir() metrics_manager.ScalarMetricsManager(temp_dir, prefix='foo') self.assertEqual(set(os.listdir(temp_dir)), set(['foo.metrics.csv.bz2']))
def test_column_names(self): metrics_mngr = metrics_manager.ScalarMetricsManager(self.get_temp_dir()) metrics_mngr.update_metrics(0, _create_dummy_metrics()) metrics = metrics_mngr.get_metrics() self.assertEqual(['a/b', 'a/c', 'round_num'], metrics.columns.tolist())