def test_state(tmpdir): model = torch.nn.Linear(3, 4) optimizer = torch.optim.Adam(model.parameters()) lightning_optimizer = LightningOptimizer(optimizer) # test state assert optimizer.state == lightning_optimizer.state lightning_optimizer.state = optimizer.state assert optimizer.state == lightning_optimizer.state # test param_groups assert optimizer.param_groups == lightning_optimizer.param_groups lightning_optimizer.param_groups = optimizer.param_groups assert optimizer.param_groups == lightning_optimizer.param_groups # test defaults assert optimizer.defaults == lightning_optimizer.defaults lightning_optimizer.defaults = optimizer.defaults assert optimizer.defaults == lightning_optimizer.defaults assert isinstance(lightning_optimizer, LightningOptimizer) assert isinstance(lightning_optimizer, Adam) assert isinstance(lightning_optimizer, Optimizer) lightning_dict = { k: v for k, v in lightning_optimizer.__dict__.items() if k not in {"_optimizer", "_optimizer_idx", "_trainer"} } assert lightning_dict == optimizer.__dict__ assert optimizer.state_dict() == lightning_optimizer.state_dict() assert optimizer.state == lightning_optimizer.state
def test_state(tmpdir): model = torch.nn.Linear(3, 4) optimizer = torch.optim.Adam(model.parameters()) lightning_optimizer = LightningOptimizer(optimizer) # test state assert optimizer.state == lightning_optimizer.state lightning_optimizer.state = optimizer.state assert optimizer.state == lightning_optimizer.state # test param_groups assert optimizer.param_groups == lightning_optimizer.param_groups lightning_optimizer.param_groups = optimizer.param_groups assert optimizer.param_groups == lightning_optimizer.param_groups # test defaults assert optimizer.defaults == lightning_optimizer.defaults lightning_optimizer.defaults = optimizer.defaults assert optimizer.defaults == lightning_optimizer.defaults assert isinstance(lightning_optimizer, LightningOptimizer) assert isinstance(lightning_optimizer, Adam) assert isinstance(lightning_optimizer, Optimizer) lightning_dict = {} special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx", "_support_closure", "_trainer", "__getstate__", "__setstate__", "state_dict", "load_state_dict", "zero_grad", "__setstate__", "add_param_group"] for k, v in lightning_optimizer.__dict__.items(): if k not in special_attrs: lightning_dict[k] = v assert lightning_dict == optimizer.__dict__ assert optimizer.state_dict() == lightning_optimizer.state_dict() assert optimizer.state == lightning_optimizer.state