def test_optimizer_to_device():
    class TestOptimizer(torch.optim.SGD):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.state["dummy"] = torch.tensor(0)

    layer = torch.nn.Linear(32, 2)
    opt = TestOptimizer(layer.parameters(), lr=0.1)
    optimizer_to_device(opt, "cpu")
    if torch.cuda.is_available():
        optimizer_to_device(opt, "cuda")
        assert_opt_parameters_on_device(opt, "cuda")
Ejemplo n.º 2
0
 def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
     optimizer_states = checkpoint["optimizer_states"]
     for optimizer, opt_state in zip(self.optimizers, optimizer_states):
         optimizer.load_state_dict(opt_state)
         optimizer_to_device(optimizer, self.root_device)