def test_training_iteration_events_are_fired(): max_epochs = 5 num_batches = 3 data = _create_mock_data_loader(max_epochs, num_batches) trainer = Trainer(MagicMock(return_value=1)) mock_manager = Mock() iteration_started = Mock() trainer.add_event_handler(Events.ITERATION_STARTED, iteration_started) iteration_complete = Mock() trainer.add_event_handler(Events.ITERATION_COMPLETED, iteration_complete) mock_manager.attach_mock(iteration_started, 'iteration_started') mock_manager.attach_mock(iteration_complete, 'iteration_complete') state = trainer.run(data, max_epochs=max_epochs) assert iteration_started.call_count == num_batches * max_epochs assert iteration_complete.call_count == num_batches * max_epochs expected_calls = [] for i in range(max_epochs * num_batches): expected_calls.append(call.iteration_started(trainer, state)) expected_calls.append(call.iteration_complete(trainer, state)) assert mock_manager.mock_calls == expected_calls
def test_custom_exception_handler(): value_error = ValueError() training_update_function = MagicMock(side_effect=value_error) trainer = Trainer(training_update_function) exception_handler = MagicMock() trainer.add_event_handler(Events.EXCEPTION_RAISED, exception_handler) state = trainer.run([1]) # only one call from _run_once_over_data, since the exception is swallowed exception_handler.assert_has_calls([call(trainer, state, value_error)])
def test_terminate_stops_training_mid_epoch(): num_iterations_per_epoch = 10 iteration_to_stop = num_iterations_per_epoch + 3 # i.e. part way through the 3rd epoch trainer = Trainer(MagicMock(return_value=1)) def start_of_iteration_handler(trainer): if trainer.state.iteration == iteration_to_stop: trainer.terminate() trainer.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler) trainer.run(data=[None] * num_iterations_per_epoch, max_epochs=3) # completes the iteration but doesn't increment counter (this happens just before a new iteration starts) assert (trainer.state.iteration == iteration_to_stop) assert trainer.state.epoch == np.ceil(iteration_to_stop / num_iterations_per_epoch) # it starts from 0
def test_current_epoch_counter_increases_every_epoch(): trainer = Trainer(MagicMock(return_value=1)) max_epochs = 5 class EpochCounter(object): def __init__(self): self.current_epoch_count = 1 def __call__(self, trainer, state): assert state.epoch == self.current_epoch_count self.current_epoch_count += 1 trainer.add_event_handler(Events.EPOCH_STARTED, EpochCounter()) state = trainer.run([1], max_epochs=max_epochs) assert state.epoch == max_epochs
def test_terminate_at_end_of_epoch_stops_training(): max_epochs = 5 last_epoch_to_run = 3 trainer = Trainer(MagicMock(return_value=1)) def end_of_epoch_handler(trainer, state): if state.epoch == last_epoch_to_run: trainer.terminate() trainer.add_event_handler(Events.EPOCH_COMPLETED, end_of_epoch_handler) assert not trainer.should_terminate state = trainer.run([1], max_epochs=max_epochs) assert state.epoch == last_epoch_to_run assert trainer.should_terminate
def test_current_iteration_counter_increases_every_iteration(): training_batches = [1, 2, 3] trainer = Trainer(MagicMock(return_value=1)) max_epochs = 5 class IterationCounter(object): def __init__(self): self.current_iteration_count = 1 def __call__(self, trainer, state): assert state.iteration == self.current_iteration_count self.current_iteration_count += 1 trainer.add_event_handler(Events.ITERATION_STARTED, IterationCounter()) state = trainer.run(training_batches, max_epochs=max_epochs) assert state.iteration == max_epochs * len(training_batches)
def test_with_trainer(dirname): def update_fn(batch): pass name = 'model' trainer = Trainer(update_fn) handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, save_interval=1) trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {name: 42}) trainer.run([0], max_epochs=4) expected = ['{}_{}_{}.pth'.format(_PREFIX, name, i) for i in [3, 4]] assert sorted(os.listdir(dirname)) == expected
def test_terminate_at_start_of_epoch_stops_training_after_completing_iteration(): max_epochs = 5 epoch_to_terminate_on = 3 batches_per_epoch = [1, 2, 3] trainer = Trainer(MagicMock(return_value=1)) def start_of_epoch_handler(trainer): if trainer.state.epoch == epoch_to_terminate_on: trainer.terminate() trainer.add_event_handler(Events.EPOCH_STARTED, start_of_epoch_handler) assert not trainer.should_terminate trainer.run(batches_per_epoch, max_epochs=max_epochs) # epoch is not completed so counter is not incremented assert trainer.state.epoch == epoch_to_terminate_on assert trainer.should_terminate # completes first iteration assert trainer.state.iteration == ((epoch_to_terminate_on - 1) * len(batches_per_epoch)) + 1