def test_custom_exception_handler(): value_error = ValueError() update_function = MagicMock(side_effect=value_error) engine = Engine(update_function) exception_handler = MagicMock() engine.add_event_handler(Events.EXCEPTION_RAISED, exception_handler) state = engine.run([1]) # only one call from _run_once_over_data, since the exception is swallowed exception_handler.assert_has_calls([call(engine, value_error)])
def test_state_attributes(): dataloader = [1, 2, 3] engine = Engine(MagicMock(return_value=1)) state = engine.run(dataloader, max_epochs=3) assert state.iteration == 9 assert state.output == 1 assert state.batch == 3 assert state.dataloader == dataloader assert state.epoch == 3 assert state.max_epochs == 3 assert state.metrics == {}
def test_terminate_stops_run_mid_epoch(): num_iterations_per_epoch = 10 iteration_to_stop = num_iterations_per_epoch + 3 # i.e. part way through the 3rd epoch engine = Engine(MagicMock(return_value=1)) def start_of_iteration_handler(engine): if engine.state.iteration == iteration_to_stop: engine.terminate() engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler) state = engine.run(data=[None] * num_iterations_per_epoch, max_epochs=3) # completes the iteration but doesn't increment counter (this happens just before a new iteration starts) assert (state.iteration == iteration_to_stop) assert state.epoch == np.ceil(iteration_to_stop / num_iterations_per_epoch) # it starts from 0
def test_with_engine_no_early_stopping(): class Counter(object): def __init__(self, count=0): self.count = count n_epochs_counter = Counter() scores = iter([1.0, 0.8, 1.2, 1.23, 0.9, 1.0, 1.1, 1.253, 1.26, 1.2]) def score_function(engine): return next(scores) def update_fn(engine, batch): pass trainer = Engine(update_fn) evaluator = Engine(update_fn) early_stopping = EarlyStopping(patience=5, score_function=score_function, trainer=trainer) @trainer.on(Events.EPOCH_COMPLETED) def evaluation(engine): evaluator.run([0]) n_epochs_counter.count += 1 evaluator.add_event_handler(Events.COMPLETED, early_stopping) trainer.run([0], max_epochs=10) assert n_epochs_counter.count == 10
def test_current_epoch_counter_increases_every_epoch(): engine = Engine(MagicMock(return_value=1)) max_epochs = 5 class EpochCounter(object): def __init__(self): self.current_epoch_count = 1 def __call__(self, engine): assert engine.state.epoch == self.current_epoch_count self.current_epoch_count += 1 engine.add_event_handler(Events.EPOCH_STARTED, EpochCounter()) state = engine.run([1], max_epochs=max_epochs) assert state.epoch == max_epochs
def test_iteration_events_are_fired(): max_epochs = 5 num_batches = 3 data = _create_mock_data_loader(max_epochs, num_batches) engine = Engine(MagicMock(return_value=1)) mock_manager = Mock() iteration_started = Mock() engine.add_event_handler(Events.ITERATION_STARTED, iteration_started) iteration_complete = Mock() engine.add_event_handler(Events.ITERATION_COMPLETED, iteration_complete) mock_manager.attach_mock(iteration_started, 'iteration_started') mock_manager.attach_mock(iteration_complete, 'iteration_complete') state = engine.run(data, max_epochs=max_epochs) assert iteration_started.call_count == num_batches * max_epochs assert iteration_complete.call_count == num_batches * max_epochs expected_calls = [] for i in range(max_epochs * num_batches): expected_calls.append(call.iteration_started(engine)) expected_calls.append(call.iteration_complete(engine)) assert mock_manager.mock_calls == expected_calls
def test_terminate_at_end_of_epoch_stops_run(): max_epochs = 5 last_epoch_to_run = 3 engine = Engine(MagicMock(return_value=1)) def end_of_epoch_handler(engine): if engine.state.epoch == last_epoch_to_run: engine.terminate() engine.add_event_handler(Events.EPOCH_COMPLETED, end_of_epoch_handler) assert not engine.should_terminate state = engine.run([1], max_epochs=max_epochs) assert state.epoch == last_epoch_to_run assert engine.should_terminate
def test_current_iteration_counter_increases_every_iteration(): batches = [1, 2, 3] engine = Engine(MagicMock(return_value=1)) max_epochs = 5 class IterationCounter(object): def __init__(self): self.current_iteration_count = 1 def __call__(self, engine): assert engine.state.iteration == self.current_iteration_count self.current_iteration_count += 1 engine.add_event_handler(Events.ITERATION_STARTED, IterationCounter()) state = engine.run(batches, max_epochs=max_epochs) assert state.iteration == max_epochs * len(batches)
def test_timer(): sleep_t = 0.2 n_iter = 3 def _train_func(engine, batch): time.sleep(sleep_t) def _test_func(engine, batch): time.sleep(sleep_t) trainer = Engine(_train_func) tester = Engine(_test_func) t_total = Timer() t_batch = Timer(average=True) t_train = Timer() t_total.attach(trainer) t_batch.attach(trainer, pause=Events.ITERATION_COMPLETED, resume=Events.ITERATION_STARTED, step=Events.ITERATION_COMPLETED) t_train.attach(trainer, pause=Events.EPOCH_COMPLETED, resume=Events.EPOCH_STARTED) @trainer.on(Events.EPOCH_COMPLETED) def run_validation(trainer): tester.run(range(n_iter)) # Run "training" trainer.run(range(n_iter)) def _equal(lhs, rhs): return round(lhs, 1) == round(rhs, 1) assert _equal(t_total.value(), (2 * n_iter * sleep_t)) assert _equal(t_batch.value(), (sleep_t)) assert _equal(t_train.value(), (n_iter * sleep_t)) t_total.reset() assert _equal(t_total.value(), 0.0)
def test_terminate_at_start_of_epoch_stops_run_after_completing_iteration(): max_epochs = 5 epoch_to_terminate_on = 3 batches_per_epoch = [1, 2, 3] engine = Engine(MagicMock(return_value=1)) def start_of_epoch_handler(engine): if engine.state.epoch == epoch_to_terminate_on: engine.terminate() engine.add_event_handler(Events.EPOCH_STARTED, start_of_epoch_handler) assert not engine.should_terminate state = engine.run(batches_per_epoch, max_epochs=max_epochs) # epoch is not completed so counter is not incremented assert state.epoch == epoch_to_terminate_on assert engine.should_terminate # completes first iteration assert state.iteration == ((epoch_to_terminate_on - 1) * len(batches_per_epoch)) + 1
def test_args_validation(): def update_fn(engine, batch): pass trainer = Engine(update_fn) # save_interval & score_func with pytest.raises(AssertionError): h = EarlyStopping(patience=-1, score_function=lambda engine: 0, trainer=trainer) with pytest.raises(AssertionError): h = EarlyStopping(patience=2, score_function=12345, trainer=trainer) with pytest.raises(AssertionError): h = EarlyStopping(patience=2, score_function=lambda engine: 0, trainer=None)
def test_simple_no_early_stopping(): scores = iter([1.0, 0.8, 1.2]) def score_function(engine): return next(scores) def update_fn(engine, batch): pass trainer = Engine(update_fn) h = EarlyStopping(patience=2, score_function=score_function, trainer=trainer) # Call 3 times and check if not stopped assert not trainer.should_terminate h(None) h(None) h(None) assert not trainer.should_terminate
def create_supervised_evaluator(model, inference_fn, metrics={}, cuda=False): """ Factory function for creating an evaluator for supervised models. Extended version from ignite's create_supervised_evaluator Args: model (torch.nn.Module): the model to train inference_fn (function): inference function metrics (dict of str: Metric): a map of metric names to Metrics cuda (bool, optional): whether or not to transfer batch to GPU (default: False) Returns: Engine: an evaluator engine with supervised inference function """ engine = Engine(inference_fn) for name, metric in metrics.items(): metric.attach(engine, name) return engine
def test_with_engine(dirname): def update_fn(engine, batch): pass name = 'model' engine = Engine(update_fn) handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, save_interval=1) engine.add_event_handler(Events.EPOCH_COMPLETED, handler, {name: 42}) engine.run([0], max_epochs=4) expected = ['{}_{}_{}.pth'.format(_PREFIX, name, i) for i in [3, 4]] assert sorted(os.listdir(dirname)) == expected
def create_dataflow_checker(): def _update(engine, batch): return batch return Engine(_update)
optimizer.step() return loss.cpu() def inference_function(engine, batch): classifier.eval() text, y = batch.text, batch.sentiment x = text[0] seq_len = text[1].numpy() seq_len[::-1].sort() softmax = nn.Softmax(dim=1) y_pred = classifier(x, seq_len) y_pred = softmax(y_pred) return y_pred.cpu(), y.squeeze().cpu() trainer = Engine(training_update_function) evaluator = create_supervised_evaluator(model=classifier, inference_fn=inference_function, metrics={ "loss": Loss(loss_fn), "acc": CategoricalAccuracy(), "prec": Precision(), "rec": Recall() }) checkpoint = ModelCheckpoint(ARGS.model_dir, "sentiment", save_interval=ARGS.save_interval, n_saved=5, create_dir=True, require_empty=False) loader = ModelLoader(classifier, ARGS.model_dir, "sentiment")
def test_returns_state(): engine = Engine(MagicMock(return_value=1)) state = engine.run([]) assert isinstance(state, State)
def test_default_exception_handler(): update_function = MagicMock(side_effect=ValueError()) engine = Engine(update_function) with raises(ValueError): engine.run([1])
return result.detach() def evaluation_function(engine, batch): crf_tagger.eval() sentence = batch.sentence[0] sent_len = batch.sentence[1] tags = batch.tags x = embedding(sentence) result = torch.tensor(crf_tagger.decode(x, sent_len.numpy()), dtype=torch.int32) result = result.transpose(1, 0) return result, tags.detach() trainer = Engine(process_function) evaluator = create_supervised_evaluator(model=crf_tagger, inference_fn=evaluation_function, metrics={ "acc": SequenceTagAccuracy( tags.vocab), }) checkpoint = ModelCheckpoint("models", "postag-en", save_interval=100, n_saved=5, create_dir=True, require_empty=False) trainer.add_event_handler(Events.ITERATION_COMPLETED, checkpoint,
def test_stopping_criterion_is_max_epochs(): engine = Engine(MagicMock(return_value=1)) max_epochs = 5 state = engine.run([1], max_epochs=max_epochs) assert state.epoch == max_epochs