def test_current_iteration_counter_increases_every_iteration(): batches = [1, 2, 3] engine = Engine(MagicMock(return_value=1)) max_epochs = 5 counter = IterationCounter() engine.add_event_handler(Events.ITERATION_STARTED, counter) state = engine.run(batches, max_epochs=max_epochs) assert state.iteration == max_epochs * len(batches) counter.current_iteration_count = 1 state = engine.run(batches, max_epochs=max_epochs) assert state.iteration == max_epochs * len(batches)
def test_current_iteration_counter_increases_every_iteration(data): engine = Engine(MagicMock(return_value=1)) max_epochs = 5 counter = IterationCounter() engine.add_event_handler(Events.ITERATION_STARTED, counter) epoch_length = 3 state = engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length) assert state.iteration == max_epochs * epoch_length counter.current_iteration_count = 1 state = engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length) assert state.iteration == max_epochs * epoch_length
def test_load_state_dict_integration(): engine = Engine(lambda e, b: 1) state_dict = {"max_epochs": 100, "epoch_length": 120, "epoch": 5} engine.load_state_dict(state_dict) engine.add_event_handler(Events.ITERATION_COMPLETED, IterationCounter(5 * 120 + 1)) engine.add_event_handler(Events.EPOCH_COMPLETED, EpochCounter(6)) data = range(120) engine.run(data)