optimizer=optimizer, batch_size=batch_size, optimizer_args=optimizer_args, epochs=5, shuffle=shuffle) arw = {"host":"10.0.0.1","hook":hook} h1 = WebsocketClientWorker(id="h1",port=8778,**arw) train_config.send(h1) message = h1.create_message_execute_command(command_name="start_monitoring",command_owner="self") serialized_message = syft.serde.serialize(message) h1._recv_msg(serialized_message) time.sleep(3) for epoch in range(10): time.sleep(3) loss = h1.fit(dataset_key="train") # ask alice to train using "xor" dataset print("-" * 50) print("Iteration %s: h1's loss: %s" % (epoch, loss)) message = h1.create_message_execute_command(command_name="stop_monitoring",command_owner="self") serialized_message = syft.serde.serialize(message) h1._recv_msg(serialized_message) new_model = train_config.model_ptr.get() h1.close()
def test_train_config_with_jit_trace_sync(hook, start_proc): # pragma: no cover kwargs = { "id": "sync_fit", "host": "localhost", "port": 9000, "hook": hook } # data = torch.tensor([[-1, 2.0], [0, 1.1], [-1, 2.1], [0, 1.2]], requires_grad=True) # target = torch.tensor([[1], [0], [1], [0]]) data, target = utils.create_gaussian_mixture_toy_data(100) dataset = sy.BaseDataset(data, target) dataset_key = "gaussian_mixture" process_remote_worker = start_proc(WebsocketServerWorker, dataset=(dataset, dataset_key), **kwargs) time.sleep(0.1) local_worker = WebsocketClientWorker(**kwargs) @hook.torch.jit.script def loss_fn(pred, target): return ((target.view(pred.shape).float() - pred.float())**2).mean() class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(2, 3) self.fc2 = nn.Linear(3, 2) self.fc3 = nn.Linear(2, 1) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model_untraced = Net() model = torch.jit.trace(model_untraced, data) pred = model(data) loss_before = loss_fn(pred=pred, target=target) # Create and send train config train_config = sy.TrainConfig(model=model, loss_fn=loss_fn, batch_size=2, epochs=1) train_config.send(local_worker) for epoch in range(5): loss = local_worker.fit(dataset_key=dataset_key) if PRINT_IN_UNITTESTS: # pragma: no cover print("-" * 50) print("Iteration %s: alice's loss: %s" % (epoch, loss)) new_model = train_config.model_ptr.get() # assert that the new model has updated (modified) parameters assert not (model.fc1._parameters["weight"] == new_model.obj.fc1._parameters["weight"]).all() assert not (model.fc2._parameters["weight"] == new_model.obj.fc2._parameters["weight"]).all() assert not (model.fc3._parameters["weight"] == new_model.obj.fc3._parameters["weight"]).all() assert not (model.fc1._parameters["bias"] == new_model.obj.fc1._parameters["bias"]).all() assert not (model.fc2._parameters["bias"] == new_model.obj.fc2._parameters["bias"]).all() assert not (model.fc3._parameters["bias"] == new_model.obj.fc3._parameters["bias"]).all() new_model.obj.eval() pred = new_model.obj(data) loss_after = loss_fn(pred=pred, target=target) if PRINT_IN_UNITTESTS: # pragma: no cover print("Loss before training: {}".format(loss_before)) print("Loss after training: {}".format(loss_after)) local_worker.ws.shutdown() del local_worker time.sleep(0.1) process_remote_worker.terminate() assert loss_after < loss_before