示例#1
0
class TBOutputTest(TfGraphTestCase):
    def setup_method(self):
        super().setup_method()
        self.log_dir = tempfile.TemporaryDirectory()
        self.tabular = TabularInput()
        self.tabular.clear()
        self.tensor_board_output = TensorBoardOutput(self.log_dir.name)

    def teardown_method(self):
        self.tensor_board_output.close()
        self.log_dir.cleanup()
        super().teardown_method()
示例#2
0
class TestCsvOutput:

    def setup_method(self):
        self.log_file = tempfile.NamedTemporaryFile()
        self.csv_output = CsvOutput(self.log_file.name)
        self.tabular = TabularInput()
        self.tabular.clear()

    def teardown_method(self):
        self.log_file.close()

    def test_record(self):
        foo = 1
        bar = 10
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.csv_output.record(self.tabular)
        self.tabular.record('foo', foo * 2)
        self.tabular.record('bar', bar * 2)
        self.csv_output.record(self.tabular)
        self.csv_output.dump()

        correct = [
            {'foo': str(foo), 'bar': str(bar)},
            {'foo': str(foo * 2), 'bar': str(bar * 2)},
        ]  # yapf: disable
        self.assert_csv_matches(correct)
        assert not os.path.exists('{}.tmp'.format(self.log_file.name))

    def test_key_inconsistent(self):
        for i in range(4):
            self.tabular.record('itr', i)
            self.tabular.record('loss', 100.0 / (2 + i))

            # the addition of new data to tabular breaks logging to CSV
            if i > 0:
                self.tabular.record('x', i)

            if i > 1:
                self.tabular.record('y', i + 1)

            # this should not produce a warning, because we only warn once
            self.csv_output.record(self.tabular)
            self.csv_output.dump()

        correct = [{
            'itr': str(0),
            'loss': str(100.0 / 2.),
            'x': '',
            'y': ''
        }, {
            'itr': str(1),
            'loss': str(100.0 / 3.),
            'x': str(1),
            'y': ''
        }, {
            'itr': str(2),
            'loss': str(100.0 / 4.),
            'x': str(2),
            'y': str(3)
        }, {
            'itr': str(3),
            'loss': str(100.0 / 5.),
            'x': str(3),
            'y': str(4)
        }]
        self.assert_csv_matches(correct)

    def test_empty_record(self):
        self.csv_output.record(self.tabular)
        self.csv_output.dump()

        foo = 1
        bar = 10
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.csv_output.record(self.tabular)
        self.csv_output.dump()
        # Empty lines are not recorded
        self.assert_csv_matches([{'foo': str(foo), 'bar': str(bar)}])

    def test_unacceptable_type(self):
        with pytest.raises(ValueError):
            self.csv_output.record('foo')

    def assert_csv_matches(self, correct):
        """Check the first row of a csv file and compare it to known values."""
        with open(self.log_file.name, 'r') as file:
            contents = list(csv.DictReader(file))
            assert len(contents) == len(correct)

            for row, correct_row in zip(contents, correct):
                assert sorted(list(row.items())) == sorted(
                    list(correct_row.items()))
示例#3
0
class TestCsvOutput:
    def setup_method(self):
        self.log_file = tempfile.NamedTemporaryFile()
        self.csv_output = CsvOutput(self.log_file.name)
        self.tabular = TabularInput()
        self.tabular.clear()

    def teardown_method(self):
        self.log_file.close()

    def test_record(self):
        foo = 1
        bar = 10
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.csv_output.record(self.tabular)
        self.tabular.record('foo', foo * 2)
        self.tabular.record('bar', bar * 2)
        self.csv_output.record(self.tabular)
        self.csv_output.dump()

        correct = [
            {'foo': str(foo), 'bar': str(bar)},
            {'foo': str(foo * 2), 'bar': str(bar * 2)},
        ]  # yapf: disable
        self.assert_csv_matches(correct)

    def test_record_inconsistent(self):
        foo = 1
        bar = 10
        self.tabular.record('foo', foo)
        self.csv_output.record(self.tabular)
        self.tabular.record('foo', foo * 2)
        self.tabular.record('bar', bar * 2)

        with pytest.warns(CsvOutputWarning):
            self.csv_output.record(self.tabular)

        # this should not produce a warning, because we only warn once
        self.csv_output.record(self.tabular)

        self.csv_output.dump()

        correct = [
            {'foo': str(foo)},
            {'foo': str(foo * 2)},
        ]  # yapf: disable
        self.assert_csv_matches(correct)

    def test_empty_record(self):
        self.csv_output.record(self.tabular)
        assert not self.csv_output._writer

        foo = 1
        bar = 10
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.csv_output.record(self.tabular)
        assert not self.csv_output._warned_once

    def test_unacceptable_type(self):
        with pytest.raises(ValueError):
            self.csv_output.record('foo')

    def test_disable_warnings(self):
        foo = 1
        bar = 10
        self.tabular.record('foo', foo)
        self.csv_output.record(self.tabular)
        self.tabular.record('foo', foo * 2)
        self.tabular.record('bar', bar * 2)

        self.csv_output.disable_warnings()

        # this should not produce a warning, because we disabled warnings
        self.csv_output.record(self.tabular)

    def assert_csv_matches(self, correct):
        """Check the first row of a csv file and compare it to known values."""
        with open(self.log_file.name, 'r') as file:
            reader = csv.DictReader(file)

            for correct_row in correct:
                row = next(reader)
                assert row == correct_row
示例#4
0
class TestTabularInput:
    def setup_method(self):
        self.tabular = TabularInput()

    def test_str(self):
        foo = 123
        bar = 456
        baz = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.tabular.record('baz', baz)

        correct_str = (
            '---  ---\n'
            'bar  456\n'
            'foo  123\n'
            '---  ---'
        )  # yapf: disable
        assert str(self.tabular) == correct_str

    def test_record(self):
        foo = 1
        bar = 10
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)

        assert self.tabular.as_dict['foo'] == foo
        assert self.tabular.as_dict['bar'] == bar

    def test_record_misc_stat(self):
        self.tabular.record_misc_stat('Foo', [0, 1, 2])
        bar = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        self.tabular.record_misc_stat('Bar', bar, placement='front')

        correct = {
            'FooAverage': 1.0,
            'FooStd': 0.816496580927726,
            'FooMedian': 1.0,
            'FooMin': 0,
            'FooMax': 2,
            'AverageBar': 5.5,
            'StdBar': 2.8722813232690143,
            'MedianBar': 5.5,
            'MinBar': 1,
            'MaxBar': 10,
        }
        assert self.tabular.as_dict == correct

    def test_record_misc_stat_nan(self):
        self.tabular.record_misc_stat('none', None)

        correct = {
            'noneAverage': math.nan,
            'noneStd': math.nan,
            'noneMedian': math.nan,
            'noneMin': math.nan,
            'noneMax': math.nan
        }
        for k, v in self.tabular.as_dict.items():
            assert correct[k] is math.nan

    def test_prefix(self):
        foo = 111
        bar = 222
        with self.tabular.prefix('test_'):
            self.tabular.record('foo', foo)
            self.tabular.record('bar', bar)

        correct = {'test_foo': foo, 'test_bar': bar}
        assert self.tabular.as_dict == correct

    def test_clear(self):
        foo = 1
        bar = 10
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.tabular.mark_all()

        assert self.tabular.as_dict
        self.tabular.clear()
        assert not self.tabular.as_dict

    def test_clear_warns_not_recorded_once(self):
        self.tabular.record('foo', 1)

        with pytest.warns(TabularInputWarning):
            self.tabular.clear()

        self.tabular.record('foo', 1)
        # This not trigger a warning, because we warned once
        self.tabular.clear()

    def test_disable_warnings(self):
        self.tabular.record('foo', 1)

        with pytest.warns(TabularInputWarning):
            self.tabular.clear()

        self.tabular.record('bar', 2)
        self.tabular.disable_warnings()

        # This should not trigger a warning, because we disabled warnings
        self.tabular.clear()

    def test_push_prefix(self):
        foo = 111
        bar = 222
        self.tabular.push_prefix('aaa_')
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.tabular.push_prefix('bbb_')
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)

        correct = {
            'aaa_foo': foo,
            'aaa_bar': bar,
            'aaa_bbb_foo': foo,
            'aaa_bbb_bar': bar,
        }
        assert self.tabular.as_dict == correct

    def test_pop_prefix(self):
        foo = 111
        bar = 222

        self.tabular.push_prefix('aaa_')
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.tabular.push_prefix('bbb_')
        self.tabular.record('foo', foo)
        self.tabular.record('bar', bar)
        self.tabular.pop_prefix()
        self.tabular.record('foopop', foo)
        self.tabular.record('barpop', bar)
        self.tabular.pop_prefix()
        self.tabular.record('foopop', foo)
        self.tabular.record('barpop', bar)
        correct = {
            'aaa_foo': foo,
            'aaa_bar': bar,
            'aaa_bbb_foo': foo,
            'aaa_bbb_bar': bar,
            'aaa_foopop': foo,
            'aaa_barpop': bar,
            'foopop': foo,
            'barpop': bar,
        }
        assert self.tabular.as_dict == correct

    def test_as_primitive_dict(self):
        stuff = {
            'int': int(1),
            'float': float(2.0),
            'bool': bool(True),
            'str': str('Hello, world!'),
            'dict': dict(foo='bar'),
        }
        for k, v in stuff.items():
            self.tabular.record(k, v)

        correct = {
            'int': int(1),
            'float': float(2.0),
            'bool': bool(True),
            'str': str('Hello, world!'),
        }
        assert self.tabular.as_primitive_dict == correct