示例#1
0
def test_state_attributes():
    dataloader = [1, 2, 3]
    trainer = Trainer(MagicMock(return_value=1))
    state = trainer.run(dataloader, max_epochs=3)

    assert state.iteration == 9
    assert state.output == 1
    assert state.batch == 3
    assert state.dataloader == dataloader
    assert state.epoch == 3
    assert state.max_epochs == 3
示例#2
0
def test_custom_exception_handler():
    value_error = ValueError()
    training_update_function = MagicMock(side_effect=value_error)

    trainer = Trainer(training_update_function)
    exception_handler = MagicMock()
    trainer.add_event_handler(Events.EXCEPTION_RAISED, exception_handler)
    state = trainer.run([1])

    # only one call from _run_once_over_data, since the exception is swallowed
    exception_handler.assert_has_calls([call(trainer, state, value_error)])
示例#3
0
def test_training_iteration_events_are_fired():
    max_epochs = 5
    num_batches = 3
    data = _create_mock_data_loader(max_epochs, num_batches)

    trainer = Trainer(MagicMock(return_value=1))

    mock_manager = Mock()
    iteration_started = Mock()
    trainer.add_event_handler(Events.ITERATION_STARTED, iteration_started)

    iteration_complete = Mock()
    trainer.add_event_handler(Events.ITERATION_COMPLETED, iteration_complete)

    mock_manager.attach_mock(iteration_started, 'iteration_started')
    mock_manager.attach_mock(iteration_complete, 'iteration_complete')

    trainer.run(data, max_epochs=max_epochs)

    assert iteration_started.call_count == num_batches * max_epochs
    assert iteration_complete.call_count == num_batches * max_epochs

    expected_calls = []
    for i in range(max_epochs * num_batches):
        expected_calls.append(call.iteration_started(trainer))
        expected_calls.append(call.iteration_complete(trainer))

    assert mock_manager.mock_calls == expected_calls
示例#4
0
def test_terminate_stops_training_mid_epoch():
    num_iterations_per_epoch = 10
    iteration_to_stop = num_iterations_per_epoch + 3  # i.e. part way through the 3rd epoch
    trainer = Trainer(MagicMock(return_value=1))

    def start_of_iteration_handler(trainer, state):
        if state.iteration == iteration_to_stop:
            trainer.terminate()

    trainer.add_event_handler(Events.ITERATION_STARTED,
                              start_of_iteration_handler)
    state = trainer.run(data=[None] * num_iterations_per_epoch, max_epochs=3)
    # completes the iteration but doesn't increment counter (this happens just before a new iteration starts)
    assert (state.iteration == iteration_to_stop)
    assert state.epoch == np.ceil(iteration_to_stop /
                                  num_iterations_per_epoch)  # it starts from 0
示例#5
0
def test_current_epoch_counter_increases_every_epoch():
    trainer = Trainer(MagicMock(return_value=1))
    max_epochs = 5

    class EpochCounter(object):
        def __init__(self):
            self.current_epoch_count = 1

        def __call__(self, trainer, state):
            assert state.epoch == self.current_epoch_count
            self.current_epoch_count += 1

    trainer.add_event_handler(Events.EPOCH_STARTED, EpochCounter())

    state = trainer.run([1], max_epochs=max_epochs)

    assert state.epoch == max_epochs
示例#6
0
def test_terminate_at_end_of_epoch_stops_training():
    max_epochs = 5
    last_epoch_to_run = 3

    trainer = Trainer(MagicMock(return_value=1))

    def end_of_epoch_handler(trainer, state):
        if state.epoch == last_epoch_to_run:
            trainer.terminate()

    trainer.add_event_handler(Events.EPOCH_COMPLETED, end_of_epoch_handler)

    assert not trainer.should_terminate

    state = trainer.run([1], max_epochs=max_epochs)

    assert state.epoch == last_epoch_to_run
    assert trainer.should_terminate
示例#7
0
def test_current_iteration_counter_increases_every_iteration():
    training_batches = [1, 2, 3]
    trainer = Trainer(MagicMock(return_value=1))
    max_epochs = 5

    class IterationCounter(object):
        def __init__(self):
            self.current_iteration_count = 1

        def __call__(self, trainer, state):
            assert state.iteration == self.current_iteration_count
            self.current_iteration_count += 1

    trainer.add_event_handler(Events.ITERATION_STARTED, IterationCounter())

    state = trainer.run(training_batches, max_epochs=max_epochs)

    assert state.iteration == max_epochs * len(training_batches)
示例#8
0
def test_timer():
    sleep_t = 0.2
    n_iter = 3

    def _train_func(engine, batch):
        time.sleep(sleep_t)

    def _test_func(engine, batch):
        time.sleep(sleep_t)

    trainer = Trainer(_train_func)
    tester = Evaluator(_test_func)

    t_total = Timer()
    t_batch = Timer(average=True)
    t_train = Timer()

    t_total.attach(trainer)
    t_batch.attach(trainer,
                   pause=Events.ITERATION_COMPLETED,
                   resume=Events.ITERATION_STARTED,
                   step=Events.ITERATION_COMPLETED)
    t_train.attach(trainer,
                   pause=Events.EPOCH_COMPLETED,
                   resume=Events.EPOCH_STARTED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def run_validation(trainer):
        tester.run(range(n_iter))

    # Run "training"
    trainer.run(range(n_iter))

    def _equal(lhs, rhs):
        return round(lhs, 1) == round(rhs, 1)

    assert _equal(t_total.value(), (2 * n_iter * sleep_t))
    assert _equal(t_batch.value(), (sleep_t))
    assert _equal(t_train.value(), (n_iter * sleep_t))

    t_total.reset()
    assert _equal(t_total.value(), 0.0)
示例#9
0
def test_terminate_at_start_of_epoch_stops_training_after_completing_iteration(
):
    max_epochs = 5
    epoch_to_terminate_on = 3
    batches_per_epoch = [1, 2, 3]

    trainer = Trainer(MagicMock(return_value=1))

    def start_of_epoch_handler(trainer, state):
        if state.epoch == epoch_to_terminate_on:
            trainer.terminate()

    trainer.add_event_handler(Events.EPOCH_STARTED, start_of_epoch_handler)

    assert not trainer.should_terminate

    state = trainer.run(batches_per_epoch, max_epochs=max_epochs)

    # epoch is not completed so counter is not incremented
    assert state.epoch == epoch_to_terminate_on
    assert trainer.should_terminate
    # completes first iteration
    assert state.iteration == (
        (epoch_to_terminate_on - 1) * len(batches_per_epoch)) + 1
示例#10
0
def test_with_trainer(dirname):
    def update_fn(batch):
        pass

    name = 'model'
    trainer = Trainer(update_fn)
    handler = ModelCheckpoint(dirname,
                              _PREFIX,
                              create_dir=False,
                              n_saved=2,
                              save_interval=1)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {name: 42})
    trainer.run([0], max_epochs=4)

    expected = ['{}_{}_{}.pth'.format(_PREFIX, name, i) for i in [3, 4]]

    assert sorted(os.listdir(dirname)) == expected
示例#11
0
def test_default_exception_handler():
    training_update_function = MagicMock(side_effect=ValueError())
    trainer = Trainer(training_update_function)

    with raises(ValueError):
        trainer.run([1])
示例#12
0
def test_returns_state():
    trainer = Trainer(MagicMock(return_value=1))
    state = trainer.run([])

    assert isinstance(state, State)
示例#13
0
def test_stopping_criterion_is_max_epochs():
    trainer = Trainer(MagicMock(return_value=1))
    max_epochs = 5
    state = trainer.run([1], max_epochs=max_epochs)
    assert state.epoch == max_epochs
示例#14
0
def test_creates_state():
    trainer = Trainer(MagicMock(return_value=1))
    trainer.run([])

    assert isinstance(trainer.state, State)