def test_state_attributes(): dataloader = [1, 2, 3] trainer = Trainer(MagicMock(return_value=1)) state = trainer.run(dataloader, max_epochs=3) assert state.iteration == 9 assert state.output == 1 assert state.batch == 3 assert state.dataloader == dataloader assert state.epoch == 3 assert state.max_epochs == 3
def test_custom_exception_handler(): value_error = ValueError() training_update_function = MagicMock(side_effect=value_error) trainer = Trainer(training_update_function) exception_handler = MagicMock() trainer.add_event_handler(Events.EXCEPTION_RAISED, exception_handler) state = trainer.run([1]) # only one call from _run_once_over_data, since the exception is swallowed exception_handler.assert_has_calls([call(trainer, state, value_error)])
def test_training_iteration_events_are_fired(): max_epochs = 5 num_batches = 3 data = _create_mock_data_loader(max_epochs, num_batches) trainer = Trainer(MagicMock(return_value=1)) mock_manager = Mock() iteration_started = Mock() trainer.add_event_handler(Events.ITERATION_STARTED, iteration_started) iteration_complete = Mock() trainer.add_event_handler(Events.ITERATION_COMPLETED, iteration_complete) mock_manager.attach_mock(iteration_started, 'iteration_started') mock_manager.attach_mock(iteration_complete, 'iteration_complete') trainer.run(data, max_epochs=max_epochs) assert iteration_started.call_count == num_batches * max_epochs assert iteration_complete.call_count == num_batches * max_epochs expected_calls = [] for i in range(max_epochs * num_batches): expected_calls.append(call.iteration_started(trainer)) expected_calls.append(call.iteration_complete(trainer)) assert mock_manager.mock_calls == expected_calls
def test_terminate_stops_training_mid_epoch(): num_iterations_per_epoch = 10 iteration_to_stop = num_iterations_per_epoch + 3 # i.e. part way through the 3rd epoch trainer = Trainer(MagicMock(return_value=1)) def start_of_iteration_handler(trainer, state): if state.iteration == iteration_to_stop: trainer.terminate() trainer.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler) state = trainer.run(data=[None] * num_iterations_per_epoch, max_epochs=3) # completes the iteration but doesn't increment counter (this happens just before a new iteration starts) assert (state.iteration == iteration_to_stop) assert state.epoch == np.ceil(iteration_to_stop / num_iterations_per_epoch) # it starts from 0
def test_current_epoch_counter_increases_every_epoch(): trainer = Trainer(MagicMock(return_value=1)) max_epochs = 5 class EpochCounter(object): def __init__(self): self.current_epoch_count = 1 def __call__(self, trainer, state): assert state.epoch == self.current_epoch_count self.current_epoch_count += 1 trainer.add_event_handler(Events.EPOCH_STARTED, EpochCounter()) state = trainer.run([1], max_epochs=max_epochs) assert state.epoch == max_epochs
def test_terminate_at_end_of_epoch_stops_training(): max_epochs = 5 last_epoch_to_run = 3 trainer = Trainer(MagicMock(return_value=1)) def end_of_epoch_handler(trainer, state): if state.epoch == last_epoch_to_run: trainer.terminate() trainer.add_event_handler(Events.EPOCH_COMPLETED, end_of_epoch_handler) assert not trainer.should_terminate state = trainer.run([1], max_epochs=max_epochs) assert state.epoch == last_epoch_to_run assert trainer.should_terminate
def test_current_iteration_counter_increases_every_iteration(): training_batches = [1, 2, 3] trainer = Trainer(MagicMock(return_value=1)) max_epochs = 5 class IterationCounter(object): def __init__(self): self.current_iteration_count = 1 def __call__(self, trainer, state): assert state.iteration == self.current_iteration_count self.current_iteration_count += 1 trainer.add_event_handler(Events.ITERATION_STARTED, IterationCounter()) state = trainer.run(training_batches, max_epochs=max_epochs) assert state.iteration == max_epochs * len(training_batches)
def test_timer(): sleep_t = 0.2 n_iter = 3 def _train_func(engine, batch): time.sleep(sleep_t) def _test_func(engine, batch): time.sleep(sleep_t) trainer = Trainer(_train_func) tester = Evaluator(_test_func) t_total = Timer() t_batch = Timer(average=True) t_train = Timer() t_total.attach(trainer) t_batch.attach(trainer, pause=Events.ITERATION_COMPLETED, resume=Events.ITERATION_STARTED, step=Events.ITERATION_COMPLETED) t_train.attach(trainer, pause=Events.EPOCH_COMPLETED, resume=Events.EPOCH_STARTED) @trainer.on(Events.EPOCH_COMPLETED) def run_validation(trainer): tester.run(range(n_iter)) # Run "training" trainer.run(range(n_iter)) def _equal(lhs, rhs): return round(lhs, 1) == round(rhs, 1) assert _equal(t_total.value(), (2 * n_iter * sleep_t)) assert _equal(t_batch.value(), (sleep_t)) assert _equal(t_train.value(), (n_iter * sleep_t)) t_total.reset() assert _equal(t_total.value(), 0.0)
def test_terminate_at_start_of_epoch_stops_training_after_completing_iteration( ): max_epochs = 5 epoch_to_terminate_on = 3 batches_per_epoch = [1, 2, 3] trainer = Trainer(MagicMock(return_value=1)) def start_of_epoch_handler(trainer, state): if state.epoch == epoch_to_terminate_on: trainer.terminate() trainer.add_event_handler(Events.EPOCH_STARTED, start_of_epoch_handler) assert not trainer.should_terminate state = trainer.run(batches_per_epoch, max_epochs=max_epochs) # epoch is not completed so counter is not incremented assert state.epoch == epoch_to_terminate_on assert trainer.should_terminate # completes first iteration assert state.iteration == ( (epoch_to_terminate_on - 1) * len(batches_per_epoch)) + 1
def test_with_trainer(dirname): def update_fn(batch): pass name = 'model' trainer = Trainer(update_fn) handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, save_interval=1) trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {name: 42}) trainer.run([0], max_epochs=4) expected = ['{}_{}_{}.pth'.format(_PREFIX, name, i) for i in [3, 4]] assert sorted(os.listdir(dirname)) == expected
def test_default_exception_handler(): training_update_function = MagicMock(side_effect=ValueError()) trainer = Trainer(training_update_function) with raises(ValueError): trainer.run([1])
def test_returns_state(): trainer = Trainer(MagicMock(return_value=1)) state = trainer.run([]) assert isinstance(state, State)
def test_stopping_criterion_is_max_epochs(): trainer = Trainer(MagicMock(return_value=1)) max_epochs = 5 state = trainer.run([1], max_epochs=max_epochs) assert state.epoch == max_epochs
def test_creates_state(): trainer = Trainer(MagicMock(return_value=1)) trainer.run([]) assert isinstance(trainer.state, State)