class TestStdOutput: def setup_method(self): self.tabular = TabularInput() self.std_output = StdOutput(with_timestamp=False) self.str_out = io.StringIO() def teardown_method(self): self.str_out.close() def test_record_str(self, mock_datetime): fake_timestamp(mock_datetime) with redirect_stdout(self.str_out): self.std_output.record('test') self.std_output.dump() self.str_out.seek(0) assert self.str_out.read() == 'test\n' def test_record_tabular(self, mock_datetime): fake_timestamp(mock_datetime) self.tabular.record('foo', 100) self.tabular.record('bar', 55) with redirect_stdout(self.str_out): self.std_output.record(self.tabular) self.std_output.dump() tab = ( '--- ---\n' 'bar 55\n' 'foo 100\n' '--- ---\n' ) # yapf: disable self.str_out.seek(0) assert self.str_out.read() == tab def test_record_with_timestamp(self, mock_datetime): fake_timestamp(mock_datetime) self.std_output = StdOutput(with_timestamp=True) with redirect_stdout(self.str_out): self.std_output.record('DOWEL') self.std_output.dump() self.str_out.seek(0) contents = self.str_out.read() assert contents == '{} | DOWEL\n'.format(FAKE_TIMESTAMP_SHORT) def test_record_unknown(self, mock_datetime): with pytest.raises(ValueError): self.std_output.record(dict())
class TestTextOutput: def setup_method(self): self.log_file = tempfile.NamedTemporaryFile() self.text_output = TextOutput(self.log_file.name) self.tabular = TabularInput() def teardown_method(self): self.log_file.close() def test_record(self, mock_datetime): fake_timestamp(mock_datetime) text = 'TESTING 123 DOWEL' self.text_output.record(text) self.text_output.dump() with open(self.log_file.name, 'r') as file: correct = '{} | TESTING 123 DOWEL\n'.format(FAKE_TIMESTAMP_SHORT) assert file.read() == correct more_text = 'MORE TESTING' self.text_output.record(more_text) self.text_output.dump() with open(self.log_file.name, 'r') as file: correct = ( '{} | TESTING 123 DOWEL\n' '{} | MORE TESTING\n' .format(FAKE_TIMESTAMP_SHORT, FAKE_TIMESTAMP_SHORT) ) # yapf: disable assert file.read() == correct def test_record_no_timestamp(self, mock_datetime): fake_timestamp(mock_datetime) self.text_output = TextOutput(self.log_file.name, with_timestamp=False) text = 'TESTING 123 DOWEL' self.text_output.record(text) self.text_output.dump() with open(self.log_file.name, 'r') as file: correct = 'TESTING 123 DOWEL\n' assert file.read() == correct def test_record_tabular(self, mock_datetime): fake_timestamp(mock_datetime) self.tabular.record('foo', 100) self.tabular.record('bar', 55) self.text_output.record(self.tabular) self.text_output.dump() with open(self.log_file.name, 'r') as file: tab = ( '--- ---\n' 'bar 55\n' 'foo 100\n' '--- ---\n' ) # yapf: disable assert file.read() == tab def test_record_unknown(self, mock_datetime): with pytest.raises(ValueError): self.text_output.record(dict())
class TestCsvOutput: def setup_method(self): self.log_file = tempfile.NamedTemporaryFile() self.csv_output = CsvOutput(self.log_file.name) self.tabular = TabularInput() self.tabular.clear() def teardown_method(self): self.log_file.close() def test_record(self): foo = 1 bar = 10 self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.csv_output.record(self.tabular) self.tabular.record('foo', foo * 2) self.tabular.record('bar', bar * 2) self.csv_output.record(self.tabular) self.csv_output.dump() correct = [ {'foo': str(foo), 'bar': str(bar)}, {'foo': str(foo * 2), 'bar': str(bar * 2)}, ] # yapf: disable self.assert_csv_matches(correct) assert not os.path.exists('{}.tmp'.format(self.log_file.name)) def test_key_inconsistent(self): for i in range(4): self.tabular.record('itr', i) self.tabular.record('loss', 100.0 / (2 + i)) # the addition of new data to tabular breaks logging to CSV if i > 0: self.tabular.record('x', i) if i > 1: self.tabular.record('y', i + 1) # this should not produce a warning, because we only warn once self.csv_output.record(self.tabular) self.csv_output.dump() correct = [{ 'itr': str(0), 'loss': str(100.0 / 2.), 'x': '', 'y': '' }, { 'itr': str(1), 'loss': str(100.0 / 3.), 'x': str(1), 'y': '' }, { 'itr': str(2), 'loss': str(100.0 / 4.), 'x': str(2), 'y': str(3) }, { 'itr': str(3), 'loss': str(100.0 / 5.), 'x': str(3), 'y': str(4) }] self.assert_csv_matches(correct) def test_empty_record(self): self.csv_output.record(self.tabular) self.csv_output.dump() foo = 1 bar = 10 self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.csv_output.record(self.tabular) self.csv_output.dump() # Empty lines are not recorded self.assert_csv_matches([{'foo': str(foo), 'bar': str(bar)}]) def test_unacceptable_type(self): with pytest.raises(ValueError): self.csv_output.record('foo') def assert_csv_matches(self, correct): """Check the first row of a csv file and compare it to known values.""" with open(self.log_file.name, 'r') as file: contents = list(csv.DictReader(file)) assert len(contents) == len(correct) for row, correct_row in zip(contents, correct): assert sorted(list(row.items())) == sorted( list(correct_row.items()))
class TestCsvOutput: def setup_method(self): self.log_file = tempfile.NamedTemporaryFile() self.csv_output = CsvOutput(self.log_file.name) self.tabular = TabularInput() self.tabular.clear() def teardown_method(self): self.log_file.close() def test_record(self): foo = 1 bar = 10 self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.csv_output.record(self.tabular) self.tabular.record('foo', foo * 2) self.tabular.record('bar', bar * 2) self.csv_output.record(self.tabular) self.csv_output.dump() correct = [ {'foo': str(foo), 'bar': str(bar)}, {'foo': str(foo * 2), 'bar': str(bar * 2)}, ] # yapf: disable self.assert_csv_matches(correct) def test_record_inconsistent(self): foo = 1 bar = 10 self.tabular.record('foo', foo) self.csv_output.record(self.tabular) self.tabular.record('foo', foo * 2) self.tabular.record('bar', bar * 2) with pytest.warns(CsvOutputWarning): self.csv_output.record(self.tabular) # this should not produce a warning, because we only warn once self.csv_output.record(self.tabular) self.csv_output.dump() correct = [ {'foo': str(foo)}, {'foo': str(foo * 2)}, ] # yapf: disable self.assert_csv_matches(correct) def test_empty_record(self): self.csv_output.record(self.tabular) assert not self.csv_output._writer foo = 1 bar = 10 self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.csv_output.record(self.tabular) assert not self.csv_output._warned_once def test_unacceptable_type(self): with pytest.raises(ValueError): self.csv_output.record('foo') def test_disable_warnings(self): foo = 1 bar = 10 self.tabular.record('foo', foo) self.csv_output.record(self.tabular) self.tabular.record('foo', foo * 2) self.tabular.record('bar', bar * 2) self.csv_output.disable_warnings() # this should not produce a warning, because we disabled warnings self.csv_output.record(self.tabular) def assert_csv_matches(self, correct): """Check the first row of a csv file and compare it to known values.""" with open(self.log_file.name, 'r') as file: reader = csv.DictReader(file) for correct_row in correct: row = next(reader) assert row == correct_row
class TestTabularInput: def setup_method(self): self.tabular = TabularInput() def test_str(self): foo = 123 bar = 456 baz = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20] self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.tabular.record('baz', baz) correct_str = ( '--- ---\n' 'bar 456\n' 'foo 123\n' '--- ---' ) # yapf: disable assert str(self.tabular) == correct_str def test_record(self): foo = 1 bar = 10 self.tabular.record('foo', foo) self.tabular.record('bar', bar) assert self.tabular.as_dict['foo'] == foo assert self.tabular.as_dict['bar'] == bar def test_record_misc_stat(self): self.tabular.record_misc_stat('Foo', [0, 1, 2]) bar = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] self.tabular.record_misc_stat('Bar', bar, placement='front') correct = { 'FooAverage': 1.0, 'FooStd': 0.816496580927726, 'FooMedian': 1.0, 'FooMin': 0, 'FooMax': 2, 'AverageBar': 5.5, 'StdBar': 2.8722813232690143, 'MedianBar': 5.5, 'MinBar': 1, 'MaxBar': 10, } assert self.tabular.as_dict == correct def test_record_misc_stat_nan(self): self.tabular.record_misc_stat('none', None) correct = { 'noneAverage': math.nan, 'noneStd': math.nan, 'noneMedian': math.nan, 'noneMin': math.nan, 'noneMax': math.nan } for k, v in self.tabular.as_dict.items(): assert correct[k] is math.nan def test_prefix(self): foo = 111 bar = 222 with self.tabular.prefix('test_'): self.tabular.record('foo', foo) self.tabular.record('bar', bar) correct = {'test_foo': foo, 'test_bar': bar} assert self.tabular.as_dict == correct def test_clear(self): foo = 1 bar = 10 self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.tabular.mark_all() assert self.tabular.as_dict self.tabular.clear() assert not self.tabular.as_dict def test_clear_warns_not_recorded_once(self): self.tabular.record('foo', 1) with pytest.warns(TabularInputWarning): self.tabular.clear() self.tabular.record('foo', 1) # This not trigger a warning, because we warned once self.tabular.clear() def test_disable_warnings(self): self.tabular.record('foo', 1) with pytest.warns(TabularInputWarning): self.tabular.clear() self.tabular.record('bar', 2) self.tabular.disable_warnings() # This should not trigger a warning, because we disabled warnings self.tabular.clear() def test_push_prefix(self): foo = 111 bar = 222 self.tabular.push_prefix('aaa_') self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.tabular.push_prefix('bbb_') self.tabular.record('foo', foo) self.tabular.record('bar', bar) correct = { 'aaa_foo': foo, 'aaa_bar': bar, 'aaa_bbb_foo': foo, 'aaa_bbb_bar': bar, } assert self.tabular.as_dict == correct def test_pop_prefix(self): foo = 111 bar = 222 self.tabular.push_prefix('aaa_') self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.tabular.push_prefix('bbb_') self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.tabular.pop_prefix() self.tabular.record('foopop', foo) self.tabular.record('barpop', bar) self.tabular.pop_prefix() self.tabular.record('foopop', foo) self.tabular.record('barpop', bar) correct = { 'aaa_foo': foo, 'aaa_bar': bar, 'aaa_bbb_foo': foo, 'aaa_bbb_bar': bar, 'aaa_foopop': foo, 'aaa_barpop': bar, 'foopop': foo, 'barpop': bar, } assert self.tabular.as_dict == correct def test_as_primitive_dict(self): stuff = { 'int': int(1), 'float': float(2.0), 'bool': bool(True), 'str': str('Hello, world!'), 'dict': dict(foo='bar'), } for k, v in stuff.items(): self.tabular.record(k, v) correct = { 'int': int(1), 'float': float(2.0), 'bool': bool(True), 'str': str('Hello, world!'), } assert self.tabular.as_primitive_dict == correct