def test_reload_of_csvfile(self): temp_dir = os.path.join(self.get_temp_dir(), 'test_dir') csv_file = os.path.join(temp_dir, 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) csv_mngr.update_metrics(0, _create_scalar_metrics()) csv_mngr.update_metrics(5, _create_scalar_metrics()) new_csv_mngr = csv_manager.CSVMetricsManager(csv_file) fieldnames, metrics = new_csv_mngr.get_metrics() self.assertCountEqual(fieldnames, ['round_num', 'a/b', 'a/c']) self.assertLen(metrics, 2, 'There should be 2 rows (for rounds 0 and 5).') self.assertEqual(5, metrics[-1]['round_num'], 'Last metrics are for round 5.') self.assertEqual(set(os.listdir(temp_dir)), set(['metrics.csv']))
def test_update_metrics_adds_empty_str_if_previous_column_not_provided(self): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) csv_mngr.update_metrics(0, _create_scalar_metrics_with_extra_column()) csv_mngr.update_metrics(1, _create_scalar_metrics()) _, metrics = csv_mngr.get_metrics() self.assertEqual(metrics[1]['a/d'], '')
def test_save_metrics_adds_column(self, save_mode): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file, save_mode=save_mode) csv_mngr.save_metrics(_create_scalar_metrics(), 0) fieldnames, metrics = csv_mngr.get_metrics() self.assertCountEqual(fieldnames, ['round_num', 'a/b', 'a/c']) self.assertNotIn('a/d', metrics[0].keys()) csv_mngr.save_metrics(_create_scalar_metrics_with_extra_column(), 1) fieldnames, metrics = csv_mngr.get_metrics() self.assertCountEqual(fieldnames, ['round_num', 'a/b', 'a/c', 'a/d']) expected_round_0_metrics = { 'round_num': 0, 'a/b': 1.0, 'a/c': 2.0, 'a/d': '' } self.assertDictEqual(metrics[0], expected_round_0_metrics) expected_round_1_metrics = { 'round_num': 1, 'a/b': 1.0, 'a/c': 2.0, 'a/d': 3.0 } self.assertDictEqual(metrics[1], expected_round_1_metrics)
def test_rows_are_cleared_and_last_round_num_is_reset(self): csv_file = os.path.join(self.get_temp_dir(), 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) csv_mngr.update_metrics(0, _create_scalar_metrics()) csv_mngr.update_metrics(5, _create_scalar_metrics()) csv_mngr.update_metrics(10, _create_scalar_metrics()) _, metrics = csv_mngr.get_metrics() self.assertLen(metrics, 3, 'There should be 3 rows (for rounds 0, 5, and 10).') csv_mngr.clear_rounds_after(last_valid_round_num=7) _, metrics = csv_mngr.get_metrics() self.assertLen( metrics, 2, 'After clearing all rounds after last_valid_round_num=7, should be 2 ' 'rows of metrics (for rounds 0 and 5).') self.assertEqual(5, metrics[-1]['round_num'], 'Last metrics retained are for round 5.') # The internal state of the manager knows the last round number is 7, so it # raises an exception if a user attempts to add new metrics at round 7, ... with self.assertRaises(ValueError): csv_mngr.update_metrics(7, _create_scalar_metrics()) # ... but allows a user to add new metrics at a round number greater than 7. csv_mngr.update_metrics(8, _create_scalar_metrics()) # (No exception.)
def test_update_metrics_raises_value_error_if_round_num_is_out_of_order( self): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) csv_mngr.update_metrics(1, _create_scalar_metrics()) with self.assertRaises(ValueError): csv_mngr.update_metrics(0, _create_scalar_metrics())
def test_column_names_with_list(self): metrics_to_append = {'a': [3, 4, 5], 'b': 6} csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) csv_mngr.update_metrics(0, metrics_to_append) fieldnames, _ = csv_mngr.get_metrics() self.assertCountEqual(['a/0', 'a/1', 'a/2', 'b', 'round_num'], fieldnames)
def test_rows_are_cleared_is_reflected_in_saved_file(self): temp_dir = self.get_temp_dir() csv_file = os.path.join(temp_dir, 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) csv_mngr.update_metrics(0, _create_scalar_metrics()) csv_mngr.update_metrics(5, _create_scalar_metrics()) csv_mngr.update_metrics(10, _create_scalar_metrics()) filename = os.path.join(temp_dir, 'metrics.csv') with tf.io.gfile.GFile(filename, 'r') as csvfile: num_lines_before = len(csvfile.readlines()) # The CSV file should have 4 lines, one for the fieldnames, and 3 for each # call to `update_metrics`. self.assertEqual(num_lines_before, 4) csv_mngr.clear_rounds_after(last_valid_round_num=7) with tf.io.gfile.GFile(filename, 'r') as csvfile: num_lines_after = len(csvfile.readlines()) # The CSV file should have 3 lines, one for the fieldnames, and 2 for the # calls to `update_metrics` with round_nums less <= 7. self.assertEqual(num_lines_after, 3)
def test_clear_rounds_after_raises_value_error_if_round_num_is_negative( self): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) csv_mngr.update_metrics(0, _create_scalar_metrics()) with self.assertRaises(ValueError): csv_mngr.clear_rounds_after(round_num=-1)
def test_clear_rounds_after_raises_runtime_error_if_no_metrics(self): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) # Clear is allowed with no metrics if no rounds have yet completed. csv_mngr.clear_rounds_after(round_num=0) with self.assertRaises(RuntimeError): # Raise exception with no metrics if no rounds have yet completed. csv_mngr.clear_rounds_after(round_num=1)
def test_get_metrics_with_nonscalars_returns_list_of_lists(self): metrics_to_append = { 'a': tf.ones([1], dtype=tf.int32), 'b': tf.zeros([2, 2], dtype=tf.int32) } csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) csv_mngr.update_metrics(0, metrics_to_append) _, metrics = csv_mngr.get_metrics() self.assertEqual(metrics[0]['a'], '[1]') self.assertEqual(metrics[0]['b'], '[[0, 0], [0, 0]]')
def test_update_metrics_returns_flat_dict_with_scalars(self): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) input_data_dict = _create_scalar_metrics() appended_data_dict = csv_mngr.update_metrics(0, input_data_dict) self.assertEqual( collections.OrderedDict({ 'a/b': 1.0, 'a/c': 2.0, 'round_num': 0.0 }), appended_data_dict)
def test_update_metrics_adds_column_if_previously_unseen_metric_added(self): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) csv_mngr.update_metrics(0, _create_scalar_metrics()) fieldnames, metrics = csv_mngr.get_metrics() self.assertCountEqual(fieldnames, ['round_num', 'a/b', 'a/c']) self.assertNotIn('a/d', metrics[0].keys()) csv_mngr.update_metrics(1, _create_scalar_metrics_with_extra_column()) fieldnames, metrics = csv_mngr.get_metrics() self.assertCountEqual(fieldnames, ['round_num', 'a/b', 'a/c', 'a/d']) self.assertEqual(metrics[0]['a/d'], '')
def test_clear_metrics_with_round_zero_removes_all_metrics(self): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) csv_mngr.save_metrics(0, _create_scalar_metrics()) csv_mngr.save_metrics(5, _create_scalar_metrics()) csv_mngr.save_metrics(10, _create_scalar_metrics()) _, metrics = csv_mngr.get_metrics() self.assertLen(metrics, 3, 'There should be 3 rows (for rounds 0, 5, and 10).') csv_mngr.clear_metrics(round_num=0) _, metrics = csv_mngr.get_metrics() self.assertEmpty(metrics) self.assertIsNone(csv_mngr._latest_round_num)
def test_nonscalar_metrics_are_saved(self, save_mode): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file, save_mode=save_mode) _, metrics = csv_mngr.get_metrics() self.assertEmpty(metrics) csv_mngr.save_metrics(_create_nonscalar_metrics(), 0) _, metrics = csv_mngr.get_metrics() self.assertLen(metrics, 1) csv_mngr.save_metrics(_create_nonscalar_metrics(), 1) _, metrics = csv_mngr.get_metrics() self.assertLen(metrics, 2)
def test_nonscalar_metrics_are_appended(self): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) _, metrics = csv_mngr.get_metrics() self.assertEmpty(metrics) csv_mngr.update_metrics(0, _create_nonscalar_metrics()) _, metrics = csv_mngr.get_metrics() self.assertLen(metrics, 1) csv_mngr.update_metrics(1, _create_nonscalar_metrics()) _, metrics = csv_mngr.get_metrics() self.assertLen(metrics, 2)
def test_constructor_raises_value_error_if_csvfile_is_invalid(self): metrics_missing_round_num = _create_scalar_metrics() temp_dir = self.get_temp_dir() # This csvfile is 'invalid' in that it was not originally created by an # instance of CSVMetricsManager, and is missing a column for # round_num. invalid_csvfile = os.path.join(temp_dir, 'invalid_metrics.csv') with tf.io.gfile.GFile(invalid_csvfile, 'w') as csvfile: writer = csv.DictWriter( csvfile, fieldnames=metrics_missing_round_num.keys()) writer.writeheader() writer.writerow(metrics_missing_round_num) with self.assertRaises(ValueError): csv_manager.CSVMetricsManager(invalid_csvfile)
def test_save_metrics_with_alternating_fieldnames(self, save_mode): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file, save_mode=save_mode) csv_mngr.save_metrics({'a': 1}, 0) csv_mngr.save_metrics({'b': 3}, 1) csv_mngr.save_metrics({'a': 2}, 2) fieldnames, metrics = csv_mngr.get_metrics() self.assertCountEqual(fieldnames, ['round_num', 'a', 'b']) expected_round_0_metrics = {'round_num': 0, 'a': 1, 'b': ''} self.assertDictEqual(metrics[0], expected_round_0_metrics) expected_round_1_metrics = {'round_num': 1, 'a': '', 'b': 3} self.assertDictEqual(metrics[1], expected_round_1_metrics) expected_round_2_metrics = {'round_num': 2, 'a': 2, 'b': ''} self.assertDictEqual(metrics[2], expected_round_2_metrics)
def test_update_metrics_returns_flat_dict_with_nonscalars(self): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) input_data_dict = _create_nonscalar_metrics() appended_data_dict = csv_mngr.update_metrics(0, input_data_dict) expected_dict = collections.OrderedDict({ 'a/b': tf.ones([1]), 'a/c': tf.zeros([2, 2]), 'round_num': 0.0 }) self.assertListEqual(list(expected_dict.keys()), list(appended_data_dict.keys())) self.assertEqual(expected_dict['round_num'], appended_data_dict['round_num']) self.assertAllEqual(expected_dict['a/b'], appended_data_dict['a/b']) self.assertAllEqual(expected_dict['a/c'], appended_data_dict['a/c'])
def test_clear_metrics_removes_rounds_equal_to_input_arg(self, save_mode): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file, save_mode=save_mode) csv_mngr.save_metrics(_create_scalar_metrics(), 0) csv_mngr.save_metrics(_create_scalar_metrics(), 5) csv_mngr.save_metrics(_create_scalar_metrics(), 10) _, metrics = csv_mngr.get_metrics() self.assertLen(metrics, 3, 'There should be 3 rows (for rounds 0, 5, and 10).') csv_mngr.clear_metrics(round_num=5) _, metrics = csv_mngr.get_metrics() self.assertLen( metrics, 1, 'After clearing all rounds starting at round_num=5, there ' 'should be 1 row of metrics (for round 0).') self.assertEqual(0, metrics[-1]['round_num'], 'Last metrics retained are for round 0.') self.assertEqual(0, csv_mngr._latest_round_num)
def test_clear_metrics_removes_rounds_after_input_arg(self, save_mode): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file, save_mode=save_mode) csv_mngr.save_metrics(_create_scalar_metrics(), 0) csv_mngr.save_metrics(_create_scalar_metrics(), 5) csv_mngr.save_metrics(_create_scalar_metrics(), 10) _, metrics = csv_mngr.get_metrics() self.assertLen(metrics, 3, 'There should be 3 rows (for rounds 0, 5, and 10).') csv_mngr.clear_metrics(round_num=7) _, metrics = csv_mngr.get_metrics() self.assertLen( metrics, 2, 'After clearing all rounds after round_num=7, should be 2 ' 'rows of metrics (for rounds 0 and 5).') self.assertEqual(5, metrics[-1]['round_num'], 'Last metrics retained are for round 5.') self.assertEqual(5, csv_mngr._latest_round_num)
def test_rows_are_cleared_and_last_round_num_is_reset(self): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) csv_mngr.update_metrics(0, _create_scalar_metrics()) csv_mngr.update_metrics(5, _create_scalar_metrics()) csv_mngr.update_metrics(10, _create_scalar_metrics()) _, metrics = csv_mngr.get_metrics() self.assertLen(metrics, 3, 'There should be 3 rows (for rounds 0, 5, and 10).') csv_mngr.clear_rounds_after(round_num=7) _, metrics = csv_mngr.get_metrics() self.assertLen( metrics, 2, 'After clearing all rounds after round_num=7, should be 2 ' 'rows of metrics (for rounds 0 and 5).') self.assertEqual(5, metrics[-1]['round_num'], 'Last metrics retained are for round 5.') self.assertEqual(5, csv_mngr._latest_round_num)
def test_save_in_append_mode_appends_and_does_not_read_metrics( self, mock_dict_writer, mock_dict_reader, mock_write_to_csv): mock_reader = mock.MagicMock() mock_reader.fieldnames = ['round_num'] mock_reader.__iter__.return_value = [{'round_num': 0}] mock_dict_reader.return_value = mock_reader mock_writer = mock.MagicMock() mock_dict_writer.return_value = mock_writer csv_file = os.path.join(self.get_temp_dir(), 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager( csv_file, save_mode=csv_manager.SaveMode.APPEND) # After construction, list(csv.DictReader) should be called once, and # _write_to_csv should have no calls. self.assertEqual(mock_reader.__iter__.call_count, 1) self.assertEqual(mock_write_to_csv.call_count, 0) # After saving, list(csv.DictReader) should be called still just once, and # _write_to_csv should have no calls. csv_mngr.save_metrics({}, 1) self.assertEqual(mock_reader.__iter__.call_count, 1) self.assertEqual(mock_write_to_csv.call_count, 0)
def test_clear_metrics_raises_if_round_num_is_negative(self, save_mode): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file, save_mode=save_mode) csv_mngr.save_metrics(_create_scalar_metrics(), 0) with self.assertRaises(ValueError): csv_mngr.clear_metrics(round_num=-1)
def test_constructor_raises_if_save_mode_is_invalid(self): csv_file = os.path.join(self.get_temp_dir(), 'metrics.csv') with self.assertRaises(ValueError): csv_manager.CSVMetricsManager(csv_file, save_mode='invalid_mode')
def test_column_names(self): csv_file = os.path.join(self.get_temp_dir(), 'test_dir', 'metrics.csv') csv_mngr = csv_manager.CSVMetricsManager(csv_file) csv_mngr.update_metrics(0, _create_scalar_metrics()) fieldnames, _ = csv_mngr.get_metrics() self.assertCountEqual(['a/b', 'a/c', 'round_num'], fieldnames)
def test_csvfile_is_saved(self): temp_dir = os.path.join(self.get_temp_dir(), 'test_dir') csv_file = os.path.join(temp_dir, 'metrics.csv') csv_manager.CSVMetricsManager(csv_file) self.assertEqual(set(os.listdir(temp_dir)), set(['metrics.csv']))