def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811 trainer1 = PyTorchTrainer( model_creator, data_creator, optimizer_creator, loss_creator=lambda config: nn.MSELoss(), num_replicas=num_replicas) trainer1.train() filename = os.path.join(tempfile.mkdtemp(), "checkpoint") trainer1.save(filename) model1 = trainer1.get_model() trainer1.shutdown() trainer2 = PyTorchTrainer( model_creator, data_creator, optimizer_creator, loss_creator=lambda config: nn.MSELoss(), num_replicas=num_replicas) trainer2.restore(filename) os.remove(filename) model2 = trainer2.get_model() model1_state_dict = model1.state_dict() model2_state_dict = model2.state_dict() assert set(model1_state_dict.keys()) == set(model2_state_dict.keys()) for k in model1_state_dict: assert torch.equal(model1_state_dict[k], model2_state_dict[k])
def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811 def custom_train(config, models, dataloader, criterion, optimizers, **kwargs): result = {} for i, (model, optimizer) in enumerate(zip(models, optimizers)): result["model_{}".format(i)] = train(config, model, dataloader, criterion, optimizer) return result def multi_model_creator(config): return nn.Linear(1, 1), nn.Linear(1, 1) def multi_optimizer_creator(models, config): opts = [ torch.optim.SGD(model.parameters(), lr=0.0001) for model in models ] return opts[0], opts[1] trainer1 = PyTorchTrainer( multi_model_creator, data_creator, multi_optimizer_creator, loss_creator=lambda config: nn.MSELoss(), train_function=custom_train, num_replicas=num_replicas) trainer1.train() filename = os.path.join(tempfile.mkdtemp(), "checkpoint") trainer1.save(filename) models1 = trainer1.get_model() trainer1.shutdown() trainer2 = PyTorchTrainer( multi_model_creator, data_creator, multi_optimizer_creator, loss_creator=lambda config: nn.MSELoss(), num_replicas=num_replicas) trainer2.restore(filename) os.remove(filename) models2 = trainer2.get_model() for model_1, model_2 in zip(models1, models2): model1_state_dict = model_1.state_dict() model2_state_dict = model_2.state_dict() assert set(model1_state_dict.keys()) == set(model2_state_dict.keys()) for k in model1_state_dict: assert torch.equal(model1_state_dict[k], model2_state_dict[k]) trainer2.shutdown()