Esempio n. 1
0
def test_save_the_best():
    with NamedTemporaryFile(dir=config.temp_dir) as dst,\
            NamedTemporaryFile(dir=config.temp_dir) as dst_best:
        track_cost = TrackTheBest("cost", after_epoch=False, after_batch=True)
        main_loop = MockMainLoop(
            extensions=[FinishAfter(after_n_epochs=1),
                        WriteCostExtension(),
                        track_cost,
                        Checkpoint(dst.name, after_batch=True,
                                   save_separately=['log'])
                        .add_condition(
                            ["after_batch"],
                            OnLogRecord(track_cost.notification_name),
                            (dst_best.name,))])
        main_loop.run()

        assert main_loop.log[4]['saved_to'] == (dst.name, dst_best.name)
        assert main_loop.log[5]['saved_to'] == (dst.name, dst_best.name)
        assert main_loop.log[6]['saved_to'] == (dst.name,)
        with open(dst_best.name, 'rb') as src:
            assert load(src).log.status['iterations_done'] == 5
        root, ext = os.path.splitext(dst_best.name)
        log_path = root + "_log" + ext
        with open(log_path, 'rb') as src:
            assert cPickle.load(src).status['iterations_done'] == 5
Esempio n. 2
0
def test_save_the_best():
    with NamedTemporaryFile() as dst,\
            NamedTemporaryFile() as dst_best:
        track_cost = TrackTheBest("cost")
        main_loop = MockMainLoop(
            extensions=[FinishAfter(after_n_epochs=1),
                        WriteCostExtension(),
                        track_cost,
                        Checkpoint(dst.name, after_batch=True,
                                   save_separately=['log'])
                        .add_condition(
                            "after_batch",
                            OnLogRecord(track_cost.notification_name),
                            (dst_best.name,))])
        main_loop.run()

        assert main_loop.log[4]['saved_to'] == (dst.name, dst_best.name)
        assert main_loop.log[5]['saved_to'] == (dst.name, dst_best.name)
        assert main_loop.log[6]['saved_to'] == (dst.name,)
        with open(dst_best.name, 'rb') as src:
            assert cPickle.load(src).log.status['iterations_done'] == 5
        root, ext = os.path.splitext(dst_best.name)
        log_path = root + "_log" + ext
        with open(log_path, 'rb') as src:
            assert cPickle.load(src).status['iterations_done'] == 5
Esempio n. 3
0
def test_error():
    ext = TrainingExtension()
    ext.after_batch = MagicMock(side_effect=KeyError)
    ext.on_error = MagicMock()
    main_loop = MockMainLoop(extensions=[ext, FinishAfter(after_epoch=True)])
    assert_raises(KeyError, main_loop.run)
    ext.on_error.assert_called_once_with()
    assert 'got_exception' in main_loop.log.current_row

    ext.on_error = MagicMock(side_effect=AttributeError)
    main_loop = MockMainLoop(extensions=[ext, FinishAfter(after_epoch=True)])
    assert_raises(KeyError, main_loop.run)
    ext.on_error.assert_called_once_with()
    assert 'got_exception' in main_loop.log.current_row
Esempio n. 4
0
def test_track_the_best():
    main_loop = MockMainLoop()
    extension = TrackTheBest("cost")
    extension.main_loop = main_loop

    main_loop.status['epochs_done'] += 1
    main_loop.status['iterations_done'] += 10
    main_loop.log.current_row['cost'] = 5
    extension.dispatch('after_epoch')
    assert main_loop.status['best_cost'] == 5
    assert main_loop.log.current_row['cost_best_so_far']

    main_loop.status['epochs_done'] += 1
    main_loop.status['iterations_done'] += 10
    main_loop.log.current_row['cost'] = 6
    extension.dispatch('after_epoch')
    assert main_loop.status['best_cost'] == 5
    assert main_loop.log.current_row.get('cost_best_so_far', None) is None

    main_loop.status['epochs_done'] += 1
    main_loop.status['iterations_done'] += 10
    main_loop.log.current_row['cost'] = 5
    extension.dispatch('after_epoch')
    assert main_loop.status['best_cost'] == 5
    assert main_loop.log.current_row.get('cost_best_so_far', None) is None

    main_loop.status['epochs_done'] += 1
    main_loop.status['iterations_done'] += 10
    main_loop.log.current_row['cost'] = 4
    extension.dispatch('after_epoch')
    assert main_loop.status['best_cost'] == 4
    assert main_loop.log.current_row['cost_best_so_far']
Esempio n. 5
0
def test_track_the_best():
    main_loop = MockMainLoop()
    extension = TrackTheBest("cost")
    extension.main_loop = main_loop

    main_loop.status.iterations_done += 1
    main_loop.log.current_row.cost = 5
    extension.dispatch('after_batch')
    assert main_loop.status.best_cost == 5
    assert main_loop.log.current_row['cost_best_so_far']

    main_loop.status.iterations_done += 1
    main_loop.log.current_row.cost = 6
    extension.dispatch('after_batch')
    assert main_loop.status.best_cost == 5
    assert main_loop.log.current_row['cost_best_so_far'] is None

    main_loop.status.iterations_done += 1
    main_loop.log.current_row.cost = 5
    extension.dispatch('after_batch')
    assert main_loop.status.best_cost == 5
    assert main_loop.log.current_row['cost_best_so_far'] is None

    main_loop.status.iterations_done += 1
    main_loop.log.current_row.cost = 4
    extension.dispatch('after_batch')
    assert main_loop.status.best_cost == 4
    assert main_loop.log.current_row['cost_best_so_far']
Esempio n. 6
0
def test_training_interrupt():
    def process_batch(batch):
        time.sleep(0.1)

    algorithm = MockAlgorithm()
    algorithm.process_batch = process_batch

    main_loop = MockMainLoop(
        algorithm=algorithm,
        data_stream=IterableDataset(count()).get_example_stream(),
        extensions=[Printing()]
    )

    p = Process(target=main_loop.run)
    p.start()
    time.sleep(0.1)
    os.kill(p.pid, signal.SIGINT)
    time.sleep(0.1)
    assert p.is_alive()
    os.kill(p.pid, signal.SIGINT)
    time.sleep(0.2)
    assert not p.is_alive()
    p.join()
Esempio n. 7
0
def test_timing():
    main_loop = MockMainLoop(
        extensions=[Timing(), FinishAfter(after_n_epochs=2)])
    main_loop.run()
Esempio n. 8
0
def test_timing():
    main_loop = MockMainLoop(extensions=[Timing(),
                                         FinishAfter(after_n_epochs=2)])
    main_loop.run()
Esempio n. 9
0
def test_timing():
    # Build the main loop
    now = 0
    timing = Timing(lambda: now)
    ml = MockMainLoop()
    timing.main_loop = ml

    # Start training
    now += 1
    timing.before_training()

    # Start epoch 1
    now += 2
    timing.before_epoch()
    assert ml.log[0].initialization_took == 2
    ml.log.status._epoch_started = True

    # Batch 1
    timing.before_batch(None)
    now += 7
    ml.log.status.iterations_done += 1
    timing.after_batch(None)
    assert ml.log[1].iteration_took == 7

    # Batch 2
    timing.before_batch(None)
    now += 8
    ml.log.status.iterations_done += 1
    timing.after_batch(None)
    assert ml.log[2].iteration_took == 8

    # Epoch 1 is done
    ml.log.status.epochs_done += 1
    timing.after_epoch()
    assert ml.log[2].epoch_took == 15
    assert ml.log[2].total_took == 17

    # Finish training
    now += 1
    timing.after_training()
    assert ml.log[2].final_total_took == 18

    # Resume training
    now = 0
    timing.on_resumption()

    # Start epoch 2
    timing.before_epoch()
    assert ml.log[2].initialization_took is None

    # Batch 3
    timing.before_batch(None)
    now += 6
    ml.log.status.iterations_done += 1
    timing.after_batch(None)
    assert ml.log[3].iteration_took == 6
    assert ml.log[3].total_took == 24

    # Finish training before the end of the current epoch
    timing.after_training()

    # Resume training
    now = 2
    timing.on_resumption()

    # Batch 4
    timing.before_batch(None)
    now += 2
    ml.log.status.iterations_done += 1
    timing.after_batch(None)
    assert ml.log[4].iteration_took == 2
    assert ml.log[4].total_took == 26

    # Epoch 2 is done
    ml.log.status.epochs_done += 1
    timing.after_epoch()
    assert ml.log[4].epoch_took == 8

    # Start epoch 3
    timing.before_epoch()

    # Batch 5
    timing.before_batch(None)
    now += 5
    ml.log.status.iterations_done += 1
    timing.after_batch(None)
    assert ml.log[5].iteration_took == 5
    assert ml.log[5].total_took == 31

    # Epoch 3 is done
    ml.log.status.epochs_done += 1
    timing.after_epoch()
    assert ml.log[5].epoch_took == 5