コード例 #1
0
ファイル: test_training.py プロジェクト: treiden/blocks
def test_save_the_best():
    skip_if_configuration_set("log_backend", "sqlite", "Known to be flaky with SQLite log backend.")
    with NamedTemporaryFile(dir=config.temp_dir) as dst, NamedTemporaryFile(dir=config.temp_dir) as dst_best:
        track_cost = TrackTheBest("cost", after_epoch=False, after_batch=True)
        main_loop = MockMainLoop(
            extensions=[
                FinishAfter(after_n_epochs=1),
                WriteCostExtension(),
                track_cost,
                Checkpoint(dst.name, after_batch=True, save_separately=["log"]).add_condition(
                    ["after_batch"], OnLogRecord(track_cost.notification_name), (dst_best.name,)
                ),
            ]
        )
        main_loop.run()

        assert main_loop.log[4]["saved_to"] == (dst.name, dst_best.name)
        assert main_loop.log[5]["saved_to"] == (dst.name, dst_best.name)
        assert main_loop.log[6]["saved_to"] == (dst.name,)
        with open(dst_best.name, "rb") as src:
            assert load(src).log.status["iterations_done"] == 5
        root, ext = os.path.splitext(dst_best.name)
        log_path = root + "_log" + ext
        with open(log_path, "rb") as src:
            assert cPickle.load(src).status["iterations_done"] == 5
コード例 #2
0
def test_save_the_best():
    skip_if_configuration_set('log_backend', 'sqlite',
                              "Known to be flaky with SQLite log backend.")
    with NamedTemporaryFile(dir=config.temp_dir) as dst,\
            NamedTemporaryFile(dir=config.temp_dir) as dst_best:
        track_cost = TrackTheBest("cost", after_epoch=False, after_batch=True)
        main_loop = MockMainLoop(extensions=[
            FinishAfter(after_n_epochs=1),
            WriteCostExtension(), track_cost,
            Checkpoint(dst.name, after_batch=True, save_separately=['log']).
            add_condition(["after_batch"],
                          OnLogRecord(track_cost.notification_name), (
                              dst_best.name, ))
        ])
        main_loop.run()

        assert main_loop.log[4]['saved_to'] == (dst.name, dst_best.name)
        assert main_loop.log[5]['saved_to'] == (dst.name, dst_best.name)
        assert main_loop.log[6]['saved_to'] == (dst.name, )
        with open(dst_best.name, 'rb') as src:
            assert load(src).log.status['iterations_done'] == 5
        root, ext = os.path.splitext(dst_best.name)
        log_path = root + "_log" + ext
        with open(log_path, 'rb') as src:
            assert cPickle.load(src).status['iterations_done'] == 5
コード例 #3
0
ファイル: test_saveload.py プロジェクト: pdsujnow/blocks
 def test_load_log_and_iteration_state(self):
     """Check we can save the log and iteration state separately."""
     skip_if_configuration_set("log_backend", "sqlite", 'Bug with log.status["resumed_from"]')
     new_main_loop = MainLoop(
         model=self.model,
         data_stream=self.data_stream,
         algorithm=self.algorithm,
         extensions=[Load("myweirdmodel.tar", True, True)],
     )
     new_main_loop.extensions[0].main_loop = new_main_loop
     new_main_loop._run_extensions("before_training")
     # Check the log
     new_keys = sorted(new_main_loop.log.status.keys())
     old_keys = sorted(self.main_loop.log.status.keys())
     for new_key, old_key in zip(new_keys, old_keys):
         assert new_key == old_key
         assert new_main_loop.log.status[new_key] == self.main_loop.log.status[old_key]
     # Check the iteration state
     new = next(new_main_loop.iteration_state[1])["data"]
     old = next(self.main_loop.iteration_state[1])["data"]
     assert_allclose(new, old)
コード例 #4
0
 def test_load_log_and_iteration_state(self):
     """Check we can save the log and iteration state separately."""
     skip_if_configuration_set('log_backend', 'sqlite',
                               'Bug with log.status["resumed_from"]')
     new_main_loop = MainLoop(
         model=self.model,
         data_stream=self.data_stream,
         algorithm=self.algorithm,
         extensions=[Load('myweirdmodel.tar', True, True)])
     new_main_loop.extensions[0].main_loop = new_main_loop
     new_main_loop._run_extensions('before_training')
     # Check the log
     new_keys = sorted(new_main_loop.log.status.keys())
     old_keys = sorted(self.main_loop.log.status.keys())
     for new_key, old_key in zip(new_keys, old_keys):
         assert new_key == old_key
         assert (new_main_loop.log.status[new_key] ==
                 self.main_loop.log.status[old_key])
     # Check the iteration state
     new = next(new_main_loop.iteration_state[1])['data']
     old = next(self.main_loop.iteration_state[1])['data']
     assert_allclose(new, old)
コード例 #5
0
ファイル: test_training.py プロジェクト: AdityoSanjaya/blocks
def test_save_the_best():
    skip_if_configuration_set('log_backend', 'sqlite',
                              "Known to be flaky with SQLite log backend.")
    with NamedTemporaryFile(dir=config.temp_dir) as dst,\
            NamedTemporaryFile(dir=config.temp_dir) as dst_best:
        track_cost = TrackTheBest("cost", after_epoch=False, after_batch=True)
        main_loop = MockMainLoop(
            extensions=[FinishAfter(after_n_epochs=1),
                        WriteCostExtension(),
                        track_cost,
                        Checkpoint(dst.name, after_batch=True,
                                   save_separately=['log'])
                        .add_condition(
                            ["after_batch"],
                            OnLogRecord(track_cost.notification_name),
                            (dst_best.name,))])
        main_loop.run()

        assert main_loop.log[4]['saved_to'] == (dst.name, dst_best.name)
        assert main_loop.log[5]['saved_to'] == (dst.name, dst_best.name)
        assert main_loop.log[6]['saved_to'] == (dst.name,)
        with open(dst_best.name, 'rb') as src:
            assert load(src).log.status['iterations_done'] == 5