Пример #1
0
def test_state_attributes():
    dataloader = [1, 2, 3]
    evaluator = Evaluator(MagicMock(return_value=1))
    state = evaluator.run(dataloader)

    assert state.iteration == 3
    assert state.output == 1
    assert state.batch == 3
    assert state.dataloader == dataloader
Пример #2
0
def test_current_validation_iteration_counter_increases_every_iteration():
    validation_batches = [1, 2, 3]
    evaluator = Evaluator(MagicMock(return_value=1))
    num_runs = 5

    class IterationCounter(object):
        def __init__(self):
            self.current_iteration_count = 1
            self.total_count = 0

        def __call__(self, evaluator, state):
            assert state.iteration == self.current_iteration_count
            self.current_iteration_count += 1
            self.total_count += 1

        def clear(self):
            self.current_iteration_count = 1

    iteration_counter = IterationCounter()

    def clear_counter(evaluator, state, counter):
        counter.clear()

    evaluator.add_event_handler(Events.STARTED, clear_counter,
                                iteration_counter)
    evaluator.add_event_handler(Events.ITERATION_STARTED, iteration_counter)

    for _ in range(num_runs):
        evaluator.run(validation_batches)

    assert iteration_counter.total_count == num_runs * len(validation_batches)
Пример #3
0
def test_terminate_stops_evaluator_when_called_during_iteration():
    num_iterations = 10
    iteration_to_stop = 3  # i.e. part way through the 3rd validation run
    evaluator = Evaluator(MagicMock(return_value=1))

    def start_of_iteration_handler(evaluator, state):
        if state.iteration == iteration_to_stop:
            evaluator.terminate()

    evaluator.add_event_handler(Events.ITERATION_STARTED,
                                start_of_iteration_handler)
    state = evaluator.run([None] * num_iterations)

    # should complete the iteration when terminate called but not increment counter
    assert state.iteration == iteration_to_stop
Пример #4
0
def test_timer():
    sleep_t = 0.2
    n_iter = 3

    def _train_func(engine, batch):
        time.sleep(sleep_t)

    def _test_func(engine, batch):
        time.sleep(sleep_t)

    trainer = Trainer(_train_func)
    tester = Evaluator(_test_func)

    t_total = Timer()
    t_batch = Timer(average=True)
    t_train = Timer()

    t_total.attach(trainer)
    t_batch.attach(trainer,
                   pause=Events.ITERATION_COMPLETED,
                   resume=Events.ITERATION_STARTED,
                   step=Events.ITERATION_COMPLETED)
    t_train.attach(trainer,
                   pause=Events.EPOCH_COMPLETED,
                   resume=Events.EPOCH_STARTED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def run_validation(trainer):
        tester.run(range(n_iter))

    # Run "training"
    trainer.run(range(n_iter))

    def _equal(lhs, rhs):
        return round(lhs, 1) == round(rhs, 1)

    assert _equal(t_total.value(), (2 * n_iter * sleep_t))
    assert _equal(t_batch.value(), (sleep_t))
    assert _equal(t_train.value(), (n_iter * sleep_t))

    t_total.reset()
    assert _equal(t_total.value(), 0.0)
Пример #5
0
def test_evaluation_iteration_events_are_fired():
    evaluator = Evaluator(MagicMock(return_value=1))

    mock_manager = Mock()
    iteration_started = Mock()
    evaluator.add_event_handler(Events.ITERATION_STARTED, iteration_started)

    iteration_complete = Mock()
    evaluator.add_event_handler(Events.ITERATION_COMPLETED, iteration_complete)

    mock_manager.attach_mock(iteration_started, 'iteration_started')
    mock_manager.attach_mock(iteration_complete, 'iteration_complete')

    batches = [(1, 2), (3, 4), (5, 6)]
    evaluator.run(batches)
    assert iteration_started.call_count == len(batches)
    assert iteration_complete.call_count == len(batches)

    expected_calls = []
    for i in range(len(batches)):
        expected_calls.append(call.iteration_started(evaluator))
        expected_calls.append(call.iteration_complete(evaluator))

    assert mock_manager.mock_calls == expected_calls
Пример #6
0
def test_returns_state():
    evaluator = Evaluator(MagicMock(return_value=1))
    state = evaluator.run([])

    assert isinstance(state, State)
Пример #7
0
def test_creates_state():
    evaluator = Evaluator(MagicMock(return_value=1))
    evaluator.run([])

    assert isinstance(evaluator.state, State)