def test_rows_are_cleared_is_reflected_in_saved_file(self):
        temp_dir = self.get_temp_dir()
        metrics_mngr = metrics_manager.ScalarMetricsManager(temp_dir,
                                                            prefix='foo')

        metrics_mngr.update_metrics(0, _create_dummy_metrics())
        metrics_mngr.update_metrics(5, _create_dummy_metrics())
        metrics_mngr.update_metrics(10, _create_dummy_metrics())

        filename = os.path.join(temp_dir, 'foo.metrics.csv')
        with tf.io.gfile.GFile(filename, 'r') as csvfile:
            num_lines_before = len(csvfile.readlines())

        # The CSV file should have 4 lines, one for the fieldnames, and 3 for each
        # call to `update_metrics`.
        self.assertEqual(num_lines_before, 4)

        metrics_mngr.clear_rounds_after(last_valid_round_num=7)

        with tf.io.gfile.GFile(filename, 'r') as csvfile:
            num_lines_after = len(csvfile.readlines())

        # The CSV file should have 3 lines, one for the fieldnames, and 2 for the
        # calls to `update_metrics` with round_nums less <= 7.
        self.assertEqual(num_lines_after, 3)
Beispiel #2
0
    def test_rows_are_cleared_and_last_round_num_is_reset(self):
        metrics_mngr = metrics_manager.ScalarMetricsManager(
            self.get_temp_dir())

        metrics_mngr.update_metrics(0, _create_dummy_metrics())
        metrics_mngr.update_metrics(5, _create_dummy_metrics())
        metrics_mngr.update_metrics(10, _create_dummy_metrics())
        metrics = metrics_mngr.get_metrics()
        self.assertEqual(
            3, len(metrics.index),
            'There should be 3 rows of metrics (for rounds 0, 5, and 10).')

        metrics_mngr.clear_rounds_after(last_valid_round_num=7)

        metrics = metrics_mngr.get_metrics()
        self.assertEqual(
            2, len(metrics.index),
            'After clearing all rounds after last_valid_round_num=7, should be 2 '
            'rows of metrics (for rounds 0 and 5).')
        self.assertEqual(5, metrics['round_num'].iloc[-1],
                         'Last metrics retained are for round 5.')

        # The internal state of the manager knows the last round number is 7, so it
        # raises an exception if a user attempts to add new metrics at round 7, ...
        with self.assertRaises(ValueError):
            metrics_mngr.update_metrics(7, _create_dummy_metrics())

        # ... but allows a user to add new metrics at a round number greater than 7.
        metrics_mngr.update_metrics(8,
                                    _create_dummy_metrics())  # (No exception.)
Beispiel #3
0
  def test_fn_writes_metrics(self):
    experiment_name = 'test_metrics'
    iterative_process = _build_federated_averaging_process()
    batch = _batch_fn()
    federated_data = [[batch]]

    def client_datasets_fn(round_num):
      del round_num
      return federated_data

    def evaluate(model):
      keras_model = tff.simulation.models.mnist.create_keras_model(
          compile_model=True)
      model.assign_weights_to(keras_model)
      return {'loss': keras_model.evaluate(batch.x, batch.y)}

    root_output_dir = self.get_temp_dir()
    training_loop.run(
        iterative_process=iterative_process,
        client_datasets_fn=client_datasets_fn,
        validation_fn=evaluate,
        total_rounds=1,
        experiment_name=experiment_name,
        root_output_dir=root_output_dir,
        rounds_per_eval=10,
        test_fn=evaluate)

    results_dir = os.path.join(root_output_dir, 'results', experiment_name)

    scalar_manager = metrics_manager.ScalarMetricsManager(results_dir)
    metrics = scalar_manager.get_metrics()
    self.assertEqual(2, len(metrics.index))
    self.assertIn('eval/loss', metrics.columns)
    self.assertIn('test/loss', metrics.columns)
    self.assertNotIn('train_eval/loss', metrics.columns)
Beispiel #4
0
    def test_clear_rounds_after_raises_value_error_if_round_num_is_negative(
            self):
        metrics_mngr = metrics_manager.ScalarMetricsManager(
            self.get_temp_dir())
        metrics_mngr.update_metrics(0, _create_dummy_metrics())

        with self.assertRaises(ValueError):
            metrics_mngr.clear_rounds_after(last_valid_round_num=-1)
    def test_reload_of_csvfile(self):
        temp_dir = self.get_temp_dir()
        metrics_mngr = metrics_manager.ScalarMetricsManager(temp_dir,
                                                            prefix='bar')
        metrics_mngr.update_metrics(0, _create_dummy_metrics())
        metrics_mngr.update_metrics(5, _create_dummy_metrics())

        new_metrics_mngr = metrics_manager.ScalarMetricsManager(temp_dir,
                                                                prefix='bar')
        fieldnames, metrics = new_metrics_mngr.get_metrics()
        self.assertCountEqual(fieldnames, ['round_num', 'a/b', 'a/c'])
        self.assertLen(metrics, 2,
                       'There should be 2 rows (for rounds 0 and 5).')
        self.assertEqual(5, metrics[-1]['round_num'],
                         'Last metrics are for round 5.')

        self.assertEqual(set(os.listdir(temp_dir)), set(['bar.metrics.csv']))
    def test_reload_of_csvfile(self):
        temp_dir = self.get_temp_dir()
        metrics_mngr = metrics_manager.ScalarMetricsManager(temp_dir,
                                                            prefix='bar')
        metrics_mngr.update_metrics(0, _create_dummy_metrics())
        metrics_mngr.update_metrics(5, _create_dummy_metrics())

        new_metrics_mngr = metrics_manager.ScalarMetricsManager(temp_dir,
                                                                prefix='bar')
        metrics = new_metrics_mngr.get_metrics()
        self.assertEqual(
            2, len(metrics.index),
            'There should be 2 rows of metrics (for rounds 0 and 5).')
        self.assertEqual(5, metrics['round_num'].iloc[-1],
                         'Last metrics are for round 5.')

        self.assertEqual(set(os.listdir(temp_dir)), set(['bar.metrics.csv']))
Beispiel #7
0
    def test_update_metrics_raises_value_error_if_round_num_is_out_of_order(
            self):
        metrics_mngr = metrics_manager.ScalarMetricsManager(
            self.get_temp_dir())

        metrics_mngr.update_metrics(1, _create_dummy_metrics())

        with self.assertRaises(ValueError):
            metrics_mngr.update_metrics(0, _create_dummy_metrics())
Beispiel #8
0
 def test_update_metrics_adds_nan_if_previously_seen_metric_not_provided(
         self):
     metrics_mngr = metrics_manager.ScalarMetricsManager(
         self.get_temp_dir())
     metrics_mngr.update_metrics(0,
                                 _create_dummy_metrics_with_extra_column())
     metrics_mngr.update_metrics(1, _create_dummy_metrics())
     metrics = metrics_mngr.get_metrics()
     self.assertTrue(np.isnan(metrics.at[1, 'a/d']))
 def test_update_metrics_adds_empty_str_if_previous_column_not_provided(
         self):
     metrics_mngr = metrics_manager.ScalarMetricsManager(
         self.get_temp_dir())
     metrics_mngr.update_metrics(0,
                                 _create_dummy_metrics_with_extra_column())
     metrics_mngr.update_metrics(1, _create_dummy_metrics())
     _, metrics = metrics_mngr.get_metrics()
     self.assertEqual(metrics[1]['a/d'], '')
Beispiel #10
0
    def test_clear_rounds_after_raises_runtime_error_if_no_metrics(self):
        metrics_mngr = metrics_manager.ScalarMetricsManager(
            self.get_temp_dir())

        # Clear is allowed with no metrics if no rounds have yet completed.
        metrics_mngr.clear_rounds_after(last_valid_round_num=0)

        with self.assertRaises(RuntimeError):
            # Raise exception with no metrics if no rounds have yet completed.
            metrics_mngr.clear_rounds_after(last_valid_round_num=1)
Beispiel #11
0
 def test_update_metrics_returns_flat_dict(self):
     metrics_mngr = metrics_manager.ScalarMetricsManager(
         self.get_temp_dir())
     input_data_dict = _create_dummy_metrics()
     appended_data_dict = metrics_mngr.update_metrics(0, input_data_dict)
     self.assertEqual({
         'a/b': 1.0,
         'a/c': 2.0,
         'round_num': 0.0
     }, appended_data_dict)
    def test_metrics_are_appended(self):
        metrics_mngr = metrics_manager.ScalarMetricsManager(
            self.get_temp_dir())
        _, metrics = metrics_mngr.get_metrics()
        self.assertEmpty(metrics)

        metrics_mngr.update_metrics(0, _create_dummy_metrics())
        _, metrics = metrics_mngr.get_metrics()
        self.assertLen(metrics, 1)

        metrics_mngr.update_metrics(1, _create_dummy_metrics())
        _, metrics = metrics_mngr.get_metrics()
        self.assertLen(metrics, 2)
Beispiel #13
0
    def test_metrics_are_appended(self):
        metrics_mngr = metrics_manager.ScalarMetricsManager(
            self.get_temp_dir())
        metrics = metrics_mngr.get_metrics()
        self.assertTrue(metrics.empty)

        metrics_mngr.update_metrics(0, _create_dummy_metrics())
        metrics = metrics_mngr.get_metrics()
        self.assertEqual(1, len(metrics.index))

        metrics_mngr.update_metrics(1, _create_dummy_metrics())
        metrics = metrics_mngr.get_metrics()
        self.assertEqual(2, len(metrics.index))
    def test_run_federated(self, run_federated_fn):
        total_rounds = 1
        shared_args = collections.OrderedDict(
            client_epochs_per_round=1,
            client_batch_size=10,
            clients_per_round=1,
            client_datasets_random_seed=1,
            total_rounds=total_rounds,
            max_batches_per_client=2,
            iterative_process_builder=iterative_process_builder,
            rounds_per_checkpoint=10,
            rounds_per_eval=10,
            rounds_per_train_eval=10,
            max_eval_batches=2)
        root_output_dir = self.get_temp_dir()
        exp_name = 'test_run_federated'
        shared_args['root_output_dir'] = root_output_dir
        shared_args['experiment_name'] = exp_name

        run_federated_fn(**shared_args)

        results_dir = os.path.join(root_output_dir, 'results', exp_name)
        self.assertTrue(tf.io.gfile.exists(results_dir))

        scalar_manager = metrics_manager.ScalarMetricsManager(results_dir)
        metrics = scalar_manager.get_metrics()

        self.assertIn(
            'train/loss',
            metrics.columns,
            msg=
            'The output metrics should have a `train/loss` column if training '
            'is successful.')
        self.assertIn(
            'eval/loss',
            metrics.columns,
            msg=
            'The output metrics should have a `train/loss` column if validation'
            ' metrics computation is successful.')
        self.assertIn(
            'test/loss',
            metrics.columns,
            msg='The output metrics should have a `test/loss` column if test '
            'metrics computation is successful.')
        self.assertLen(
            metrics.index,
            total_rounds + 1,
            msg='The number of rows in the metrics CSV should be the number of '
            'training rounds + 1 (as there is an extra row for validation/test set'
            'metrics after training has completed.')
Beispiel #15
0
    def test_constructor_raises_value_error_if_csvfile_is_invalid(self):
        dataframe_missing_round_num = pd.DataFrame.from_dict(
            _create_dummy_metrics())

        temp_dir = self.get_temp_dir()
        # This csvfile is 'invalid' in that it was not originally created by an
        # instance of ScalarMetricsManager, and is missing a column for
        # round_num.
        invalid_csvfile = os.path.join(temp_dir, 'foo.metrics.csv.bz2')
        utils_impl.atomic_write_to_csv(dataframe_missing_round_num,
                                       invalid_csvfile)

        with self.assertRaises(ValueError):
            metrics_manager.ScalarMetricsManager(temp_dir, prefix='foo')
    def test_update_metrics_adds_column_if_previously_unseen_metric_added(
            self):
        metrics_mngr = metrics_manager.ScalarMetricsManager(
            self.get_temp_dir())
        metrics_mngr.update_metrics(0, _create_dummy_metrics())
        fieldnames, metrics = metrics_mngr.get_metrics()
        self.assertCountEqual(fieldnames, ['round_num', 'a/b', 'a/c'])
        self.assertNotIn('a/d', metrics[0].keys())

        metrics_mngr.update_metrics(1,
                                    _create_dummy_metrics_with_extra_column())
        fieldnames, metrics = metrics_mngr.get_metrics()
        self.assertCountEqual(fieldnames, ['round_num', 'a/b', 'a/c', 'a/d'])
        self.assertEqual(metrics[0]['a/d'], '')
    def test_constructor_raises_value_error_if_csvfile_is_invalid(self):
        metrics_missing_round_num = _create_dummy_metrics()
        temp_dir = self.get_temp_dir()
        # This csvfile is 'invalid' in that it was not originally created by an
        # instance of ScalarMetricsManager, and is missing a column for
        # round_num.
        invalid_csvfile = os.path.join(temp_dir, 'foo.metrics.csv')
        with tf.io.gfile.GFile(invalid_csvfile, 'w') as csvfile:
            writer = csv.DictWriter(
                csvfile, fieldnames=metrics_missing_round_num.keys())
            writer.writeheader()
            writer.writerow(metrics_missing_round_num)

        with self.assertRaises(ValueError):
            metrics_manager.ScalarMetricsManager(temp_dir, prefix='foo')
Beispiel #18
0
def _setup_outputs(root_output_dir,
                   experiment_name,
                   hparam_dict,
                   write_metrics_with_bz2=True,
                   rounds_per_profile=0):
  """Set up directories for experiment loops, write hyperparameters to disk."""

  if not experiment_name:
    raise ValueError('experiment_name must be specified.')

  create_if_not_exists(root_output_dir)

  checkpoint_dir = os.path.join(root_output_dir, 'checkpoints', experiment_name)
  create_if_not_exists(checkpoint_dir)
  checkpoint_mngr = checkpoint_manager.FileCheckpointManager(checkpoint_dir)

  results_dir = os.path.join(root_output_dir, 'results', experiment_name)
  create_if_not_exists(results_dir)
  metrics_mngr = metrics_manager.ScalarMetricsManager(
      results_dir, use_bz2=write_metrics_with_bz2)

  summary_logdir = os.path.join(root_output_dir, 'logdir', experiment_name)
  create_if_not_exists(summary_logdir)
  summary_writer = tf.summary.create_file_writer(summary_logdir)

  if hparam_dict:
    hparam_dict['metrics_file'] = metrics_mngr.metrics_filename
    hparams_file = os.path.join(results_dir, 'hparams.csv')
    utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file)
    with summary_writer.as_default():
      hp.hparams({k: v for k, v in hparam_dict.items() if v is not None})

  logging.info('Writing...')
  logging.info('    checkpoints to: %s', checkpoint_dir)
  logging.info('    metrics csv to: %s', metrics_mngr.metrics_filename)
  logging.info('    summaries to: %s', summary_logdir)

  @contextlib.contextmanager
  def profiler(round_num):
    if (rounds_per_profile > 0 and round_num % rounds_per_profile == 0):
      with tf.profiler.experimental.Profile(summary_logdir):
        yield
    else:
      yield

  return checkpoint_mngr, metrics_mngr, summary_writer, profiler
Beispiel #19
0
    def test_rows_are_cleared_is_reflected_in_saved_file(self):
        temp_dir = self.get_temp_dir()
        metrics_mngr = metrics_manager.ScalarMetricsManager(temp_dir,
                                                            prefix='foo')

        metrics_mngr.update_metrics(0, _create_dummy_metrics())
        metrics_mngr.update_metrics(5, _create_dummy_metrics())
        metrics_mngr.update_metrics(10, _create_dummy_metrics())

        file_contents_before = utils_impl.atomic_read_from_csv(
            os.path.join(temp_dir, 'foo.metrics.csv.bz2'))
        self.assertEqual(3, len(file_contents_before.index))

        metrics_mngr.clear_rounds_after(last_valid_round_num=7)

        file_contents_after = utils_impl.atomic_read_from_csv(
            os.path.join(temp_dir, 'foo.metrics.csv.bz2'))
        self.assertEqual(2, len(file_contents_after.index))
Beispiel #20
0
 def test_csvfile_is_saved(self):
     temp_dir = self.get_temp_dir()
     metrics_manager.ScalarMetricsManager(temp_dir, prefix='foo')
     self.assertEqual(set(os.listdir(temp_dir)),
                      set(['foo.metrics.csv.bz2']))
Beispiel #21
0
 def test_column_names(self):
     metrics_mngr = metrics_manager.ScalarMetricsManager(
         self.get_temp_dir())
     metrics_mngr.update_metrics(0, _create_dummy_metrics())
     metrics = metrics_mngr.get_metrics()
     self.assertEqual(['a/b', 'a/c', 'round_num'], metrics.columns.tolist())
 def test_column_names(self):
     metrics_mngr = metrics_manager.ScalarMetricsManager(
         self.get_temp_dir())
     metrics_mngr.update_metrics(0, _create_dummy_metrics())
     fieldnames, _ = metrics_mngr.get_metrics()
     self.assertCountEqual(['a/b', 'a/c', 'round_num'], fieldnames)