Exemple #1
0
def test_custom_exception_handler():
    value_error = ValueError()
    update_function = MagicMock(side_effect=value_error)

    engine = Engine(update_function)
    exception_handler = MagicMock()
    engine.add_event_handler(Events.EXCEPTION_RAISED, exception_handler)
    state = engine.run([1])

    # only one call from _run_once_over_data, since the exception is swallowed
    exception_handler.assert_has_calls([call(engine, value_error)])
Exemple #2
0
def test_state_attributes():
    dataloader = [1, 2, 3]
    engine = Engine(MagicMock(return_value=1))
    state = engine.run(dataloader, max_epochs=3)

    assert state.iteration == 9
    assert state.output == 1
    assert state.batch == 3
    assert state.dataloader == dataloader
    assert state.epoch == 3
    assert state.max_epochs == 3
    assert state.metrics == {}
Exemple #3
0
def test_terminate_stops_run_mid_epoch():
    num_iterations_per_epoch = 10
    iteration_to_stop = num_iterations_per_epoch + 3  # i.e. part way through the 3rd epoch
    engine = Engine(MagicMock(return_value=1))

    def start_of_iteration_handler(engine):
        if engine.state.iteration == iteration_to_stop:
            engine.terminate()

    engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler)
    state = engine.run(data=[None] * num_iterations_per_epoch, max_epochs=3)
    # completes the iteration but doesn't increment counter (this happens just before a new iteration starts)
    assert (state.iteration == iteration_to_stop)
    assert state.epoch == np.ceil(iteration_to_stop / num_iterations_per_epoch)  # it starts from 0
Exemple #4
0
def test_with_engine_no_early_stopping():

    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    n_epochs_counter = Counter()

    scores = iter([1.0, 0.8, 1.2, 1.23, 0.9, 1.0, 1.1, 1.253, 1.26, 1.2])

    def score_function(engine):
        return next(scores)

    def update_fn(engine, batch):
        pass

    trainer = Engine(update_fn)
    evaluator = Engine(update_fn)
    early_stopping = EarlyStopping(patience=5, score_function=score_function, trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        evaluator.run([0])
        n_epochs_counter.count += 1

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert n_epochs_counter.count == 10
Exemple #5
0
def test_current_epoch_counter_increases_every_epoch():
    engine = Engine(MagicMock(return_value=1))
    max_epochs = 5

    class EpochCounter(object):
        def __init__(self):
            self.current_epoch_count = 1

        def __call__(self, engine):
            assert engine.state.epoch == self.current_epoch_count
            self.current_epoch_count += 1

    engine.add_event_handler(Events.EPOCH_STARTED, EpochCounter())

    state = engine.run([1], max_epochs=max_epochs)

    assert state.epoch == max_epochs
Exemple #6
0
def test_iteration_events_are_fired():
    max_epochs = 5
    num_batches = 3
    data = _create_mock_data_loader(max_epochs, num_batches)

    engine = Engine(MagicMock(return_value=1))

    mock_manager = Mock()
    iteration_started = Mock()
    engine.add_event_handler(Events.ITERATION_STARTED, iteration_started)

    iteration_complete = Mock()
    engine.add_event_handler(Events.ITERATION_COMPLETED, iteration_complete)

    mock_manager.attach_mock(iteration_started, 'iteration_started')
    mock_manager.attach_mock(iteration_complete, 'iteration_complete')

    state = engine.run(data, max_epochs=max_epochs)

    assert iteration_started.call_count == num_batches * max_epochs
    assert iteration_complete.call_count == num_batches * max_epochs

    expected_calls = []
    for i in range(max_epochs * num_batches):
        expected_calls.append(call.iteration_started(engine))
        expected_calls.append(call.iteration_complete(engine))

    assert mock_manager.mock_calls == expected_calls
Exemple #7
0
def test_terminate_at_end_of_epoch_stops_run():
    max_epochs = 5
    last_epoch_to_run = 3

    engine = Engine(MagicMock(return_value=1))

    def end_of_epoch_handler(engine):
        if engine.state.epoch == last_epoch_to_run:
            engine.terminate()

    engine.add_event_handler(Events.EPOCH_COMPLETED, end_of_epoch_handler)

    assert not engine.should_terminate

    state = engine.run([1], max_epochs=max_epochs)

    assert state.epoch == last_epoch_to_run
    assert engine.should_terminate
Exemple #8
0
def test_current_iteration_counter_increases_every_iteration():
    batches = [1, 2, 3]
    engine = Engine(MagicMock(return_value=1))
    max_epochs = 5

    class IterationCounter(object):
        def __init__(self):
            self.current_iteration_count = 1

        def __call__(self, engine):
            assert engine.state.iteration == self.current_iteration_count
            self.current_iteration_count += 1

    engine.add_event_handler(Events.ITERATION_STARTED, IterationCounter())

    state = engine.run(batches, max_epochs=max_epochs)

    assert state.iteration == max_epochs * len(batches)
Exemple #9
0
def test_timer():
    sleep_t = 0.2
    n_iter = 3

    def _train_func(engine, batch):
        time.sleep(sleep_t)

    def _test_func(engine, batch):
        time.sleep(sleep_t)

    trainer = Engine(_train_func)
    tester = Engine(_test_func)

    t_total = Timer()
    t_batch = Timer(average=True)
    t_train = Timer()

    t_total.attach(trainer)
    t_batch.attach(trainer,
                   pause=Events.ITERATION_COMPLETED,
                   resume=Events.ITERATION_STARTED,
                   step=Events.ITERATION_COMPLETED)
    t_train.attach(trainer,
                   pause=Events.EPOCH_COMPLETED,
                   resume=Events.EPOCH_STARTED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def run_validation(trainer):
        tester.run(range(n_iter))

    # Run "training"
    trainer.run(range(n_iter))

    def _equal(lhs, rhs):
        return round(lhs, 1) == round(rhs, 1)

    assert _equal(t_total.value(), (2 * n_iter * sleep_t))
    assert _equal(t_batch.value(), (sleep_t))
    assert _equal(t_train.value(), (n_iter * sleep_t))

    t_total.reset()
    assert _equal(t_total.value(), 0.0)
Exemple #10
0
def test_terminate_at_start_of_epoch_stops_run_after_completing_iteration():
    max_epochs = 5
    epoch_to_terminate_on = 3
    batches_per_epoch = [1, 2, 3]

    engine = Engine(MagicMock(return_value=1))

    def start_of_epoch_handler(engine):
        if engine.state.epoch == epoch_to_terminate_on:
            engine.terminate()

    engine.add_event_handler(Events.EPOCH_STARTED, start_of_epoch_handler)

    assert not engine.should_terminate

    state = engine.run(batches_per_epoch, max_epochs=max_epochs)

    # epoch is not completed so counter is not incremented
    assert state.epoch == epoch_to_terminate_on
    assert engine.should_terminate
    # completes first iteration
    assert state.iteration == ((epoch_to_terminate_on - 1) * len(batches_per_epoch)) + 1
Exemple #11
0
def test_args_validation():

    def update_fn(engine, batch):
        pass

    trainer = Engine(update_fn)

    # save_interval & score_func
    with pytest.raises(AssertionError):
        h = EarlyStopping(patience=-1, score_function=lambda engine: 0, trainer=trainer)

    with pytest.raises(AssertionError):
        h = EarlyStopping(patience=2, score_function=12345, trainer=trainer)

    with pytest.raises(AssertionError):
        h = EarlyStopping(patience=2, score_function=lambda engine: 0, trainer=None)
Exemple #12
0
def test_simple_no_early_stopping():

    scores = iter([1.0, 0.8, 1.2])

    def score_function(engine):
        return next(scores)

    def update_fn(engine, batch):
        pass

    trainer = Engine(update_fn)

    h = EarlyStopping(patience=2, score_function=score_function, trainer=trainer)
    # Call 3 times and check if not stopped
    assert not trainer.should_terminate
    h(None)
    h(None)
    h(None)
    assert not trainer.should_terminate
Exemple #13
0
def create_supervised_evaluator(model, inference_fn, metrics={}, cuda=False):
    """
    Factory function for creating an evaluator for supervised models.
    Extended version from ignite's create_supervised_evaluator
    Args:
        model (torch.nn.Module): the model to train
        inference_fn (function): inference function
        metrics (dict of str: Metric): a map of metric names to Metrics
        cuda (bool, optional): whether or not to transfer batch to GPU (default: False)
    Returns:
        Engine: an evaluator engine with supervised inference function
    """

    engine = Engine(inference_fn)

    for name, metric in metrics.items():
        metric.attach(engine, name)

    return engine
Exemple #14
0
def test_with_engine(dirname):
    def update_fn(engine, batch):
        pass

    name = 'model'
    engine = Engine(update_fn)
    handler = ModelCheckpoint(dirname,
                              _PREFIX,
                              create_dir=False,
                              n_saved=2,
                              save_interval=1)

    engine.add_event_handler(Events.EPOCH_COMPLETED, handler, {name: 42})
    engine.run([0], max_epochs=4)

    expected = ['{}_{}_{}.pth'.format(_PREFIX, name, i) for i in [3, 4]]

    assert sorted(os.listdir(dirname)) == expected
def create_dataflow_checker():

    def _update(engine, batch):
        return batch

    return Engine(_update)
Exemple #16
0
        optimizer.step()
        return loss.cpu()

    def inference_function(engine, batch):
        classifier.eval()
        text, y = batch.text, batch.sentiment
        x = text[0]
        seq_len = text[1].numpy()
        seq_len[::-1].sort()
        softmax = nn.Softmax(dim=1)

        y_pred = classifier(x, seq_len)
        y_pred = softmax(y_pred)
        return y_pred.cpu(), y.squeeze().cpu()

    trainer = Engine(training_update_function)
    evaluator = create_supervised_evaluator(model=classifier,
                                            inference_fn=inference_function,
                                            metrics={
                                                "loss": Loss(loss_fn),
                                                "acc": CategoricalAccuracy(),
                                                "prec": Precision(),
                                                "rec": Recall()
                                            })
    checkpoint = ModelCheckpoint(ARGS.model_dir,
                                 "sentiment",
                                 save_interval=ARGS.save_interval,
                                 n_saved=5,
                                 create_dir=True,
                                 require_empty=False)
    loader = ModelLoader(classifier, ARGS.model_dir, "sentiment")
Exemple #17
0
def test_returns_state():
    engine = Engine(MagicMock(return_value=1))
    state = engine.run([])

    assert isinstance(state, State)
Exemple #18
0
def test_default_exception_handler():
    update_function = MagicMock(side_effect=ValueError())
    engine = Engine(update_function)

    with raises(ValueError):
        engine.run([1])
Exemple #19
0
        return result.detach()

    def evaluation_function(engine, batch):
        crf_tagger.eval()
        sentence = batch.sentence[0]
        sent_len = batch.sentence[1]
        tags = batch.tags

        x = embedding(sentence)
        result = torch.tensor(crf_tagger.decode(x, sent_len.numpy()),
                              dtype=torch.int32)
        result = result.transpose(1, 0)

        return result, tags.detach()

    trainer = Engine(process_function)
    evaluator = create_supervised_evaluator(model=crf_tagger,
                                            inference_fn=evaluation_function,
                                            metrics={
                                                "acc":
                                                SequenceTagAccuracy(
                                                    tags.vocab),
                                            })
    checkpoint = ModelCheckpoint("models",
                                 "postag-en",
                                 save_interval=100,
                                 n_saved=5,
                                 create_dir=True,
                                 require_empty=False)

    trainer.add_event_handler(Events.ITERATION_COMPLETED, checkpoint,
Exemple #20
0
def test_stopping_criterion_is_max_epochs():
    engine = Engine(MagicMock(return_value=1))
    max_epochs = 5
    state = engine.run([1], max_epochs=max_epochs)
    assert state.epoch == max_epochs