示例#1
0
def test_training_iteration_events_are_fired():
    max_epochs = 5
    num_batches = 3
    data = _create_mock_data_loader(max_epochs, num_batches)

    trainer = Trainer(MagicMock(return_value=1))

    mock_manager = Mock()
    iteration_started = Mock()
    trainer.add_event_handler(Events.ITERATION_STARTED, iteration_started)

    iteration_complete = Mock()
    trainer.add_event_handler(Events.ITERATION_COMPLETED, iteration_complete)

    mock_manager.attach_mock(iteration_started, 'iteration_started')
    mock_manager.attach_mock(iteration_complete, 'iteration_complete')

    state = trainer.run(data, max_epochs=max_epochs)

    assert iteration_started.call_count == num_batches * max_epochs
    assert iteration_complete.call_count == num_batches * max_epochs

    expected_calls = []
    for i in range(max_epochs * num_batches):
        expected_calls.append(call.iteration_started(trainer, state))
        expected_calls.append(call.iteration_complete(trainer, state))

    assert mock_manager.mock_calls == expected_calls
示例#2
0
def test_custom_exception_handler():
    value_error = ValueError()
    training_update_function = MagicMock(side_effect=value_error)

    trainer = Trainer(training_update_function)
    exception_handler = MagicMock()
    trainer.add_event_handler(Events.EXCEPTION_RAISED, exception_handler)
    state = trainer.run([1])

    # only one call from _run_once_over_data, since the exception is swallowed
    exception_handler.assert_has_calls([call(trainer, state, value_error)])
示例#3
0
def test_terminate_stops_training_mid_epoch():
    num_iterations_per_epoch = 10
    iteration_to_stop = num_iterations_per_epoch + 3  # i.e. part way through the 3rd epoch
    trainer = Trainer(MagicMock(return_value=1))

    def start_of_iteration_handler(trainer):
        if trainer.state.iteration == iteration_to_stop:
            trainer.terminate()

    trainer.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler)
    trainer.run(data=[None] * num_iterations_per_epoch, max_epochs=3)
    # completes the iteration but doesn't increment counter (this happens just before a new iteration starts)
    assert (trainer.state.iteration == iteration_to_stop)
    assert trainer.state.epoch == np.ceil(iteration_to_stop / num_iterations_per_epoch)  # it starts from 0
示例#4
0
def test_current_epoch_counter_increases_every_epoch():
    trainer = Trainer(MagicMock(return_value=1))
    max_epochs = 5

    class EpochCounter(object):
        def __init__(self):
            self.current_epoch_count = 1

        def __call__(self, trainer, state):
            assert state.epoch == self.current_epoch_count
            self.current_epoch_count += 1

    trainer.add_event_handler(Events.EPOCH_STARTED, EpochCounter())

    state = trainer.run([1], max_epochs=max_epochs)

    assert state.epoch == max_epochs
示例#5
0
def test_terminate_at_end_of_epoch_stops_training():
    max_epochs = 5
    last_epoch_to_run = 3

    trainer = Trainer(MagicMock(return_value=1))

    def end_of_epoch_handler(trainer, state):
        if state.epoch == last_epoch_to_run:
            trainer.terminate()

    trainer.add_event_handler(Events.EPOCH_COMPLETED, end_of_epoch_handler)

    assert not trainer.should_terminate

    state = trainer.run([1], max_epochs=max_epochs)

    assert state.epoch == last_epoch_to_run
    assert trainer.should_terminate
示例#6
0
def test_current_iteration_counter_increases_every_iteration():
    training_batches = [1, 2, 3]
    trainer = Trainer(MagicMock(return_value=1))
    max_epochs = 5

    class IterationCounter(object):
        def __init__(self):
            self.current_iteration_count = 1

        def __call__(self, trainer, state):
            assert state.iteration == self.current_iteration_count
            self.current_iteration_count += 1

    trainer.add_event_handler(Events.ITERATION_STARTED, IterationCounter())

    state = trainer.run(training_batches, max_epochs=max_epochs)

    assert state.iteration == max_epochs * len(training_batches)
示例#7
0
def test_with_trainer(dirname):
    def update_fn(batch):
        pass

    name = 'model'
    trainer = Trainer(update_fn)
    handler = ModelCheckpoint(dirname,
                              _PREFIX,
                              create_dir=False,
                              n_saved=2,
                              save_interval=1)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {name: 42})
    trainer.run([0], max_epochs=4)

    expected = ['{}_{}_{}.pth'.format(_PREFIX, name, i) for i in [3, 4]]

    assert sorted(os.listdir(dirname)) == expected
示例#8
0
def test_terminate_at_start_of_epoch_stops_training_after_completing_iteration():
    max_epochs = 5
    epoch_to_terminate_on = 3
    batches_per_epoch = [1, 2, 3]

    trainer = Trainer(MagicMock(return_value=1))

    def start_of_epoch_handler(trainer):
        if trainer.state.epoch == epoch_to_terminate_on:
            trainer.terminate()

    trainer.add_event_handler(Events.EPOCH_STARTED, start_of_epoch_handler)

    assert not trainer.should_terminate

    trainer.run(batches_per_epoch, max_epochs=max_epochs)

    # epoch is not completed so counter is not incremented
    assert trainer.state.epoch == epoch_to_terminate_on
    assert trainer.should_terminate
    # completes first iteration
    assert trainer.state.iteration == ((epoch_to_terminate_on - 1) * len(batches_per_epoch)) + 1