def test_nonscalar_metrics_are_written(self): summary_dir = os.path.join(self.get_temp_dir(), 'logdir') tb_mngr = tensorboard_manager.TensorBoardManager( summary_dir=summary_dir) tb_mngr.update_metrics(0, _create_nonscalar_metrics()) self.assertTrue(tf.io.gfile.exists(summary_dir)) self.assertLen(tf.io.gfile.listdir(summary_dir), 1)
def test_update_hparams_returns_flat_dict(self): tb_mngr = tensorboard_manager.TensorBoardManager( summary_dir=self.get_temp_dir()) input_data_dict = _create_scalar_metrics() appended_data_dict = tb_mngr.update_hparams(input_data_dict) self.assertEqual({ 'a/b': 1.0, 'a/c': 2.0, }, appended_data_dict)
def test_update_metrics_raises_value_error_if_round_num_is_out_of_order( self): tb_mngr = tensorboard_manager.TensorBoardManager( summary_dir=self.get_temp_dir()) tb_mngr.update_metrics(1, _create_scalar_metrics()) with self.assertRaises(ValueError): tb_mngr.update_metrics(0, _create_scalar_metrics())
def _setup_outputs(root_output_dir, experiment_name, hparam_dict, rounds_per_profile=0): """Set up directories for experiment loops, write hyperparameters to disk.""" if not experiment_name: raise ValueError('experiment_name must be specified.') create_if_not_exists(root_output_dir) checkpoint_dir = os.path.join(root_output_dir, 'checkpoints', experiment_name) create_if_not_exists(checkpoint_dir) checkpoint_mngr = tff.simulation.FileCheckpointManager(checkpoint_dir) results_dir = os.path.join(root_output_dir, 'results', experiment_name) create_if_not_exists(results_dir) csv_file = os.path.join(results_dir, 'experiment.metrics.csv') metrics_mngr = tff.simulation.CSVMetricsManager(csv_file) summary_logdir = os.path.join(root_output_dir, 'logdir', experiment_name) tb_mngr = tensorboard_manager.TensorBoardManager( summary_dir=summary_logdir) if hparam_dict: hparam_dict['metrics_file'] = metrics_mngr.metrics_filename hparams_file = os.path.join(results_dir, 'hparams.csv') utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file) tb_mngr.update_hparams( {k: v for k, v in hparam_dict.items() if v is not None}) logging.info('Writing...') logging.info(' checkpoints to: %s', checkpoint_dir) logging.info(' metrics csv to: %s', metrics_mngr.metrics_filename) logging.info(' summaries to: %s', summary_logdir) @contextlib.contextmanager def profiler(round_num): if (rounds_per_profile > 0 and round_num % rounds_per_profile == 0): with tf.profiler.experimental.Profile(summary_logdir): yield else: yield return checkpoint_mngr, metrics_mngr, tb_mngr, profiler