def test_validation_iteration_events_are_fired_when_validate_is_called_explicitly( ): max_epochs = 5 num_batches = 3 validation_data = _create_mock_data_loader(max_epochs, num_batches) trainer = Trainer(training_data=[None], validation_data=validation_data, training_update_function=MagicMock(), validation_inference_function=MagicMock(return_value=1)) mock_manager = Mock() iteration_started = Mock() trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED, iteration_started) iteration_complete = Mock() trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_COMPLETED, iteration_complete) mock_manager.attach_mock(iteration_started, 'iteration_started') mock_manager.attach_mock(iteration_complete, 'iteration_complete') assert iteration_started.call_count == 0 assert iteration_complete.call_count == 0 trainer.validate() assert iteration_started.call_count == num_batches assert iteration_complete.call_count == num_batches
def test_validate_not_called_if_validate_every_epoch_is_false(): trainer = Trainer([1], MagicMock(return_value=1), MagicMock(), MagicMock()) trainer.validate = MagicMock() max_epochs = 5 trainer.run(max_epochs=max_epochs, validate_every_epoch=False) assert trainer.validate.call_count == 0
def test_validate_is_called_every_epoch_by_default(): trainer = Trainer([1], MagicMock(return_value=1), [1], MagicMock()) trainer.validate = MagicMock() max_epochs = 5 trainer.run(max_epochs=max_epochs) assert trainer.validate.call_count == max_epochs
def test_terminate_after_training_iteration_skips_validation_run(): num_iterations_per_epoch = 10 iteration_to_stop = num_iterations_per_epoch - 1 trainer = Trainer(training_data=[None] * num_iterations_per_epoch, training_update_function=MagicMock(return_value=1), validation_data=MagicMock(), validation_inference_function=MagicMock()) def end_of_iteration_handler(trainer): if trainer.current_iteration == iteration_to_stop: trainer.terminate() trainer.validate = MagicMock() trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED, end_of_iteration_handler) trainer.run(max_epochs=3, validate_every_epoch=True) assert trainer.validate.call_count == 0
def test_terminate_after_training_iteration_skips_validation_run(): num_iterations_per_epoch = 10 iteration_to_stop = num_iterations_per_epoch - 1 trainer = Trainer(MagicMock(return_value=1), MagicMock()) def end_of_iteration_handler(trainer): if trainer.current_iteration == iteration_to_stop: trainer.terminate() trainer.validate = MagicMock() trainer.add_event_handler(TrainingEvents.TRAINING_EPOCH_COMPLETED, _validate, MagicMock()) trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_STARTED, end_of_iteration_handler) trainer.run([None] * num_iterations_per_epoch, max_epochs=3) assert trainer.validate.call_count == 0