def test_optimizer_to_device(): class TestOptimizer(torch.optim.SGD): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.state["dummy"] = torch.tensor(0) layer = torch.nn.Linear(32, 2) opt = TestOptimizer(layer.parameters(), lr=0.1) optimizer_to_device(opt, "cpu") if torch.cuda.is_available(): optimizer_to_device(opt, "cuda") assert_opt_parameters_on_device(opt, "cuda")
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: optimizer_states = checkpoint["optimizer_states"] for optimizer, opt_state in zip(self.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) optimizer_to_device(optimizer, self.root_device)