def test_iteration_events_are_fired(): max_epochs = 5 num_batches = 3 data = _create_mock_data_loader(max_epochs, num_batches) engine = Engine(MagicMock(return_value=1)) mock_manager = Mock() iteration_started = Mock() engine.add_event_handler(Events.ITERATION_STARTED, iteration_started) iteration_complete = Mock() engine.add_event_handler(Events.ITERATION_COMPLETED, iteration_complete) mock_manager.attach_mock(iteration_started, 'iteration_started') mock_manager.attach_mock(iteration_complete, 'iteration_complete') state = engine.run(data, max_epochs=max_epochs) assert iteration_started.call_count == num_batches * max_epochs assert iteration_complete.call_count == num_batches * max_epochs expected_calls = [] for i in range(max_epochs * num_batches): expected_calls.append(call.iteration_started(engine)) expected_calls.append(call.iteration_complete(engine)) assert mock_manager.mock_calls == expected_calls
def test_validation_iteration_events_are_fired(): max_epochs = 5 num_batches = 3 validation_data = _create_mock_data_loader(max_epochs, num_batches) trainer = Trainer(training_data=[None], validation_data=validation_data, training_update_function=MagicMock(return_value=1), validation_inference_function=MagicMock(return_value=1)) mock_manager = Mock() iteration_started = Mock() trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_STARTED, iteration_started) iteration_complete = Mock() trainer.add_event_handler(TrainingEvents.VALIDATION_ITERATION_COMPLETED, iteration_complete) mock_manager.attach_mock(iteration_started, 'iteration_started') mock_manager.attach_mock(iteration_complete, 'iteration_complete') trainer.run(max_epochs=max_epochs) assert iteration_started.call_count == num_batches * max_epochs assert iteration_complete.call_count == num_batches * max_epochs expected_calls = [] for i in range(max_epochs * num_batches): expected_calls.append(call.iteration_started(trainer)) expected_calls.append(call.iteration_complete(trainer)) assert mock_manager.mock_calls == expected_calls
def test_evaluation_iteration_events_are_fired(): evaluator = Evaluator(MagicMock(return_value=1)) mock_manager = Mock() iteration_started = Mock() evaluator.add_event_handler(Events.ITERATION_STARTED, iteration_started) iteration_complete = Mock() evaluator.add_event_handler(Events.ITERATION_COMPLETED, iteration_complete) mock_manager.attach_mock(iteration_started, 'iteration_started') mock_manager.attach_mock(iteration_complete, 'iteration_complete') batches = [(1, 2), (3, 4), (5, 6)] state = evaluator.run(batches) assert iteration_started.call_count == len(batches) assert iteration_complete.call_count == len(batches) expected_calls = [] for i in range(len(batches)): expected_calls.append(call.iteration_started(evaluator, state)) expected_calls.append(call.iteration_complete(evaluator, state)) assert mock_manager.mock_calls == expected_calls