def test_save_the_best(): with NamedTemporaryFile(dir=config.temp_dir) as dst,\ NamedTemporaryFile(dir=config.temp_dir) as dst_best: track_cost = TrackTheBest("cost", after_epoch=False, after_batch=True) main_loop = MockMainLoop( extensions=[FinishAfter(after_n_epochs=1), WriteCostExtension(), track_cost, Checkpoint(dst.name, after_batch=True, save_separately=['log']) .add_condition( ["after_batch"], OnLogRecord(track_cost.notification_name), (dst_best.name,))]) main_loop.run() assert main_loop.log[4]['saved_to'] == (dst.name, dst_best.name) assert main_loop.log[5]['saved_to'] == (dst.name, dst_best.name) assert main_loop.log[6]['saved_to'] == (dst.name,) with open(dst_best.name, 'rb') as src: assert load(src).log.status['iterations_done'] == 5 root, ext = os.path.splitext(dst_best.name) log_path = root + "_log" + ext with open(log_path, 'rb') as src: assert cPickle.load(src).status['iterations_done'] == 5
def test_save_the_best(): with NamedTemporaryFile() as dst,\ NamedTemporaryFile() as dst_best: track_cost = TrackTheBest("cost") main_loop = MockMainLoop( extensions=[FinishAfter(after_n_epochs=1), WriteCostExtension(), track_cost, Checkpoint(dst.name, after_batch=True, save_separately=['log']) .add_condition( "after_batch", OnLogRecord(track_cost.notification_name), (dst_best.name,))]) main_loop.run() assert main_loop.log[4]['saved_to'] == (dst.name, dst_best.name) assert main_loop.log[5]['saved_to'] == (dst.name, dst_best.name) assert main_loop.log[6]['saved_to'] == (dst.name,) with open(dst_best.name, 'rb') as src: assert cPickle.load(src).log.status['iterations_done'] == 5 root, ext = os.path.splitext(dst_best.name) log_path = root + "_log" + ext with open(log_path, 'rb') as src: assert cPickle.load(src).status['iterations_done'] == 5
def test_error(): ext = TrainingExtension() ext.after_batch = MagicMock(side_effect=KeyError) ext.on_error = MagicMock() main_loop = MockMainLoop(extensions=[ext, FinishAfter(after_epoch=True)]) assert_raises(KeyError, main_loop.run) ext.on_error.assert_called_once_with() assert 'got_exception' in main_loop.log.current_row ext.on_error = MagicMock(side_effect=AttributeError) main_loop = MockMainLoop(extensions=[ext, FinishAfter(after_epoch=True)]) assert_raises(KeyError, main_loop.run) ext.on_error.assert_called_once_with() assert 'got_exception' in main_loop.log.current_row
def test_track_the_best(): main_loop = MockMainLoop() extension = TrackTheBest("cost") extension.main_loop = main_loop main_loop.status['epochs_done'] += 1 main_loop.status['iterations_done'] += 10 main_loop.log.current_row['cost'] = 5 extension.dispatch('after_epoch') assert main_loop.status['best_cost'] == 5 assert main_loop.log.current_row['cost_best_so_far'] main_loop.status['epochs_done'] += 1 main_loop.status['iterations_done'] += 10 main_loop.log.current_row['cost'] = 6 extension.dispatch('after_epoch') assert main_loop.status['best_cost'] == 5 assert main_loop.log.current_row.get('cost_best_so_far', None) is None main_loop.status['epochs_done'] += 1 main_loop.status['iterations_done'] += 10 main_loop.log.current_row['cost'] = 5 extension.dispatch('after_epoch') assert main_loop.status['best_cost'] == 5 assert main_loop.log.current_row.get('cost_best_so_far', None) is None main_loop.status['epochs_done'] += 1 main_loop.status['iterations_done'] += 10 main_loop.log.current_row['cost'] = 4 extension.dispatch('after_epoch') assert main_loop.status['best_cost'] == 4 assert main_loop.log.current_row['cost_best_so_far']
def test_track_the_best(): main_loop = MockMainLoop() extension = TrackTheBest("cost") extension.main_loop = main_loop main_loop.status.iterations_done += 1 main_loop.log.current_row.cost = 5 extension.dispatch('after_batch') assert main_loop.status.best_cost == 5 assert main_loop.log.current_row['cost_best_so_far'] main_loop.status.iterations_done += 1 main_loop.log.current_row.cost = 6 extension.dispatch('after_batch') assert main_loop.status.best_cost == 5 assert main_loop.log.current_row['cost_best_so_far'] is None main_loop.status.iterations_done += 1 main_loop.log.current_row.cost = 5 extension.dispatch('after_batch') assert main_loop.status.best_cost == 5 assert main_loop.log.current_row['cost_best_so_far'] is None main_loop.status.iterations_done += 1 main_loop.log.current_row.cost = 4 extension.dispatch('after_batch') assert main_loop.status.best_cost == 4 assert main_loop.log.current_row['cost_best_so_far']
def test_training_interrupt(): def process_batch(batch): time.sleep(0.1) algorithm = MockAlgorithm() algorithm.process_batch = process_batch main_loop = MockMainLoop( algorithm=algorithm, data_stream=IterableDataset(count()).get_example_stream(), extensions=[Printing()] ) p = Process(target=main_loop.run) p.start() time.sleep(0.1) os.kill(p.pid, signal.SIGINT) time.sleep(0.1) assert p.is_alive() os.kill(p.pid, signal.SIGINT) time.sleep(0.2) assert not p.is_alive() p.join()
def test_timing(): main_loop = MockMainLoop( extensions=[Timing(), FinishAfter(after_n_epochs=2)]) main_loop.run()
def test_timing(): main_loop = MockMainLoop(extensions=[Timing(), FinishAfter(after_n_epochs=2)]) main_loop.run()
def test_timing(): # Build the main loop now = 0 timing = Timing(lambda: now) ml = MockMainLoop() timing.main_loop = ml # Start training now += 1 timing.before_training() # Start epoch 1 now += 2 timing.before_epoch() assert ml.log[0].initialization_took == 2 ml.log.status._epoch_started = True # Batch 1 timing.before_batch(None) now += 7 ml.log.status.iterations_done += 1 timing.after_batch(None) assert ml.log[1].iteration_took == 7 # Batch 2 timing.before_batch(None) now += 8 ml.log.status.iterations_done += 1 timing.after_batch(None) assert ml.log[2].iteration_took == 8 # Epoch 1 is done ml.log.status.epochs_done += 1 timing.after_epoch() assert ml.log[2].epoch_took == 15 assert ml.log[2].total_took == 17 # Finish training now += 1 timing.after_training() assert ml.log[2].final_total_took == 18 # Resume training now = 0 timing.on_resumption() # Start epoch 2 timing.before_epoch() assert ml.log[2].initialization_took is None # Batch 3 timing.before_batch(None) now += 6 ml.log.status.iterations_done += 1 timing.after_batch(None) assert ml.log[3].iteration_took == 6 assert ml.log[3].total_took == 24 # Finish training before the end of the current epoch timing.after_training() # Resume training now = 2 timing.on_resumption() # Batch 4 timing.before_batch(None) now += 2 ml.log.status.iterations_done += 1 timing.after_batch(None) assert ml.log[4].iteration_took == 2 assert ml.log[4].total_took == 26 # Epoch 2 is done ml.log.status.epochs_done += 1 timing.after_epoch() assert ml.log[4].epoch_took == 8 # Start epoch 3 timing.before_epoch() # Batch 5 timing.before_batch(None) now += 5 ml.log.status.iterations_done += 1 timing.after_batch(None) assert ml.log[5].iteration_took == 5 assert ml.log[5].total_took == 31 # Epoch 3 is done ml.log.status.epochs_done += 1 timing.after_epoch() assert ml.log[5].epoch_took == 5