コード例 #1
0
def test_fit_loop_done_log_messages(caplog):
    fit_loop = FitLoop()
    trainer = Mock(spec=Trainer)
    fit_loop.trainer = trainer

    trainer.should_stop = False
    trainer.num_training_batches = 5
    assert not fit_loop.done
    assert not caplog.messages

    trainer.num_training_batches = 0
    assert fit_loop.done
    assert "No training batches" in caplog.text
    caplog.clear()
    trainer.num_training_batches = 5

    epoch_loop = Mock()
    epoch_loop.global_step = 10
    fit_loop.connect(epoch_loop=epoch_loop)
    fit_loop.max_steps = 10
    assert fit_loop.done
    assert "max_steps=10` reached" in caplog.text
    caplog.clear()
    fit_loop.max_steps = 20

    fit_loop.epoch_progress.current.processed = 3
    fit_loop.max_epochs = 3
    trainer.should_stop = True
    assert fit_loop.done
    assert "max_epochs=3` reached" in caplog.text
    caplog.clear()
    fit_loop.max_epochs = 5

    fit_loop.epoch_loop.min_steps = 0
    with caplog.at_level(level=logging.DEBUG, logger="pytorch_lightning.utilities.rank_zero"):
        assert fit_loop.done
    assert "should_stop` was set" in caplog.text

    fit_loop.epoch_loop.min_steps = 100
    assert not fit_loop.done
    assert "was signaled to stop but" in caplog.text
コード例 #2
0
def test_loops_state_dict():
    trainer = Trainer()

    fit_loop = FitLoop()
    with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"):
        fit_loop.trainer = object()

    fit_loop.trainer = trainer
    state_dict = fit_loop.state_dict()

    new_fit_loop = FitLoop()
    new_fit_loop.trainer = trainer

    new_fit_loop.load_state_dict(state_dict)
    assert fit_loop.state_dict() == new_fit_loop.state_dict()
コード例 #3
0
def test_loops_state_dict():
    trainer = Trainer()

    fit_loop = FitLoop()

    fit_loop.trainer = trainer
    state_dict = fit_loop.state_dict()

    new_fit_loop = FitLoop()
    new_fit_loop.trainer = trainer

    new_fit_loop.load_state_dict(state_dict)
    assert fit_loop.state_dict() == new_fit_loop.state_dict()