Example #1
0
  def test_reload_of_csvfile(self):
    temp_dir = os.path.join(self.get_temp_dir(), 'test_dir')
    csv_file = os.path.join(temp_dir, 'metrics.csv')
    csv_mngr = csv_manager.CSVMetricsManager(csv_file)
    csv_mngr.update_metrics(0, _create_scalar_metrics())
    csv_mngr.update_metrics(5, _create_scalar_metrics())

    new_csv_mngr = csv_manager.CSVMetricsManager(csv_file)
    fieldnames, metrics = new_csv_mngr.get_metrics()
    self.assertCountEqual(fieldnames, ['round_num', 'a/b', 'a/c'])
    self.assertLen(metrics, 2, 'There should be 2 rows (for rounds 0 and 5).')
    self.assertEqual(5, metrics[-1]['round_num'],
                     'Last metrics are for round 5.')

    self.assertEqual(set(os.listdir(temp_dir)), set(['metrics.csv']))
Example #2
0
 def test_update_metrics_adds_empty_str_if_previous_column_not_provided(self):
   csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
   csv_mngr = csv_manager.CSVMetricsManager(csv_file)
   csv_mngr.update_metrics(0, _create_scalar_metrics_with_extra_column())
   csv_mngr.update_metrics(1, _create_scalar_metrics())
   _, metrics = csv_mngr.get_metrics()
   self.assertEqual(metrics[1]['a/d'], '')
  def test_save_metrics_adds_column(self, save_mode):
    csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
    csv_mngr = csv_manager.CSVMetricsManager(csv_file, save_mode=save_mode)
    csv_mngr.save_metrics(_create_scalar_metrics(), 0)
    fieldnames, metrics = csv_mngr.get_metrics()
    self.assertCountEqual(fieldnames, ['round_num', 'a/b', 'a/c'])
    self.assertNotIn('a/d', metrics[0].keys())

    csv_mngr.save_metrics(_create_scalar_metrics_with_extra_column(), 1)
    fieldnames, metrics = csv_mngr.get_metrics()
    self.assertCountEqual(fieldnames, ['round_num', 'a/b', 'a/c', 'a/d'])
    expected_round_0_metrics = {
        'round_num': 0,
        'a/b': 1.0,
        'a/c': 2.0,
        'a/d': ''
    }
    self.assertDictEqual(metrics[0], expected_round_0_metrics)

    expected_round_1_metrics = {
        'round_num': 1,
        'a/b': 1.0,
        'a/c': 2.0,
        'a/d': 3.0
    }
    self.assertDictEqual(metrics[1], expected_round_1_metrics)
Example #4
0
  def test_rows_are_cleared_and_last_round_num_is_reset(self):
    csv_file = os.path.join(self.get_temp_dir(), 'metrics.csv')
    csv_mngr = csv_manager.CSVMetricsManager(csv_file)

    csv_mngr.update_metrics(0, _create_scalar_metrics())
    csv_mngr.update_metrics(5, _create_scalar_metrics())
    csv_mngr.update_metrics(10, _create_scalar_metrics())
    _, metrics = csv_mngr.get_metrics()
    self.assertLen(metrics, 3,
                   'There should be 3 rows (for rounds 0, 5, and 10).')

    csv_mngr.clear_rounds_after(last_valid_round_num=7)

    _, metrics = csv_mngr.get_metrics()
    self.assertLen(
        metrics, 2,
        'After clearing all rounds after last_valid_round_num=7, should be 2 '
        'rows of metrics (for rounds 0 and 5).')
    self.assertEqual(5, metrics[-1]['round_num'],
                     'Last metrics retained are for round 5.')

    # The internal state of the manager knows the last round number is 7, so it
    # raises an exception if a user attempts to add new metrics at round 7, ...
    with self.assertRaises(ValueError):
      csv_mngr.update_metrics(7, _create_scalar_metrics())

    # ... but allows a user to add new metrics at a round number greater than 7.
    csv_mngr.update_metrics(8, _create_scalar_metrics())  # (No exception.)
 def test_update_metrics_raises_value_error_if_round_num_is_out_of_order(
         self):
     csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
     csv_mngr = csv_manager.CSVMetricsManager(csv_file)
     csv_mngr.update_metrics(1, _create_scalar_metrics())
     with self.assertRaises(ValueError):
         csv_mngr.update_metrics(0, _create_scalar_metrics())
Example #6
0
 def test_column_names_with_list(self):
   metrics_to_append = {'a': [3, 4, 5], 'b': 6}
   csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
   csv_mngr = csv_manager.CSVMetricsManager(csv_file)
   csv_mngr.update_metrics(0, metrics_to_append)
   fieldnames, _ = csv_mngr.get_metrics()
   self.assertCountEqual(['a/0', 'a/1', 'a/2', 'b', 'round_num'], fieldnames)
Example #7
0
  def test_rows_are_cleared_is_reflected_in_saved_file(self):
    temp_dir = self.get_temp_dir()
    csv_file = os.path.join(temp_dir, 'metrics.csv')
    csv_mngr = csv_manager.CSVMetricsManager(csv_file)

    csv_mngr.update_metrics(0, _create_scalar_metrics())
    csv_mngr.update_metrics(5, _create_scalar_metrics())
    csv_mngr.update_metrics(10, _create_scalar_metrics())

    filename = os.path.join(temp_dir, 'metrics.csv')
    with tf.io.gfile.GFile(filename, 'r') as csvfile:
      num_lines_before = len(csvfile.readlines())

    # The CSV file should have 4 lines, one for the fieldnames, and 3 for each
    # call to `update_metrics`.
    self.assertEqual(num_lines_before, 4)

    csv_mngr.clear_rounds_after(last_valid_round_num=7)

    with tf.io.gfile.GFile(filename, 'r') as csvfile:
      num_lines_after = len(csvfile.readlines())

    # The CSV file should have 3 lines, one for the fieldnames, and 2 for the
    # calls to `update_metrics` with round_nums less <= 7.
    self.assertEqual(num_lines_after, 3)
 def test_clear_rounds_after_raises_value_error_if_round_num_is_negative(
         self):
     csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
     csv_mngr = csv_manager.CSVMetricsManager(csv_file)
     csv_mngr.update_metrics(0, _create_scalar_metrics())
     with self.assertRaises(ValueError):
         csv_mngr.clear_rounds_after(round_num=-1)
 def test_clear_rounds_after_raises_runtime_error_if_no_metrics(self):
     csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
     csv_mngr = csv_manager.CSVMetricsManager(csv_file)
     # Clear is allowed with no metrics if no rounds have yet completed.
     csv_mngr.clear_rounds_after(round_num=0)
     with self.assertRaises(RuntimeError):
         # Raise exception with no metrics if no rounds have yet completed.
         csv_mngr.clear_rounds_after(round_num=1)
 def test_get_metrics_with_nonscalars_returns_list_of_lists(self):
     metrics_to_append = {
         'a': tf.ones([1], dtype=tf.int32),
         'b': tf.zeros([2, 2], dtype=tf.int32)
     }
     csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
     csv_mngr = csv_manager.CSVMetricsManager(csv_file)
     csv_mngr.update_metrics(0, metrics_to_append)
     _, metrics = csv_mngr.get_metrics()
     self.assertEqual(metrics[0]['a'], '[1]')
     self.assertEqual(metrics[0]['b'], '[[0, 0], [0, 0]]')
 def test_update_metrics_returns_flat_dict_with_scalars(self):
     csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
     csv_mngr = csv_manager.CSVMetricsManager(csv_file)
     input_data_dict = _create_scalar_metrics()
     appended_data_dict = csv_mngr.update_metrics(0, input_data_dict)
     self.assertEqual(
         collections.OrderedDict({
             'a/b': 1.0,
             'a/c': 2.0,
             'round_num': 0.0
         }), appended_data_dict)
Example #12
0
  def test_update_metrics_adds_column_if_previously_unseen_metric_added(self):
    csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
    csv_mngr = csv_manager.CSVMetricsManager(csv_file)
    csv_mngr.update_metrics(0, _create_scalar_metrics())
    fieldnames, metrics = csv_mngr.get_metrics()
    self.assertCountEqual(fieldnames, ['round_num', 'a/b', 'a/c'])
    self.assertNotIn('a/d', metrics[0].keys())

    csv_mngr.update_metrics(1, _create_scalar_metrics_with_extra_column())
    fieldnames, metrics = csv_mngr.get_metrics()
    self.assertCountEqual(fieldnames, ['round_num', 'a/b', 'a/c', 'a/d'])
    self.assertEqual(metrics[0]['a/d'], '')
Example #13
0
 def test_clear_metrics_with_round_zero_removes_all_metrics(self):
     csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
     csv_mngr = csv_manager.CSVMetricsManager(csv_file)
     csv_mngr.save_metrics(0, _create_scalar_metrics())
     csv_mngr.save_metrics(5, _create_scalar_metrics())
     csv_mngr.save_metrics(10, _create_scalar_metrics())
     _, metrics = csv_mngr.get_metrics()
     self.assertLen(metrics, 3,
                    'There should be 3 rows (for rounds 0, 5, and 10).')
     csv_mngr.clear_metrics(round_num=0)
     _, metrics = csv_mngr.get_metrics()
     self.assertEmpty(metrics)
     self.assertIsNone(csv_mngr._latest_round_num)
  def test_nonscalar_metrics_are_saved(self, save_mode):
    csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
    csv_mngr = csv_manager.CSVMetricsManager(csv_file, save_mode=save_mode)
    _, metrics = csv_mngr.get_metrics()
    self.assertEmpty(metrics)

    csv_mngr.save_metrics(_create_nonscalar_metrics(), 0)
    _, metrics = csv_mngr.get_metrics()
    self.assertLen(metrics, 1)

    csv_mngr.save_metrics(_create_nonscalar_metrics(), 1)
    _, metrics = csv_mngr.get_metrics()
    self.assertLen(metrics, 2)
    def test_nonscalar_metrics_are_appended(self):
        csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
        csv_mngr = csv_manager.CSVMetricsManager(csv_file)
        _, metrics = csv_mngr.get_metrics()
        self.assertEmpty(metrics)

        csv_mngr.update_metrics(0, _create_nonscalar_metrics())
        _, metrics = csv_mngr.get_metrics()
        self.assertLen(metrics, 1)

        csv_mngr.update_metrics(1, _create_nonscalar_metrics())
        _, metrics = csv_mngr.get_metrics()
        self.assertLen(metrics, 2)
 def test_constructor_raises_value_error_if_csvfile_is_invalid(self):
     metrics_missing_round_num = _create_scalar_metrics()
     temp_dir = self.get_temp_dir()
     # This csvfile is 'invalid' in that it was not originally created by an
     # instance of CSVMetricsManager, and is missing a column for
     # round_num.
     invalid_csvfile = os.path.join(temp_dir, 'invalid_metrics.csv')
     with tf.io.gfile.GFile(invalid_csvfile, 'w') as csvfile:
         writer = csv.DictWriter(
             csvfile, fieldnames=metrics_missing_round_num.keys())
         writer.writeheader()
         writer.writerow(metrics_missing_round_num)
     with self.assertRaises(ValueError):
         csv_manager.CSVMetricsManager(invalid_csvfile)
 def test_save_metrics_with_alternating_fieldnames(self, save_mode):
   csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
   csv_mngr = csv_manager.CSVMetricsManager(csv_file, save_mode=save_mode)
   csv_mngr.save_metrics({'a': 1}, 0)
   csv_mngr.save_metrics({'b': 3}, 1)
   csv_mngr.save_metrics({'a': 2}, 2)
   fieldnames, metrics = csv_mngr.get_metrics()
   self.assertCountEqual(fieldnames, ['round_num', 'a', 'b'])
   expected_round_0_metrics = {'round_num': 0, 'a': 1, 'b': ''}
   self.assertDictEqual(metrics[0], expected_round_0_metrics)
   expected_round_1_metrics = {'round_num': 1, 'a': '', 'b': 3}
   self.assertDictEqual(metrics[1], expected_round_1_metrics)
   expected_round_2_metrics = {'round_num': 2, 'a': 2, 'b': ''}
   self.assertDictEqual(metrics[2], expected_round_2_metrics)
 def test_update_metrics_returns_flat_dict_with_nonscalars(self):
     csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
     csv_mngr = csv_manager.CSVMetricsManager(csv_file)
     input_data_dict = _create_nonscalar_metrics()
     appended_data_dict = csv_mngr.update_metrics(0, input_data_dict)
     expected_dict = collections.OrderedDict({
         'a/b': tf.ones([1]),
         'a/c': tf.zeros([2, 2]),
         'round_num': 0.0
     })
     self.assertListEqual(list(expected_dict.keys()),
                          list(appended_data_dict.keys()))
     self.assertEqual(expected_dict['round_num'],
                      appended_data_dict['round_num'])
     self.assertAllEqual(expected_dict['a/b'], appended_data_dict['a/b'])
     self.assertAllEqual(expected_dict['a/c'], appended_data_dict['a/c'])
 def test_clear_metrics_removes_rounds_equal_to_input_arg(self, save_mode):
   csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
   csv_mngr = csv_manager.CSVMetricsManager(csv_file, save_mode=save_mode)
   csv_mngr.save_metrics(_create_scalar_metrics(), 0)
   csv_mngr.save_metrics(_create_scalar_metrics(), 5)
   csv_mngr.save_metrics(_create_scalar_metrics(), 10)
   _, metrics = csv_mngr.get_metrics()
   self.assertLen(metrics, 3,
                  'There should be 3 rows (for rounds 0, 5, and 10).')
   csv_mngr.clear_metrics(round_num=5)
   _, metrics = csv_mngr.get_metrics()
   self.assertLen(
       metrics, 1, 'After clearing all rounds starting at round_num=5, there '
       'should be 1 row of metrics (for round 0).')
   self.assertEqual(0, metrics[-1]['round_num'],
                    'Last metrics retained are for round 0.')
   self.assertEqual(0, csv_mngr._latest_round_num)
 def test_clear_metrics_removes_rounds_after_input_arg(self, save_mode):
   csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
   csv_mngr = csv_manager.CSVMetricsManager(csv_file, save_mode=save_mode)
   csv_mngr.save_metrics(_create_scalar_metrics(), 0)
   csv_mngr.save_metrics(_create_scalar_metrics(), 5)
   csv_mngr.save_metrics(_create_scalar_metrics(), 10)
   _, metrics = csv_mngr.get_metrics()
   self.assertLen(metrics, 3,
                  'There should be 3 rows (for rounds 0, 5, and 10).')
   csv_mngr.clear_metrics(round_num=7)
   _, metrics = csv_mngr.get_metrics()
   self.assertLen(
       metrics, 2, 'After clearing all rounds after round_num=7, should be 2 '
       'rows of metrics (for rounds 0 and 5).')
   self.assertEqual(5, metrics[-1]['round_num'],
                    'Last metrics retained are for round 5.')
   self.assertEqual(5, csv_mngr._latest_round_num)
 def test_rows_are_cleared_and_last_round_num_is_reset(self):
     csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
     csv_mngr = csv_manager.CSVMetricsManager(csv_file)
     csv_mngr.update_metrics(0, _create_scalar_metrics())
     csv_mngr.update_metrics(5, _create_scalar_metrics())
     csv_mngr.update_metrics(10, _create_scalar_metrics())
     _, metrics = csv_mngr.get_metrics()
     self.assertLen(metrics, 3,
                    'There should be 3 rows (for rounds 0, 5, and 10).')
     csv_mngr.clear_rounds_after(round_num=7)
     _, metrics = csv_mngr.get_metrics()
     self.assertLen(
         metrics, 2,
         'After clearing all rounds after round_num=7, should be 2 '
         'rows of metrics (for rounds 0 and 5).')
     self.assertEqual(5, metrics[-1]['round_num'],
                      'Last metrics retained are for round 5.')
     self.assertEqual(5, csv_mngr._latest_round_num)
  def test_save_in_append_mode_appends_and_does_not_read_metrics(
      self, mock_dict_writer, mock_dict_reader, mock_write_to_csv):
    mock_reader = mock.MagicMock()
    mock_reader.fieldnames = ['round_num']
    mock_reader.__iter__.return_value = [{'round_num': 0}]
    mock_dict_reader.return_value = mock_reader

    mock_writer = mock.MagicMock()
    mock_dict_writer.return_value = mock_writer

    csv_file = os.path.join(self.get_temp_dir(), 'metrics.csv')
    csv_mngr = csv_manager.CSVMetricsManager(
        csv_file, save_mode=csv_manager.SaveMode.APPEND)
    # After construction, list(csv.DictReader) should be called once, and
    # _write_to_csv should have no calls.
    self.assertEqual(mock_reader.__iter__.call_count, 1)
    self.assertEqual(mock_write_to_csv.call_count, 0)

    # After saving, list(csv.DictReader) should be called still just once, and
    # _write_to_csv should have no calls.
    csv_mngr.save_metrics({}, 1)
    self.assertEqual(mock_reader.__iter__.call_count, 1)
    self.assertEqual(mock_write_to_csv.call_count, 0)
 def test_clear_metrics_raises_if_round_num_is_negative(self, save_mode):
   csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
   csv_mngr = csv_manager.CSVMetricsManager(csv_file, save_mode=save_mode)
   csv_mngr.save_metrics(_create_scalar_metrics(), 0)
   with self.assertRaises(ValueError):
     csv_mngr.clear_metrics(round_num=-1)
 def test_constructor_raises_if_save_mode_is_invalid(self):
   csv_file = os.path.join(self.get_temp_dir(), 'metrics.csv')
   with self.assertRaises(ValueError):
     csv_manager.CSVMetricsManager(csv_file, save_mode='invalid_mode')
 def test_column_names(self):
     csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv')
     csv_mngr = csv_manager.CSVMetricsManager(csv_file)
     csv_mngr.update_metrics(0, _create_scalar_metrics())
     fieldnames, _ = csv_mngr.get_metrics()
     self.assertCountEqual(['a/b', 'a/c', 'round_num'], fieldnames)
 def test_csvfile_is_saved(self):
     temp_dir = os.path.join(self.get_temp_dir(), 'test_dir')
     csv_file = os.path.join(temp_dir, 'metrics.csv')
     csv_manager.CSVMetricsManager(csv_file)
     self.assertEqual(set(os.listdir(temp_dir)), set(['metrics.csv']))