Пример #1
0
def test_validate_not_called_if_validate_every_epoch_is_false():
    trainer = Trainer([1], MagicMock(return_value=1), MagicMock(), MagicMock())
    trainer.validate = MagicMock()

    max_epochs = 5
    trainer.run(max_epochs=max_epochs, validate_every_epoch=False)
    assert trainer.validate.call_count == 0
Пример #2
0
def test_current_validation_iteration_counter_increases_every_iteration():
    validation_batches = [1, 2, 3]
    trainer = Trainer(MagicMock(return_value=1), MagicMock(return_value=1))
    max_epochs = 5

    class IterationCounter(object):
        def __init__(self):
            self.current_iteration_count = 0
            self.total_count = 0

        def __call__(self, trainer):
            assert trainer.current_validation_iteration == self.current_iteration_count
            self.current_iteration_count += 1
            self.total_count += 1

        def clear(self):
            self.current_iteration_count = 0

    iteration_counter = IterationCounter()

    def clear_counter(trainer, counter):
        counter.clear()

    trainer.add_event_handler(TrainingEvents.TRAINING_EPOCH_COMPLETED,
                              _validate, validation_batches)
    trainer.add_event_handler(TrainingEvents.VALIDATION_STARTING,
                              clear_counter, iteration_counter)
    trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED,
                              iteration_counter)

    trainer.run([1], max_epochs=max_epochs)
    assert iteration_counter.total_count == max_epochs * len(
        validation_batches)
Пример #3
0
def test_validate_is_called_every_epoch_by_default():
    trainer = Trainer([1], MagicMock(return_value=1), [1], MagicMock())
    trainer.validate = MagicMock()

    max_epochs = 5
    trainer.run(max_epochs=max_epochs)
    assert trainer.validate.call_count == max_epochs
Пример #4
0
def test_on_decorator():
    max_epochs = 5
    num_batches = 3
    training_data = _create_mock_data_loader(max_epochs, num_batches)

    trainer = Trainer(MagicMock(return_value=1), MagicMock())

    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    started_counter = Counter()

    @trainer.on(TrainingEvents.TRAINING_ITERATION_STARTED, started_counter)
    def handle_training_iteration_started(trainer, started_counter):
        started_counter.count += 1

    completed_counter = Counter()

    @trainer.on(TrainingEvents.TRAINING_ITERATION_COMPLETED, completed_counter)
    def handle_training_iteration_completed(trainer, completed_counter):
        completed_counter.count += 1

    trainer.run(training_data, max_epochs=max_epochs)

    assert started_counter.count == 15
    assert completed_counter.count == 15
Пример #5
0
def test_validation_iteration_events_are_fired():
    max_epochs = 5
    num_batches = 3
    validation_data = _create_mock_data_loader(max_epochs, num_batches)

    trainer = Trainer(MagicMock(return_value=1), MagicMock(return_value=1))

    mock_manager = Mock()
    iteration_started = Mock()
    trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED,
                              iteration_started)

    iteration_complete = Mock()
    trainer.add_event_handler(TrainingEvents.TRAINING_EPOCH_COMPLETED,
                              _validate, validation_data)
    trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_COMPLETED,
                              iteration_complete)

    mock_manager.attach_mock(iteration_started, 'iteration_started')
    mock_manager.attach_mock(iteration_complete, 'iteration_complete')

    trainer.run([None], max_epochs=max_epochs)

    assert iteration_started.call_count == num_batches * max_epochs
    assert iteration_complete.call_count == num_batches * max_epochs

    expected_calls = []
    for i in range(max_epochs * num_batches):
        expected_calls.append(call.iteration_started(trainer))
        expected_calls.append(call.iteration_complete(trainer))

    assert mock_manager.mock_calls == expected_calls
Пример #6
0
def test_adding_handler_for_non_existent_event_throws_error():
    trainer = Trainer(MagicMock(), MagicMock(), MagicMock(), MagicMock())

    event_name = uuid.uuid4()
    while event_name in TrainingEvents.__members__.values():
        event_name = uuid.uuid4()

    with raises(ValueError):
        trainer.add_event_handler(event_name, lambda x: x)
Пример #7
0
def test_custom_exception_handler():
    value_error = ValueError()
    training_update_function = MagicMock(side_effect=value_error)

    trainer = Trainer(training_update_function)
    exception_handler = MagicMock()
    trainer.add_event_handler(Events.EXCEPTION_RAISED, exception_handler)
    state = trainer.run([1])

    # only one call from _run_once_over_data, since the exception is swallowed
    exception_handler.assert_has_calls([call(trainer, state, value_error)])
Пример #8
0
def test_state_attributes():
    dataloader = [1, 2, 3]
    trainer = Trainer(MagicMock(return_value=1))
    state = trainer.run(dataloader, max_epochs=3)

    assert state.iteration == 9
    assert state.output == 1
    assert state.batch == 3
    assert state.dataloader == dataloader
    assert state.epoch == 3
    assert state.max_epochs == 3
Пример #9
0
def test_best_loss_updates_if_no_validation_loss_and_training_loss_reduces(
        training_update_loss_per_batch, batches_per_epoch,
        expected_number_of_updates, validate_every_epoch):
    max_epochs = len(training_update_loss_per_batch) // batches_per_epoch

    training_data = _create_mock_data_loader_manager(max_epochs,
                                                     batches_per_epoch)

    def check_loss_update(trainer):
        best_training_loss = trainer.best_training_loss
        last_training_loss = trainer.avg_training_loss_per_epoch[-1]
        assert (best_training_loss == last_training_loss).all()

    trainer = Trainer(training_data=training_data,
                      training_update_function=MagicMock(
                          side_effect=training_update_loss_per_batch))

    trainer.add_event_listener(TrainingEvents.BEST_LOSS_UPDATED,
                               check_loss_update)

    counter = MagicMock()
    trainer.add_event_listener(TrainingEvents.BEST_LOSS_UPDATED, counter)

    trainer.run(max_epochs=max_epochs,
                validate_every_epoch=validate_every_epoch)

    assert counter.call_count == expected_number_of_updates
Пример #10
0
def test_validation_iteration_events_are_fired_when_validate_is_called_explicitly(
):
    max_epochs = 5
    num_batches = 3
    validation_data = _create_mock_data_loader(max_epochs, num_batches)

    trainer = Trainer(training_data=[None],
                      validation_data=validation_data,
                      training_update_function=MagicMock(),
                      validation_inference_function=MagicMock(return_value=1))

    mock_manager = Mock()
    iteration_started = Mock()
    trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED,
                              iteration_started)

    iteration_complete = Mock()
    trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_COMPLETED,
                              iteration_complete)

    mock_manager.attach_mock(iteration_started, 'iteration_started')
    mock_manager.attach_mock(iteration_complete, 'iteration_complete')

    assert iteration_started.call_count == 0
    assert iteration_complete.call_count == 0

    trainer.validate()

    assert iteration_started.call_count == num_batches
    assert iteration_complete.call_count == num_batches
Пример #11
0
def test_training_iteration_events_are_fired():
    max_epochs = 5
    num_batches = 3
    training_data = _create_mock_data_loader(max_epochs, num_batches)

    trainer = Trainer(training_data=training_data,
                      validation_data=MagicMock(),
                      training_update_function=MagicMock(return_value=1),
                      validation_inference_function=MagicMock())

    mock_manager = Mock()
    iteration_started = Mock()
    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED,
                              iteration_started)

    iteration_complete = Mock()
    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED,
                              iteration_complete)

    mock_manager.attach_mock(iteration_started, 'iteration_started')
    mock_manager.attach_mock(iteration_complete, 'iteration_complete')

    trainer.run(max_epochs=max_epochs, validate_every_epoch=False)

    assert iteration_started.call_count == num_batches * max_epochs
    assert iteration_complete.call_count == num_batches * max_epochs

    expected_calls = []
    for i in range(max_epochs * num_batches):
        expected_calls.append(call.iteration_started(trainer))
        expected_calls.append(call.iteration_complete(trainer))

    assert mock_manager.mock_calls == expected_calls
Пример #12
0
def test_terminate_stops_training_mid_epoch():
    num_iterations_per_epoch = 10
    iteration_to_stop = num_iterations_per_epoch + 3  # i.e. part way through the 3rd epoch
    trainer = Trainer(MagicMock(return_value=1))

    def start_of_iteration_handler(trainer, state):
        if state.iteration == iteration_to_stop:
            trainer.terminate()

    trainer.add_event_handler(Events.ITERATION_STARTED,
                              start_of_iteration_handler)
    state = trainer.run(data=[None] * num_iterations_per_epoch, max_epochs=3)
    # completes the iteration but doesn't increment counter (this happens just before a new iteration starts)
    assert (state.iteration == iteration_to_stop)
    assert state.epoch == np.ceil(iteration_to_stop /
                                  num_iterations_per_epoch)  # it starts from 0
Пример #13
0
def test_current_epoch_counter_increases_every_epoch():
    trainer = Trainer(MagicMock(return_value=1))
    max_epochs = 5

    class EpochCounter(object):
        def __init__(self):
            self.current_epoch_count = 1

        def __call__(self, trainer, state):
            assert state.epoch == self.current_epoch_count
            self.current_epoch_count += 1

    trainer.add_event_handler(Events.EPOCH_STARTED, EpochCounter())

    state = trainer.run([1], max_epochs=max_epochs)

    assert state.epoch == max_epochs
Пример #14
0
def test_terminate_at_end_of_epoch_stops_training():
    max_epochs = 5
    last_epoch_to_run = 3

    trainer = Trainer(MagicMock(return_value=1))

    def end_of_epoch_handler(trainer, state):
        if state.epoch == last_epoch_to_run:
            trainer.terminate()

    trainer.add_event_handler(Events.EPOCH_COMPLETED, end_of_epoch_handler)

    assert not trainer.should_terminate

    state = trainer.run([1], max_epochs=max_epochs)

    assert state.epoch == last_epoch_to_run
    assert trainer.should_terminate
Пример #15
0
def test_current_iteration_counter_increases_every_iteration():
    training_batches = [1, 2, 3]
    trainer = Trainer(MagicMock(return_value=1))
    max_epochs = 5

    class IterationCounter(object):
        def __init__(self):
            self.current_iteration_count = 1

        def __call__(self, trainer, state):
            assert state.iteration == self.current_iteration_count
            self.current_iteration_count += 1

    trainer.add_event_handler(Events.ITERATION_STARTED, IterationCounter())

    state = trainer.run(training_batches, max_epochs=max_epochs)

    assert state.iteration == max_epochs * len(training_batches)
Пример #16
0
def test_avg_validation_loss_per_epoch_updates_correctly(
        validation_update_losses_per_batch, avg_loss_per_epoch,
        batches_per_epoch):
    max_epochs = len(validation_update_losses_per_batch) // batches_per_epoch

    validation_data_loader_manager = _create_mock_data_loader_manager(
        max_epochs, batches_per_epoch)

    trainer = Trainer(training_data=[1],
                      validation_data=validation_data_loader_manager,
                      training_update_function=MagicMock(return_value=1),
                      validation_inference_function=MagicMock(
                          side_effect=validation_update_losses_per_batch))

    trainer.run(max_epochs=max_epochs)

    assert np.array_equal(trainer.avg_validation_loss, avg_loss_per_epoch)
    assert np.array_equal(trainer.best_validation_loss,
                          min(avg_loss_per_epoch, key=lambda a: a[0]))
Пример #17
0
def test_avg_training_loss_per_epoch_updates_correctly(
        training_update_losses_per_batch, avg_loss_per_epoch,
        batches_per_epoch):
    max_epochs = len(training_update_losses_per_batch) // batches_per_epoch
    training_data = _create_mock_data_loader_manager(max_epochs,
                                                     batches_per_epoch)

    trainer = Trainer(training_data=training_data,
                      validation_data=MagicMock(),
                      training_update_function=MagicMock(
                          side_effect=training_update_losses_per_batch),
                      validation_inference_function=MagicMock())

    trainer.run(max_epochs=max_epochs, validate_every_epoch=False)

    assert np.array_equal(trainer.avg_training_loss_per_epoch,
                          avg_loss_per_epoch)
    assert np.array_equal(trainer.best_training_loss,
                          np.min(avg_loss_per_epoch, axis=0))
Пример #18
0
def test_timer():
    sleep_t = 0.2
    n_iter = 3

    def _train_func(batch):
        time.sleep(sleep_t)

    def _test_func(batch):
        time.sleep(sleep_t)

    trainer = Trainer(_train_func)
    tester = Evaluator(_test_func)

    t_total = Timer()
    t_batch = Timer(average=True)
    t_train = Timer()

    t_total.attach(trainer)
    t_batch.attach(trainer,
                   pause=Events.ITERATION_COMPLETED,
                   resume=Events.ITERATION_STARTED,
                   step=Events.ITERATION_COMPLETED)
    t_train.attach(trainer,
                   pause=Events.EPOCH_COMPLETED,
                   resume=Events.EPOCH_STARTED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def run_validation(trainer, state):
        tester.run(range(n_iter))

    # Run "training"
    trainer.run(range(n_iter))

    def _equal(lhs, rhs):
        return round(lhs, 1) == round(rhs, 1)

    assert _equal(t_total.value(), (2 * n_iter * sleep_t))
    assert _equal(t_batch.value(), (sleep_t))
    assert _equal(t_train.value(), (n_iter * sleep_t))

    t_total.reset()
    assert _equal(t_total.value(), 0.0)
Пример #19
0
def test_terminate_after_training_iteration_skips_validation_run():
    num_iterations_per_epoch = 10
    iteration_to_stop = num_iterations_per_epoch - 1
    trainer = Trainer(MagicMock(return_value=1), MagicMock())

    def end_of_iteration_handler(trainer):
        if trainer.current_iteration == iteration_to_stop:
            trainer.terminate()

    trainer.validate = MagicMock()

    trainer.add_event_handler(TrainingEvents.TRAINING_EPOCH_COMPLETED,
                              _validate, MagicMock())
    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED,
                              end_of_iteration_handler)
    trainer.run([None] * num_iterations_per_epoch, max_epochs=3)
    assert trainer.validate.call_count == 0
Пример #20
0
def test_best_loss_updates_if_validate_is_called_irregularly(
        num_training_batches, validate_every_epoch):
    max_epochs = 10
    num_validation_batches = 2

    def update(*args):
        return np.random.randn()

    def validate(*args):
        return np.random.randn()

    def validate_or_not(trainer):
        if trainer.current_iteration % 3 == 0:
            trainer.validate()

    trainer = Trainer(training_data=[None] * num_training_batches,
                      training_update_function=update,
                      validation_data=[None] * num_validation_batches,
                      validation_inference_function=validate)

    trainer.add_event_listener(TrainingEvents.TRAINING_ITERATION_COMPLETED,
                               validate_or_not)

    trainer.run(max_epochs=max_epochs,
                validate_every_epoch=validate_every_epoch)

    assert np.isclose(np.min(trainer.avg_training_loss_per_epoch),
                      trainer.best_training_loss)
Пример #21
0
def test_terminate_at_start_of_epoch_stops_training_after_completing_iteration(
):
    max_epochs = 5
    epoch_to_terminate_on = 3
    batches_per_epoch = [1, 2, 3]

    def start_of_epoch_handler(trainer):
        if trainer.current_epoch == epoch_to_terminate_on:
            trainer.terminate()

    trainer = Trainer(batches_per_epoch, MagicMock(return_value=1),
                      MagicMock(), MagicMock())
    trainer.add_event_handler(TrainingEvents.EPOCH_STARTED,
                              start_of_epoch_handler)

    assert not trainer.should_terminate

    trainer.run(max_epochs=max_epochs, validate_every_epoch=False)

    # epoch is not completed so counter is not incremented
    assert trainer.current_epoch == epoch_to_terminate_on
    assert trainer.should_terminate
    # completes first iteration
    assert trainer.current_iteration == (epoch_to_terminate_on *
                                         len(batches_per_epoch)) + 1
Пример #22
0
def test_adding_multiple_event_handlers():
    trainer = Trainer([1], MagicMock(return_value=1), MagicMock(), MagicMock())
    handlers = [MagicMock(), MagicMock()]
    for handler in handlers:
        trainer.add_event_handler(TrainingEvents.TRAINING_STARTED, handler)

    trainer.run(validate_every_epoch=False)
    for handler in handlers:
        handler.assert_called_once_with(trainer)
Пример #23
0
def test_terminate_at_start_of_epoch_stops_training_after_completing_iteration(
):
    max_epochs = 5
    epoch_to_terminate_on = 3
    batches_per_epoch = [1, 2, 3]

    trainer = Trainer(MagicMock(return_value=1))

    def start_of_epoch_handler(trainer, state):
        if state.epoch == epoch_to_terminate_on:
            trainer.terminate()

    trainer.add_event_handler(Events.EPOCH_STARTED, start_of_epoch_handler)

    assert not trainer.should_terminate

    state = trainer.run(batches_per_epoch, max_epochs=max_epochs)

    # epoch is not completed so counter is not incremented
    assert state.epoch == epoch_to_terminate_on
    assert trainer.should_terminate
    # completes first iteration
    assert state.iteration == (
        (epoch_to_terminate_on - 1) * len(batches_per_epoch)) + 1
Пример #24
0
def test_exception_handler_called_on_error():
  training_update_function = MagicMock(side_effect=ValueError())

  trainer = Trainer([1], training_update_function, MagicMock(), MagicMock())
  exception_handler = MagicMock()
  trainer.add_event_handler(TrainingEvents.EXCEPTION_RAISED, exception_handler)

  with raises(ValueError):
    trainer.run()

  exception_handler.assert_called_once_with(trainer)
Пример #25
0
def test_exception_handler_called_on_error():
    trainer = Trainer([1], MagicMock(return_value=None), MagicMock(),
                      MagicMock())
    exception_handler = MagicMock()
    trainer.add_event_listener(TrainingEvents.EXCEPTION_RAISED,
                               exception_handler)

    with raises(ValueError):
        trainer.run()

    exception_handler.assert_called_once_with(trainer)
Пример #26
0
def test_terminate_after_training_iteration_skips_validation_run():
  num_iterations_per_epoch = 10
  iteration_to_stop = num_iterations_per_epoch - 1
  trainer = Trainer(training_data=[None] * num_iterations_per_epoch,
                    training_update_function=MagicMock(return_value=1),
                    validation_data=MagicMock(),
                    validation_inference_function=MagicMock())

  def end_of_iteration_handler(trainer):
    if trainer.current_iteration == iteration_to_stop:
      trainer.terminate()

  trainer.validate = MagicMock()

  trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED, end_of_iteration_handler)
  trainer.run(max_epochs=3, validate_every_epoch=True)
  assert trainer.validate.call_count == 0
Пример #27
0
def test_terminate_stops_trainer_when_called_during_validation():
    num_iterations_per_epoch = 10
    iteration_to_stop = 3  # i.e. part way through the 2nd validation run
    epoch_to_stop = 2
    trainer = Trainer(MagicMock(return_value=1), MagicMock(return_value=1))

    def end_of_iteration_handler(trainer):
        if (trainer.current_epoch == epoch_to_stop
                and trainer.current_validation_iteration == iteration_to_stop):

            trainer.terminate()

    trainer.add_event_handler(TrainingEvents.TRAINING_EPOCH_COMPLETED,
                              _validate, [None] * num_iterations_per_epoch)
    trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED,
                              end_of_iteration_handler)
    trainer.run([None] * num_iterations_per_epoch, max_epochs=4)

    assert trainer.current_epoch == epoch_to_stop
    # should complete the iteration when terminate called
    assert trainer.current_validation_iteration == iteration_to_stop + 1
    assert trainer.current_iteration == (epoch_to_stop +
                                         1) * num_iterations_per_epoch
Пример #28
0
def test_terminate_stops_training_mid_epoch():
    num_iterations_per_epoch = 10
    iteration_to_stop = num_iterations_per_epoch + 3  # i.e. part way through the 2nd epoch
    trainer = Trainer(MagicMock(return_value=1), MagicMock())

    def end_of_iteration_handler(trainer):
        if trainer.current_iteration == iteration_to_stop:
            trainer.terminate()

    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED,
                              end_of_iteration_handler)
    trainer.run(training_data=[None] * num_iterations_per_epoch, max_epochs=3)
    assert (trainer.current_iteration == iteration_to_stop + 1
            )  # completes the iteration when terminate called
    assert trainer.current_epoch == np.ceil(
        iteration_to_stop / num_iterations_per_epoch) - 1  # it starts from 0
Пример #29
0
def test_terminate_at_end_of_epoch_stops_training():
  max_epochs = 5
  last_epoch_to_run = 3

  def end_of_epoch_handler(trainer):
    if trainer.current_epoch == last_epoch_to_run:
      trainer.terminate()

  trainer = Trainer([1], MagicMock(return_value=1), MagicMock(), MagicMock())
  trainer.add_event_handler(TrainingEvents.EPOCH_COMPLETED, end_of_epoch_handler)

  assert not trainer.should_terminate

  trainer.run(max_epochs=max_epochs, validate_every_epoch=False)

  assert trainer.current_epoch == last_epoch_to_run + 1  # counter is incremented at end of loop
  assert trainer.should_terminate
Пример #30
0
def test_current_epoch_counter_increases_every_epoch():
    trainer = Trainer([1], MagicMock(return_value=1), MagicMock(), MagicMock())
    max_epochs = 5

    class EpochCounter(object):
        def __init__(self):
            self.current_epoch_count = 0

        def __call__(self, trainer):
            assert trainer.current_epoch == self.current_epoch_count
            self.current_epoch_count += 1

    trainer.add_event_handler(TrainingEvents.EPOCH_STARTED, EpochCounter())

    trainer.run(max_epochs=max_epochs, validate_every_epoch=False)

    assert trainer.current_epoch == max_epochs