コード例 #1
0
def test_current_epoch_counter_increases_every_epoch(data):
    engine = Engine(MagicMock(return_value=1))
    max_epochs = 5

    counter = EpochCounter()
    engine.add_event_handler(Events.EPOCH_STARTED, counter)

    state = engine.run(data, max_epochs=max_epochs, epoch_length=2)
    assert state.epoch == max_epochs
    counter.current_epoch_count = 1
    state = engine.run(data, max_epochs=max_epochs, epoch_length=2)
    assert state.epoch == max_epochs
コード例 #2
0
def test_load_state_dict_integration():
    engine = Engine(lambda e, b: 1)

    state_dict = {"max_epochs": 100, "epoch_length": 120, "epoch": 5}

    engine.load_state_dict(state_dict)
    engine.add_event_handler(Events.ITERATION_COMPLETED, IterationCounter(5 * 120 + 1))
    engine.add_event_handler(Events.EPOCH_COMPLETED, EpochCounter(6))
    data = range(120)
    engine.run(data)