def test_state_attributes(): dataloader = [1, 2, 3] evaluator = Evaluator(MagicMock(return_value=1)) state = evaluator.run(dataloader) assert state.iteration == 3 assert state.output == 1 assert state.batch == 3 assert state.dataloader == dataloader
def test_current_validation_iteration_counter_increases_every_iteration(): validation_batches = [1, 2, 3] evaluator = Evaluator(MagicMock(return_value=1)) num_runs = 5 class IterationCounter(object): def __init__(self): self.current_iteration_count = 1 self.total_count = 0 def __call__(self, evaluator, state): assert state.iteration == self.current_iteration_count self.current_iteration_count += 1 self.total_count += 1 def clear(self): self.current_iteration_count = 1 iteration_counter = IterationCounter() def clear_counter(evaluator, state, counter): counter.clear() evaluator.add_event_handler(Events.STARTED, clear_counter, iteration_counter) evaluator.add_event_handler(Events.ITERATION_STARTED, iteration_counter) for _ in range(num_runs): evaluator.run(validation_batches) assert iteration_counter.total_count == num_runs * len(validation_batches)
def test_terminate_stops_evaluator_when_called_during_iteration(): num_iterations = 10 iteration_to_stop = 3 # i.e. part way through the 3rd validation run evaluator = Evaluator(MagicMock(return_value=1)) def start_of_iteration_handler(evaluator, state): if state.iteration == iteration_to_stop: evaluator.terminate() evaluator.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler) state = evaluator.run([None] * num_iterations) # should complete the iteration when terminate called but not increment counter assert state.iteration == iteration_to_stop
def test_timer(): sleep_t = 0.2 n_iter = 3 def _train_func(engine, batch): time.sleep(sleep_t) def _test_func(engine, batch): time.sleep(sleep_t) trainer = Trainer(_train_func) tester = Evaluator(_test_func) t_total = Timer() t_batch = Timer(average=True) t_train = Timer() t_total.attach(trainer) t_batch.attach(trainer, pause=Events.ITERATION_COMPLETED, resume=Events.ITERATION_STARTED, step=Events.ITERATION_COMPLETED) t_train.attach(trainer, pause=Events.EPOCH_COMPLETED, resume=Events.EPOCH_STARTED) @trainer.on(Events.EPOCH_COMPLETED) def run_validation(trainer): tester.run(range(n_iter)) # Run "training" trainer.run(range(n_iter)) def _equal(lhs, rhs): return round(lhs, 1) == round(rhs, 1) assert _equal(t_total.value(), (2 * n_iter * sleep_t)) assert _equal(t_batch.value(), (sleep_t)) assert _equal(t_train.value(), (n_iter * sleep_t)) t_total.reset() assert _equal(t_total.value(), 0.0)
def test_evaluation_iteration_events_are_fired(): evaluator = Evaluator(MagicMock(return_value=1)) mock_manager = Mock() iteration_started = Mock() evaluator.add_event_handler(Events.ITERATION_STARTED, iteration_started) iteration_complete = Mock() evaluator.add_event_handler(Events.ITERATION_COMPLETED, iteration_complete) mock_manager.attach_mock(iteration_started, 'iteration_started') mock_manager.attach_mock(iteration_complete, 'iteration_complete') batches = [(1, 2), (3, 4), (5, 6)] evaluator.run(batches) assert iteration_started.call_count == len(batches) assert iteration_complete.call_count == len(batches) expected_calls = [] for i in range(len(batches)): expected_calls.append(call.iteration_started(evaluator)) expected_calls.append(call.iteration_complete(evaluator)) assert mock_manager.mock_calls == expected_calls
def test_returns_state(): evaluator = Evaluator(MagicMock(return_value=1)) state = evaluator.run([]) assert isinstance(state, State)
def test_creates_state(): evaluator = Evaluator(MagicMock(return_value=1)) evaluator.run([]) assert isinstance(evaluator.state, State)