class TBOutputTest(TfGraphTestCase): def setUp(self): super().setUp() self.log_dir = tempfile.TemporaryDirectory() self.tabular = TabularInput() self.tabular.clear() self.tensor_board_output = TensorBoardOutput(self.log_dir.name) def tearDown(self): self.tensor_board_output.close() self.log_dir.cleanup() super().tearDown()
class TestTabularInput(unittest.TestCase): def setUp(self): self.tabular = TabularInput() def test_str(self): foo = 123 bar = 456 baz = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20] self.tabular.record("foo", foo) self.tabular.record("bar", bar) self.tabular.record("baz", baz) correct_str = "--- ---\n" "bar 456\n" "foo 123\n" "--- ---" # yapf: disable assert str(self.tabular) == correct_str def test_record(self): foo = 1 bar = 10 self.tabular.record("foo", foo) self.tabular.record("bar", bar) assert self.tabular.as_dict["foo"] == foo assert self.tabular.as_dict["bar"] == bar def test_record_misc_stat(self): self.tabular.record_misc_stat("Foo", [0, 1, 2]) bar = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] self.tabular.record_misc_stat("Bar", bar, placement="front") correct = { "FooAverage": 1.0, "FooStd": 0.816496580927726, "FooMedian": 1.0, "FooMin": 0, "FooMax": 2, "AverageBar": 5.5, "StdBar": 2.8722813232690143, "MedianBar": 5.5, "MinBar": 1, "MaxBar": 10, } self.assertDictEqual(self.tabular.as_dict, correct) def test_record_misc_stat_nan(self): self.tabular.record_misc_stat("none", None) correct = { "noneAverage": math.nan, "noneStd": math.nan, "noneMedian": math.nan, "noneMin": math.nan, "noneMax": math.nan, } for k, v in self.tabular.as_dict.items(): assert correct[k] is math.nan def test_prefix(self): foo = 111 bar = 222 with self.tabular.prefix("test_"): self.tabular.record("foo", foo) self.tabular.record("bar", bar) correct = {"test_foo": foo, "test_bar": bar} self.assertDictEqual(self.tabular.as_dict, correct) def test_clear(self): foo = 1 bar = 10 self.tabular.record("foo", foo) self.tabular.record("bar", bar) self.tabular.mark_all() assert self.tabular.as_dict self.tabular.clear() assert not self.tabular.as_dict def test_clear_warns_not_recorded_once(self): self.tabular.record("foo", 1) with self.assertWarns(TabularInputWarning): self.tabular.clear() self.tabular.record("foo", 1) # This not trigger a warning, because we warned once self.tabular.clear() def test_disable_warnings(self): self.tabular.record("foo", 1) with self.assertWarns(TabularInputWarning): self.tabular.clear() self.tabular.record("bar", 2) self.tabular.disable_warnings() # This should not trigger a warning, because we disabled warnings self.tabular.clear() def test_push_prefix(self): foo = 111 bar = 222 self.tabular.push_prefix("aaa_") self.tabular.record("foo", foo) self.tabular.record("bar", bar) self.tabular.push_prefix("bbb_") self.tabular.record("foo", foo) self.tabular.record("bar", bar) correct = { "aaa_foo": foo, "aaa_bar": bar, "aaa_bbb_foo": foo, "aaa_bbb_bar": bar, } self.assertDictEqual(self.tabular.as_dict, correct) def test_pop_prefix(self): foo = 111 bar = 222 self.tabular.push_prefix("aaa_") self.tabular.record("foo", foo) self.tabular.record("bar", bar) self.tabular.push_prefix("bbb_") self.tabular.record("foo", foo) self.tabular.record("bar", bar) self.tabular.pop_prefix() self.tabular.record("foopop", foo) self.tabular.record("barpop", bar) self.tabular.pop_prefix() self.tabular.record("foopop", foo) self.tabular.record("barpop", bar) correct = { "aaa_foo": foo, "aaa_bar": bar, "aaa_bbb_foo": foo, "aaa_bbb_bar": bar, "aaa_foopop": foo, "aaa_barpop": bar, "foopop": foo, "barpop": bar, } self.assertDictEqual(self.tabular.as_dict, correct) def test_as_primitive_dict(self): stuff = { "int": int(1), "float": float(2.0), "bool": bool(True), "str": str("Hello, world!"), "dict": dict(foo="bar"), } for k, v in stuff.items(): self.tabular.record(k, v) correct = { "int": int(1), "float": float(2.0), "bool": bool(True), "str": str("Hello, world!"), } self.assertDictEqual(self.tabular.as_primitive_dict, correct)
class TestTabularInput(unittest.TestCase): def setUp(self): self.tabular = TabularInput() def test_str(self): foo = 123 bar = 456 baz = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20] self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.tabular.record('baz', baz) correct_str = ( '--- ---\n' 'bar 456\n' 'foo 123\n' '--- ---' ) # yapf: disable assert str(self.tabular) == correct_str def test_record(self): foo = 1 bar = 10 self.tabular.record('foo', foo) self.tabular.record('bar', bar) assert self.tabular.as_dict['foo'] == foo assert self.tabular.as_dict['bar'] == bar def test_record_misc_stat(self): self.tabular.record_misc_stat('Foo', [0, 1, 2]) bar = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] self.tabular.record_misc_stat('Bar', bar, placement='front') correct = { 'FooAverage': 1.0, 'FooStd': 0.816496580927726, 'FooMedian': 1.0, 'FooMin': 0, 'FooMax': 2, 'AverageBar': 5.5, 'StdBar': 2.8722813232690143, 'MedianBar': 5.5, 'MinBar': 1, 'MaxBar': 10, } self.assertDictEqual(self.tabular.as_dict, correct) def test_record_misc_stat_nan(self): self.tabular.record_misc_stat('none', None) correct = { 'noneAverage': math.nan, 'noneStd': math.nan, 'noneMedian': math.nan, 'noneMin': math.nan, 'noneMax': math.nan } for k, v in self.tabular.as_dict.items(): assert correct[k] is math.nan def test_prefix(self): foo = 111 bar = 222 with self.tabular.prefix('test_'): self.tabular.record('foo', foo) self.tabular.record('bar', bar) correct = {'test_foo': foo, 'test_bar': bar} self.assertDictEqual(self.tabular.as_dict, correct) def test_clear(self): foo = 1 bar = 10 self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.tabular.mark_all() assert self.tabular.as_dict self.tabular.clear() assert not self.tabular.as_dict def test_clear_warns_not_recorded_once(self): self.tabular.record('foo', 1) with self.assertWarns(TabularInputWarning): self.tabular.clear() self.tabular.record('foo', 1) # This not trigger a warning, because we warned once self.tabular.clear() def test_disable_warnings(self): self.tabular.record('foo', 1) with self.assertWarns(TabularInputWarning): self.tabular.clear() self.tabular.record('bar', 2) self.tabular.disable_warnings() # This should not trigger a warning, because we disabled warnings self.tabular.clear() def test_push_prefix(self): foo = 111 bar = 222 self.tabular.push_prefix('aaa_') self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.tabular.push_prefix('bbb_') self.tabular.record('foo', foo) self.tabular.record('bar', bar) correct = { 'aaa_foo': foo, 'aaa_bar': bar, 'aaa_bbb_foo': foo, 'aaa_bbb_bar': bar, } self.assertDictEqual(self.tabular.as_dict, correct) def test_pop_prefix(self): foo = 111 bar = 222 self.tabular.push_prefix('aaa_') self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.tabular.push_prefix('bbb_') self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.tabular.pop_prefix() self.tabular.record('foopop', foo) self.tabular.record('barpop', bar) self.tabular.pop_prefix() self.tabular.record('foopop', foo) self.tabular.record('barpop', bar) correct = { 'aaa_foo': foo, 'aaa_bar': bar, 'aaa_bbb_foo': foo, 'aaa_bbb_bar': bar, 'aaa_foopop': foo, 'aaa_barpop': bar, 'foopop': foo, 'barpop': bar, } self.assertDictEqual(self.tabular.as_dict, correct) def test_as_primitive_dict(self): stuff = { 'int': int(1), 'float': float(2.0), 'bool': bool(True), 'str': str('Hello, world!'), 'dict': dict(foo='bar'), } for k, v in stuff.items(): self.tabular.record(k, v) correct = { 'int': int(1), 'float': float(2.0), 'bool': bool(True), 'str': str('Hello, world!'), } self.assertDictEqual(self.tabular.as_primitive_dict, correct)
class TestCsvOutput(unittest.TestCase): def setUp(self): self.log_file = tempfile.NamedTemporaryFile() self.csv_output = CsvOutput(self.log_file.name) self.tabular = TabularInput() self.tabular.clear() def tearDown(self): self.log_file.close() def test_record(self): foo = 1 bar = 10 self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.csv_output.record(self.tabular) self.tabular.record('foo', foo * 2) self.tabular.record('bar', bar * 2) self.csv_output.record(self.tabular) self.csv_output.dump() correct = [ {'foo': str(foo), 'bar': str(bar)}, {'foo': str(foo * 2), 'bar': str(bar * 2)}, ] # yapf: disable self.assert_csv_matches(correct) def test_record_inconsistent(self): foo = 1 bar = 10 self.tabular.record('foo', foo) self.csv_output.record(self.tabular) self.tabular.record('foo', foo * 2) self.tabular.record('bar', bar * 2) with self.assertWarns(CsvOutputWarning): self.csv_output.record(self.tabular) # this should not produce a warning, because we only warn once self.csv_output.record(self.tabular) self.csv_output.dump() correct = [ {'foo': str(foo)}, {'foo': str(foo * 2)}, ] # yapf: disable self.assert_csv_matches(correct) def test_empty_record(self): self.csv_output.record(self.tabular) assert not self.csv_output._writer foo = 1 bar = 10 self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.csv_output.record(self.tabular) assert not self.csv_output._warned_once def test_unacceptable_type(self): with self.assertRaises(ValueError): self.csv_output.record('foo') def test_disable_warnings(self): foo = 1 bar = 10 self.tabular.record('foo', foo) self.csv_output.record(self.tabular) self.tabular.record('foo', foo * 2) self.tabular.record('bar', bar * 2) self.csv_output.disable_warnings() # this should not produce a warning, because we disabled warnings self.csv_output.record(self.tabular) def assert_csv_matches(self, correct): """Check the first row of a csv file and compare it to known values.""" with open(self.log_file.name, 'r') as file: reader = csv.DictReader(file) for correct_row in correct: row = next(reader) self.assertDictEqual(row, correct_row)