예제 #1
0
class TBOutputTest(TfGraphTestCase):
    def setUp(self):
        super().setUp()
        self.log_dir = tempfile.TemporaryDirectory()
        self.tabular = TabularInput()
        self.tabular.clear()
        self.tensor_board_output = TensorBoardOutput(self.log_dir.name)

    def tearDown(self):
        self.tensor_board_output.close()
        self.log_dir.cleanup()
        super().tearDown()
예제 #2
0
class TestTabularInput(unittest.TestCase):
    def setUp(self):
        self.tabular = TabularInput()

    def test_str(self):
        foo = 123
        bar = 456
        baz = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
        self.tabular.record("foo", foo)
        self.tabular.record("bar", bar)
        self.tabular.record("baz", baz)

        correct_str = "---  ---\n" "bar  456\n" "foo  123\n" "---  ---"  # yapf: disable
        assert str(self.tabular) == correct_str

    def test_record(self):
        foo = 1
        bar = 10
        self.tabular.record("foo", foo)
        self.tabular.record("bar", bar)

        assert self.tabular.as_dict["foo"] == foo
        assert self.tabular.as_dict["bar"] == bar

    def test_record_misc_stat(self):
        self.tabular.record_misc_stat("Foo", [0, 1, 2])
        bar = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        self.tabular.record_misc_stat("Bar", bar, placement="front")

        correct = {
            "FooAverage": 1.0,
            "FooStd": 0.816496580927726,
            "FooMedian": 1.0,
            "FooMin": 0,
            "FooMax": 2,
            "AverageBar": 5.5,
            "StdBar": 2.8722813232690143,
            "MedianBar": 5.5,
            "MinBar": 1,
            "MaxBar": 10,
        }
        self.assertDictEqual(self.tabular.as_dict, correct)

    def test_record_misc_stat_nan(self):
        self.tabular.record_misc_stat("none", None)

        correct = {
            "noneAverage": math.nan,
            "noneStd": math.nan,
            "noneMedian": math.nan,
            "noneMin": math.nan,
            "noneMax": math.nan,
        }
        for k, v in self.tabular.as_dict.items():
            assert correct[k] is math.nan

    def test_prefix(self):
        foo = 111
        bar = 222
        with self.tabular.prefix("test_"):
            self.tabular.record("foo", foo)
            self.tabular.record("bar", bar)

        correct = {"test_foo": foo, "test_bar": bar}
        self.assertDictEqual(self.tabular.as_dict, correct)

    def test_clear(self):
        foo = 1
        bar = 10
        self.tabular.record("foo", foo)
        self.tabular.record("bar", bar)
        self.tabular.mark_all()

        assert self.tabular.as_dict
        self.tabular.clear()
        assert not self.tabular.as_dict

    def test_clear_warns_not_recorded_once(self):
        self.tabular.record("foo", 1)

        with self.assertWarns(TabularInputWarning):
            self.tabular.clear()

        self.tabular.record("foo", 1)
        # This not trigger a warning, because we warned once
        self.tabular.clear()

    def test_disable_warnings(self):
        self.tabular.record("foo", 1)

        with self.assertWarns(TabularInputWarning):
            self.tabular.clear()

        self.tabular.record("bar", 2)
        self.tabular.disable_warnings()

        # This should not trigger a warning, because we disabled warnings
        self.tabular.clear()

    def test_push_prefix(self):
        foo = 111
        bar = 222
        self.tabular.push_prefix("aaa_")
        self.tabular.record("foo", foo)
        self.tabular.record("bar", bar)
        self.tabular.push_prefix("bbb_")
        self.tabular.record("foo", foo)
        self.tabular.record("bar", bar)

        correct = {
            "aaa_foo": foo,
            "aaa_bar": bar,
            "aaa_bbb_foo": foo,
            "aaa_bbb_bar": bar,
        }
        self.assertDictEqual(self.tabular.as_dict, correct)

    def test_pop_prefix(self):
        foo = 111
        bar = 222

        self.tabular.push_prefix("aaa_")
        self.tabular.record("foo", foo)
        self.tabular.record("bar", bar)
        self.tabular.push_prefix("bbb_")
        self.tabular.record("foo", foo)
        self.tabular.record("bar", bar)
        self.tabular.pop_prefix()
        self.tabular.record("foopop", foo)
        self.tabular.record("barpop", bar)
        self.tabular.pop_prefix()
        self.tabular.record("foopop", foo)
        self.tabular.record("barpop", bar)
        correct = {
            "aaa_foo": foo,
            "aaa_bar": bar,
            "aaa_bbb_foo": foo,
            "aaa_bbb_bar": bar,
            "aaa_foopop": foo,
            "aaa_barpop": bar,
            "foopop": foo,
            "barpop": bar,
        }
        self.assertDictEqual(self.tabular.as_dict, correct)

    def test_as_primitive_dict(self):
        stuff = {
            "int": int(1),
            "float": float(2.0),
            "bool": bool(True),
            "str": str("Hello, world!"),
            "dict": dict(foo="bar"),
        }
        for k, v in stuff.items():
            self.tabular.record(k, v)

        correct = {
            "int": int(1),
            "float": float(2.0),
            "bool": bool(True),
            "str": str("Hello, world!"),
        }
        self.assertDictEqual(self.tabular.as_primitive_dict, correct)
예제 #3
0
class TestTabularInput(unittest.TestCase):
    def setUp(self):
        self.tabular = TabularInput()

    def test_str(self):
        foo = 123
        bar = 456
        baz = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.tabular.record('baz', baz)

        correct_str = (
            '---  ---\n'
            'bar  456\n'
            'foo  123\n'
            '---  ---'
        )  # yapf: disable
        assert str(self.tabular) == correct_str

    def test_record(self):
        foo = 1
        bar = 10
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)

        assert self.tabular.as_dict['foo'] == foo
        assert self.tabular.as_dict['bar'] == bar

    def test_record_misc_stat(self):
        self.tabular.record_misc_stat('Foo', [0, 1, 2])
        bar = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        self.tabular.record_misc_stat('Bar', bar, placement='front')

        correct = {
            'FooAverage': 1.0,
            'FooStd': 0.816496580927726,
            'FooMedian': 1.0,
            'FooMin': 0,
            'FooMax': 2,
            'AverageBar': 5.5,
            'StdBar': 2.8722813232690143,
            'MedianBar': 5.5,
            'MinBar': 1,
            'MaxBar': 10,
        }
        self.assertDictEqual(self.tabular.as_dict, correct)

    def test_record_misc_stat_nan(self):
        self.tabular.record_misc_stat('none', None)

        correct = {
            'noneAverage': math.nan,
            'noneStd': math.nan,
            'noneMedian': math.nan,
            'noneMin': math.nan,
            'noneMax': math.nan
        }
        for k, v in self.tabular.as_dict.items():
            assert correct[k] is math.nan

    def test_prefix(self):
        foo = 111
        bar = 222
        with self.tabular.prefix('test_'):
            self.tabular.record('foo', foo)
            self.tabular.record('bar', bar)

        correct = {'test_foo': foo, 'test_bar': bar}
        self.assertDictEqual(self.tabular.as_dict, correct)

    def test_clear(self):
        foo = 1
        bar = 10
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.tabular.mark_all()

        assert self.tabular.as_dict
        self.tabular.clear()
        assert not self.tabular.as_dict

    def test_clear_warns_not_recorded_once(self):
        self.tabular.record('foo', 1)

        with self.assertWarns(TabularInputWarning):
            self.tabular.clear()

        self.tabular.record('foo', 1)
        # This not trigger a warning, because we warned once
        self.tabular.clear()

    def test_disable_warnings(self):
        self.tabular.record('foo', 1)

        with self.assertWarns(TabularInputWarning):
            self.tabular.clear()

        self.tabular.record('bar', 2)
        self.tabular.disable_warnings()

        # This should not trigger a warning, because we disabled warnings
        self.tabular.clear()

    def test_push_prefix(self):
        foo = 111
        bar = 222
        self.tabular.push_prefix('aaa_')
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.tabular.push_prefix('bbb_')
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)

        correct = {
            'aaa_foo': foo,
            'aaa_bar': bar,
            'aaa_bbb_foo': foo,
            'aaa_bbb_bar': bar,
        }
        self.assertDictEqual(self.tabular.as_dict, correct)

    def test_pop_prefix(self):
        foo = 111
        bar = 222

        self.tabular.push_prefix('aaa_')
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.tabular.push_prefix('bbb_')
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.tabular.pop_prefix()
        self.tabular.record('foopop', foo)
        self.tabular.record('barpop', bar)
        self.tabular.pop_prefix()
        self.tabular.record('foopop', foo)
        self.tabular.record('barpop', bar)
        correct = {
            'aaa_foo': foo,
            'aaa_bar': bar,
            'aaa_bbb_foo': foo,
            'aaa_bbb_bar': bar,
            'aaa_foopop': foo,
            'aaa_barpop': bar,
            'foopop': foo,
            'barpop': bar,
        }
        self.assertDictEqual(self.tabular.as_dict, correct)

    def test_as_primitive_dict(self):
        stuff = {
            'int': int(1),
            'float': float(2.0),
            'bool': bool(True),
            'str': str('Hello, world!'),
            'dict': dict(foo='bar'),
        }
        for k, v in stuff.items():
            self.tabular.record(k, v)

        correct = {
            'int': int(1),
            'float': float(2.0),
            'bool': bool(True),
            'str': str('Hello, world!'),
        }
        self.assertDictEqual(self.tabular.as_primitive_dict, correct)
예제 #4
0
class TestCsvOutput(unittest.TestCase):
    def setUp(self):
        self.log_file = tempfile.NamedTemporaryFile()
        self.csv_output = CsvOutput(self.log_file.name)
        self.tabular = TabularInput()
        self.tabular.clear()

    def tearDown(self):
        self.log_file.close()

    def test_record(self):
        foo = 1
        bar = 10
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.csv_output.record(self.tabular)
        self.tabular.record('foo', foo * 2)
        self.tabular.record('bar', bar * 2)
        self.csv_output.record(self.tabular)
        self.csv_output.dump()

        correct = [
            {'foo': str(foo), 'bar': str(bar)},
            {'foo': str(foo * 2), 'bar': str(bar * 2)},
        ]  # yapf: disable
        self.assert_csv_matches(correct)

    def test_record_inconsistent(self):
        foo = 1
        bar = 10
        self.tabular.record('foo', foo)
        self.csv_output.record(self.tabular)
        self.tabular.record('foo', foo * 2)
        self.tabular.record('bar', bar * 2)

        with self.assertWarns(CsvOutputWarning):
            self.csv_output.record(self.tabular)

        # this should not produce a warning, because we only warn once
        self.csv_output.record(self.tabular)

        self.csv_output.dump()

        correct = [
            {'foo': str(foo)},
            {'foo': str(foo * 2)},
        ]  # yapf: disable
        self.assert_csv_matches(correct)

    def test_empty_record(self):
        self.csv_output.record(self.tabular)
        assert not self.csv_output._writer

        foo = 1
        bar = 10
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.csv_output.record(self.tabular)
        assert not self.csv_output._warned_once

    def test_unacceptable_type(self):
        with self.assertRaises(ValueError):
            self.csv_output.record('foo')

    def test_disable_warnings(self):
        foo = 1
        bar = 10
        self.tabular.record('foo', foo)
        self.csv_output.record(self.tabular)
        self.tabular.record('foo', foo * 2)
        self.tabular.record('bar', bar * 2)

        self.csv_output.disable_warnings()

        # this should not produce a warning, because we disabled warnings
        self.csv_output.record(self.tabular)

    def assert_csv_matches(self, correct):
        """Check the first row of a csv file and compare it to known values."""
        with open(self.log_file.name, 'r') as file:
            reader = csv.DictReader(file)

            for correct_row in correct:
                row = next(reader)
                self.assertDictEqual(row, correct_row)