def test_loops_state_dict():
    fit_loop = FitLoop()
    with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"):
        fit_loop.trainer = object()

    fit_loop.connect(Mock())
    state_dict = fit_loop.state_dict()
    new_fit_loop = FitLoop()
    new_fit_loop.load_state_dict(state_dict)
    assert fit_loop.state_dict() == new_fit_loop.state_dict()
Example #2
0
def test_loops_state_dict():
    trainer = Trainer()

    fit_loop = FitLoop()

    fit_loop.trainer = trainer
    state_dict = fit_loop.state_dict()

    new_fit_loop = FitLoop()
    new_fit_loop.trainer = trainer

    new_fit_loop.load_state_dict(state_dict)
    assert fit_loop.state_dict() == new_fit_loop.state_dict()