Example #1
0
def test_validation_iteration_events_are_fired_when_validate_is_called_explicitly(
):
    max_epochs = 5
    num_batches = 3
    validation_data = _create_mock_data_loader(max_epochs, num_batches)

    trainer = Trainer(training_data=[None],
                      validation_data=validation_data,
                      training_update_function=MagicMock(),
                      validation_inference_function=MagicMock(return_value=1))

    mock_manager = Mock()
    iteration_started = Mock()
    trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED,
                              iteration_started)

    iteration_complete = Mock()
    trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_COMPLETED,
                              iteration_complete)

    mock_manager.attach_mock(iteration_started, 'iteration_started')
    mock_manager.attach_mock(iteration_complete, 'iteration_complete')

    assert iteration_started.call_count == 0
    assert iteration_complete.call_count == 0

    trainer.validate()

    assert iteration_started.call_count == num_batches
    assert iteration_complete.call_count == num_batches
Example #2
0
def test_validate_not_called_if_validate_every_epoch_is_false():
    trainer = Trainer([1], MagicMock(return_value=1), MagicMock(), MagicMock())
    trainer.validate = MagicMock()

    max_epochs = 5
    trainer.run(max_epochs=max_epochs, validate_every_epoch=False)
    assert trainer.validate.call_count == 0
Example #3
0
def test_validate_is_called_every_epoch_by_default():
    trainer = Trainer([1], MagicMock(return_value=1), [1], MagicMock())
    trainer.validate = MagicMock()

    max_epochs = 5
    trainer.run(max_epochs=max_epochs)
    assert trainer.validate.call_count == max_epochs
Example #4
0
def test_terminate_after_training_iteration_skips_validation_run():
  num_iterations_per_epoch = 10
  iteration_to_stop = num_iterations_per_epoch - 1
  trainer = Trainer(training_data=[None] * num_iterations_per_epoch,
                    training_update_function=MagicMock(return_value=1),
                    validation_data=MagicMock(),
                    validation_inference_function=MagicMock())

  def end_of_iteration_handler(trainer):
    if trainer.current_iteration == iteration_to_stop:
      trainer.terminate()

  trainer.validate = MagicMock()

  trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED, end_of_iteration_handler)
  trainer.run(max_epochs=3, validate_every_epoch=True)
  assert trainer.validate.call_count == 0
Example #5
0
def test_terminate_after_training_iteration_skips_validation_run():
    num_iterations_per_epoch = 10
    iteration_to_stop = num_iterations_per_epoch - 1
    trainer = Trainer(MagicMock(return_value=1), MagicMock())

    def end_of_iteration_handler(trainer):
        if trainer.current_iteration == iteration_to_stop:
            trainer.terminate()

    trainer.validate = MagicMock()

    trainer.add_event_handler(TrainingEvents.TRAINING_EPOCH_COMPLETED,
                              _validate, MagicMock())
    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED,
                              end_of_iteration_handler)
    trainer.run([None] * num_iterations_per_epoch, max_epochs=3)
    assert trainer.validate.call_count == 0