def test_current_epoch_counter_increases_every_epoch(data): engine = Engine(MagicMock(return_value=1)) max_epochs = 5 counter = EpochCounter() engine.add_event_handler(Events.EPOCH_STARTED, counter) state = engine.run(data, max_epochs=max_epochs, epoch_length=2) assert state.epoch == max_epochs counter.current_epoch_count = 1 state = engine.run(data, max_epochs=max_epochs, epoch_length=2) assert state.epoch == max_epochs
def test_load_state_dict_integration(): engine = Engine(lambda e, b: 1) state_dict = {"max_epochs": 100, "epoch_length": 120, "epoch": 5} engine.load_state_dict(state_dict) engine.add_event_handler(Events.ITERATION_COMPLETED, IterationCounter(5 * 120 + 1)) engine.add_event_handler(Events.EPOCH_COMPLETED, EpochCounter(6)) data = range(120) engine.run(data)