def test_list_objects_remote(hook, start_proc): kwargs = {"id": "fed", "host": "localhost", "port": 8765, "hook": hook} process_remote_fed1 = start_proc(WebsocketServerWorker, kwargs) time.sleep(0.1) kwargs = {"id": "fed", "host": "localhost", "port": 8765, "hook": hook} local_worker = WebsocketClientWorker(**kwargs) x = torch.tensor([1, 2, 3]).send(local_worker) res = local_worker.list_objects_remote() res_dict = eval(res.replace("tensor", "torch.tensor")) assert len(res_dict) == 1 y = torch.tensor([4, 5, 6]).send(local_worker) res = local_worker.list_objects_remote() res_dict = eval(res.replace("tensor", "torch.tensor")) assert len(res_dict) == 2 # delete x before terminating the websocket connection del x del y time.sleep(0.1) local_worker.ws.shutdown() time.sleep(0.1) local_worker.remove_worker_from_local_worker_registry() process_remote_fed1.terminate()
def test_objects_count_remote(hook, start_proc): kwargs = {"id": "fed", "host": "localhost", "port": 8764, "hook": hook} process_remote_worker = start_proc(WebsocketServerWorker, kwargs) time.sleep(0.1) kwargs = {"id": "fed", "host": "localhost", "port": 8764, "hook": hook} local_worker = WebsocketClientWorker(**kwargs) x = torch.tensor([1, 2, 3]).send(local_worker) nr_objects = local_worker.objects_count_remote() assert nr_objects == 1 y = torch.tensor([4, 5, 6]).send(local_worker) nr_objects = local_worker.objects_count_remote() assert nr_objects == 2 x.get() nr_objects = local_worker.objects_count_remote() assert nr_objects == 1 # delete remote object before terminating the websocket connection del y time.sleep(0.1) local_worker.ws.shutdown() time.sleep(0.1) local_worker.remove_worker_from_local_worker_registry() process_remote_worker.terminate()
def test_websocket_worker_multiple_output_response(hook, start_proc): """Evaluates that you can do basic tensor operations using WebsocketServerWorker""" kwargs = { "id": "socket_multiple_output", "host": "localhost", "port": 8768, "hook": hook } server = start_proc(WebsocketServerWorker, kwargs) time.sleep(0.1) x = torch.tensor([1.0, 3, 2]) socket_pipe = WebsocketClientWorker(**kwargs) x = x.send(socket_pipe) p1, p2 = torch.sort(x) x1, x2 = p1.get(), p2.get() assert (x1 == torch.tensor([1.0, 2, 3])).all() assert (x2 == torch.tensor([0, 2, 1])).all() del x socket_pipe.ws.shutdown() time.sleep(0.1) server.terminate()
def test_websocket_worker_multiple_output_response(hook, start_proc): """Evaluates that you can do basic tensor operations using WebsocketServerWorker""" kwargs = { "id": "socket_multiple_output", "host": "localhost", "port": 8768, "hook": hook } process_remote_worker = start_proc(WebsocketServerWorker, **kwargs) time.sleep(0.1) x = torch.tensor([1.0, 3, 2]) local_worker = WebsocketClientWorker(**kwargs) x = x.send(local_worker) p1, p2 = torch.sort(x) x1, x2 = p1.get(), p2.get() assert (x1 == torch.tensor([1.0, 2, 3])).all() assert (x2 == torch.tensor([0, 2, 1])).all() x.get() # retrieve remote object before closing the websocket connection local_worker.ws.shutdown() process_remote_worker.terminate()
def test_websocket_workers_search(hook, start_proc): """Evaluates that a client can search and find tensors that belong to another party""" # Sample tensor to store on the server sample_data = torch.tensor([1, 2, 3, 4]).tag("#sample_data", "#another_tag") # Args for initializing the websocket server and client base_kwargs = {"id": "fed2", "host": "localhost", "port": 8767, "hook": hook} server_kwargs = base_kwargs server_kwargs["data"] = [sample_data] server = start_proc(WebsocketServerWorker, server_kwargs) time.sleep(0.1) client_worker = WebsocketClientWorker(**base_kwargs) # Search for the tensor located on the server by using its tag results = client_worker.search("#sample_data", "#another_tag") assert results assert results[0].owner.id == "me" assert results[0].location.id == "fed2" # Search multiple times should still work results = client_worker.search("#sample_data", "#another_tag") assert results assert results[0].owner.id == "me" assert results[0].location.id == "fed2" client_worker.ws.shutdown() client_worker.ws.close() time.sleep(0.1) server.terminate()
def test_plan_execute_remotely(hook, start_proc): """Test plan execution remotely.""" hook.local_worker.is_client_worker = False @sy.func2plan def my_plan(data): x = data * 2 y = (x - 2) * 10 return x + y # TODO: remove this line when issue #2062 is fixed # Force to build plan x = th.tensor([-1, 2, 3]) my_plan(x) kwargs = {"id": "test_plan_worker", "host": "localhost", "port": 8799, "hook": hook} server = start_proc(WebsocketServerWorker, kwargs) time.sleep(0.1) socket_pipe = WebsocketClientWorker(**kwargs) plan_ptr = my_plan.send(socket_pipe) x_ptr = x.send(socket_pipe) plan_res = plan_ptr(x_ptr).get() assert (plan_res == th.tensor([-42, 24, 46])).all() # delete remote object before websocket connection termination del x_ptr server.terminate()
def test_connect_close(hook, start_proc): kwargs = {"id": "fed", "host": "localhost", "port": 8763, "hook": hook} process_remote_worker = start_proc(WebsocketServerWorker, **kwargs) time.sleep(0.1) kwargs = {"id": "fed", "host": "localhost", "port": 8763, "hook": hook} local_worker = WebsocketClientWorker(**kwargs) x = torch.tensor([1, 2, 3]) x_ptr = x.send(local_worker) assert local_worker.objects_count_remote() == 1 local_worker.close() time.sleep(0.1) local_worker.connect() assert local_worker.objects_count_remote() == 1 x_val = x_ptr.get() assert (x_val == x).all() local_worker.ws.shutdown() time.sleep(0.1) process_remote_worker.terminate()
def test_execute_plan_remotely(hook, start_proc): """Test plan execution remotely.""" hook.local_worker.is_client_worker = False @sy.func2plan(args_shape=[(1, )]) def my_plan(data): x = data * 2 y = (x - 2) * 10 return x + y x = th.tensor([-1, 2, 3]) local_res = my_plan(x) kwargs = { "id": "test_plan_worker", "host": "localhost", "port": 8799, "hook": hook } server = start_proc(WebsocketServerWorker, **kwargs) time.sleep(0.1) socket_pipe = WebsocketClientWorker(**kwargs) plan_ptr = my_plan.send(socket_pipe) x_ptr = x.send(socket_pipe) plan_res = plan_ptr(x_ptr).get() assert (plan_res == local_res).all() # delete remote object before websocket connection termination del x_ptr server.terminate() hook.local_worker.is_client_worker = True
def connect(local=False): """ Connects to the three hospitals on AWS and returns their WebsocketClientWorkers. Args: local (bool): Set to true for testing and start_local_workers.py will provide three test workers locally. Returns: tuple(class:`syft.workers.WebsocketClientWorker`): tuple of 3 connected workers. """ hook = sy.TorchHook(torch) if local: naming = LH else: naming = H h1 = WebsocketClientWorker( id=naming.h1_name, port=naming.h1_port, host=naming.h1_host, hook=hook ) logger.info("Connected to worker h1.") h2 = WebsocketClientWorker( id=naming.h2_name, port=naming.h2_port, host=naming.h2_host, hook=hook ) logger.info("Connected to worker h2.") h3 = WebsocketClientWorker( id=naming.h3_name, port=naming.h3_port, host=naming.h3_host, hook=hook ) logger.info("Connected to worker h3.") return h1, h2, h3
def test_create_already_existing_worker_with_different_type(hook, start_proc): # Shares tensor with bob bob = sy.VirtualWorker(hook, "bob") _ = th.tensor([1, 2, 3]).send(bob) kwargs = {"id": "fed1", "host": "localhost", "port": 8765, "hook": hook} server = start_proc(WebsocketServerWorker, kwargs) time.sleep(0.1) # Recreates bob as a different type of worker kwargs = {"id": "bob", "host": "localhost", "port": 8765, "hook": hook} with pytest.raises(RuntimeError): bob = WebsocketClientWorker(**kwargs) server.terminate()
def test_execute_plan_module_remotely(hook, start_proc): """Test plan execution remotely.""" hook.local_worker.is_client_worker = False class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(2, 3) self.fc2 = nn.Linear(3, 2) @sy.method2plan def forward(self, x): x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=0) net = Net() x = th.tensor([-1, 2.0]) local_res = net(x) assert not net.forward.is_built net.forward.build(x) kwargs = { "id": "test_plan_worker_2", "host": "localhost", "port": 8799, "hook": hook } server = start_proc(WebsocketServerWorker, **kwargs) time.sleep(0.1) socket_pipe = WebsocketClientWorker(**kwargs) plan_ptr = net.send(socket_pipe) x_ptr = x.send(socket_pipe) remote_res = plan_ptr(x_ptr).get() assert (remote_res == local_res).all() # delete remote object before websocket connection termination del x_ptr server.terminate() hook.local_worker.is_client_worker = True
def instantiate_websocket_client_worker(**kwargs): # pragma: no cover """ Helper function to instantiate the websocket client. If connection is refused, we wait a bit and try again. After 5 failed tries, a ConnectionRefusedError is raised. """ retry_counter = 0 connection_open = False while not connection_open: try: local_worker = WebsocketClientWorker(**kwargs) connection_open = True except ConnectionRefusedError as e: if retry_counter < 5: retry_counter += 1 time.sleep(0.1) else: raise e return local_worker
def test_websocket_garbage_collection(hook, start_proc): # Args for initializing the websocket server and client base_kwargs = {"id": "ws_gc", "host": "localhost", "port": 8777, "hook": hook} server = start_proc(WebsocketServerWorker, base_kwargs) time.sleep(0.1) client_worker = WebsocketClientWorker(**base_kwargs) sample_data = torch.tensor([1, 2, 3, 4]) sample_ptr = sample_data.send(client_worker) _ = sample_ptr.get() assert sample_data not in client_worker._objects client_worker.ws.shutdown() client_worker.ws.close() time.sleep(0.1) server.terminate()
def test_websocket_worker(hook): """Evaluates that you can do basic tensor operations using WebsocketServerWorker""" kwargs = {"id": "fed1", "host": "localhost", "port": 8765, "hook": hook} server = start_proc(WebsocketServerWorker, kwargs) time.sleep(0.1) x = torch.ones(5) socket_pipe = WebsocketClientWorker(**kwargs) x = x.send(socket_pipe) y = x + x y = y.get() assert (y == torch.ones(5) * 2).all() del x server.terminate()
def test_websocket_worker(hook, start_proc): """Evaluates that you can do basic tensor operations using WebsocketServerWorker""" kwargs = {"id": "fed", "host": "localhost", "port": 8766, "hook": hook} process_remote_worker = start_proc(WebsocketServerWorker, kwargs) time.sleep(0.1) x = torch.ones(5) local_worker = WebsocketClientWorker(**kwargs) x = x.send(local_worker) y = x + x y = y.get() assert (y == torch.ones(5) * 2).all() del x local_worker.ws.shutdown() time.sleep(0.1) local_worker.remove_worker_from_local_worker_registry() process_remote_worker.terminate()
def test_train_plan_locally_and_then_send_it(hook, start_proc): """Test training a plan locally and then executing it remotely.""" hook.local_worker.is_client_worker = False # Create toy model class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(2, 3) self.fc2 = nn.Linear(3, 2) @sy.method2plan def forward(self, x): x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=0) net = Net() # Create toy data x = th.tensor([-1, 2.0]) y = th.tensor([1.0]) # Train Model opt = optim.SGD(params=net.parameters(), lr=0.01) previous_loss = None for _ in range(5): # 1) erase previous gradients (if they exist) opt.zero_grad() # 2) make a prediction pred = net(x) # 3) calculate how much we missed loss = ((pred - y)**2).sum() # 4) figure out which weights caused us to miss loss.backward() # 5) change those weights opt.step() if previous_loss is not None: assert loss < previous_loss previous_loss = loss local_res = net(x) net.forward.build(x) kwargs = { "id": "test_plan_worker_3", "host": "localhost", "port": 8800, "hook": hook } server = start_proc(WebsocketServerWorker, **kwargs) time.sleep(0.1) socket_pipe = WebsocketClientWorker(**kwargs) plan_ptr = net.send(socket_pipe) x_ptr = x.send(socket_pipe) remote_res = plan_ptr(x_ptr).get() assert (remote_res == local_res).all() # delete remote object before websocket connection termination del x_ptr server.terminate() hook.local_worker.is_client_worker = True
def test_websocket_worker_basic(hook, start_proc, secure, tmpdir): """Evaluates that you can do basic tensor operations using WebsocketServerWorker in insecure and secure mode.""" def create_self_signed_cert(cert_path, key_path): # create a key pair k = crypto.PKey() k.generate_key(crypto.TYPE_RSA, 1024) # create a self-signed cert cert = crypto.X509() cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notAfter(1000) cert.set_pubkey(k) cert.sign(k, "sha1") # store keys and cert open(cert_path, "wb").write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) open(key_path, "wb").write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k)) kwargs = { "id": "secure_fed" if secure else "fed", "host": "localhost", "port": 8766 if secure else 8765, "hook": hook, } if secure: # Create cert and keys cert_path = tmpdir.join("test.crt") key_path = tmpdir.join("test.key") create_self_signed_cert(cert_path, key_path) kwargs["cert_path"] = cert_path kwargs["key_path"] = key_path process_remote_worker = start_proc(WebsocketServerWorker, **kwargs) time.sleep(0.1) x = torch.ones(5) if secure: # unused args del kwargs["cert_path"] del kwargs["key_path"] kwargs["secure"] = secure local_worker = WebsocketClientWorker(**kwargs) x = x.send(local_worker) y = x + x y = y.get() assert (y == torch.ones(5) * 2).all() del x local_worker.ws.shutdown() time.sleep(0.1) local_worker.remove_worker_from_local_worker_registry() process_remote_worker.terminate()
category, rand_cat_index = randomChoice(all_categories) #cat = category, it's not a random animal #rand_line_index is a relative index for a data point within the random category rand_cat_index line, rand_line_index = randomChoice(category_lines[category]) category_start_index = categories_start_index[category] absolute_index = category_start_index + rand_line_index return(absolute_index) ########## part 3 remote workers set-up and dataset distribution ############## hook = sy.TorchHook(torch) # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning #alice = sy.VirtualWorker(hook, id="alice") #bob = sy.VirtualWorker(hook, id="bob") #If you have your workers operating remotely, like on Raspberry PIs kwargs_websocket_alice = {"host": "6", "hook": hook} alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket_alice) kwargs_websocket_bob = {"host": "ip_bob", "hook": hook} bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket_bob) workers_virtual = [alice, bob] langDataset = LanguageDataset(array_lines_proper_dimension, categories_numpy) #assign the data points and the corresponding categories to workers. print("assignment starts") federated_train_loader = sy.FederatedDataLoader( langDataset.federate(workers_virtual), batch_size=args.batch_size) print("assignment completed") # time test for this part shows that time taken is not significant for this stage print("Generating list of batches for the workers...")
def main(): args = define_and_get_arguments() print(args) hook = sy.TorchHook(torch) host = "localhost" if args.use_virtual: alice = VirtualWorker(id="alice", hook=hook, verbose=args.verbose) bob = VirtualWorker(id="bob", hook=hook, verbose=args.verbose) charlie = VirtualWorker(id="charlie", hook=hook, verbose=args.verbose) else: kwargs_websocket = { "host": host, "hook": hook, "verbose": args.verbose } alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket) bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket) charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket) workers = [alice, bob, charlie] use_cuda = args.cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} # Search multiple times should still work tr_alice = alice.search("#mnist", "#alice", "#train_tag") tr_bob = bob.search("#mnist", "#bob", "#train_tag") tr_charlie = charlie.search("#mnist", "#charlie", "#train_tag") base_data = [] base_data.append(BaseDataset(tr_alice[0], tr_alice[1])) base_data.append(BaseDataset(tr_bob[0], tr_bob[1])) base_data.append(BaseDataset(tr_charlie[0], tr_charlie[1])) federated_train_loader = sy.FederatedDataLoader( FederatedDataset(base_data), batch_size=args.batch_size, shuffle=True, iter_per_worker=True, **kwargs, ) test_loader = torch.utils.data.DataLoader( datasets.MNIST( "../data", train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]), ), batch_size=args.test_batch_size, shuffle=True, **kwargs, ) model = Net().to(device) for epoch in range(1, args.epochs + 1): logger.info("Starting epoch %s/%s", epoch, args.epochs) model = train(model, device, federated_train_loader, args.lr, args.federate_after_n_batches) test(model, device, test_loader) if args.save_model: torch.save(model.state_dict(), "mnist_cnn.pt")
def main(): args = define_and_get_arguments() hook = sy.TorchHook(torch) host = "localhost" if args.use_virtual: alice = VirtualWorker(id="hospital_a", hook=hook, verbose=args.verbose) bob = VirtualWorker(id="hospital_b", hook=hook, verbose=args.verbose) charlie = VirtualWorker(id="hospital_c", hook=hook, verbose=args.verbose) else: kwargs_websocket = { "host": host, "hook": hook, "verbose": args.verbose } hospital_a = WebsocketClientWorker(id="hospital_a", port=8777, **kwargs_websocket) hospital_b = WebsocketClientWorker(id="hospital_b", port=8778, **kwargs_websocket) hospital_c = WebsocketClientWorker(id="hospital_c", port=8779, **kwargs_websocket) print() print( "*******************************************************************************************************" ) print("building training channels ...") print(" #hospital_a, remote tensor reference: ", hospital_a) print(" #hospital_b, remote tensor reference: ", hospital_b) print(" #hospital_c, remote tensor reference: ", hospital_c) print() workers = [hospital_a, hospital_b, hospital_c] use_cuda = args.cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} # Search multiple times should still work tr_hospital_a = hospital_a.search("#chest_xray", "#hospital_a", "#train_tag") tr_hospital_b = hospital_b.search("#chest_xray", "#hospital_b", "#train_tag") tr_hospital_c = hospital_c.search("#chest_xray", "#hospital_c", "#train_tag") base_data = [] base_data.append(BaseDataset(tr_hospital_a[0], tr_hospital_a[1])) base_data.append(BaseDataset(tr_hospital_b[0], tr_hospital_b[1])) base_data.append(BaseDataset(tr_hospital_c[0], tr_hospital_c[1])) federated_train_loader = sy.FederatedDataLoader( FederatedDataset(base_data), batch_size=args.batch_size, shuffle=True, iter_per_worker=True, **kwargs, ) data_transforms = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.RandomRotation(20), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) test = datasets.ImageFolder('chest_xray/small', data_transforms) local_test_loader = torch.utils.data.DataLoader( test, batch_size=args.test_batch_size, shuffle=True, **kwargs) model = resnet.resnet18_simple() # print("*******************************************************************************************************") # print("model architecture") # print(model) # print() print( "*******************************************************************************************************" ) print("starting federated learning ...") for epoch in range(1, args.epochs + 1): logger.info(" starting fl training epoch %s/%s", epoch, args.epochs) model = fl_train(model, device, federated_train_loader, args.lr, args.federate_after_n_batches) logger.info(" starting local inference") local_test(model, device, local_test_loader) if args.save_model: torch.save(model.state_dict(), "./log/chest_xray_resnet18.pt")
def main(): args = Arguments(False) hook = sy.TorchHook(torch) if args.use_virtual: alice = VirtualWorker(id="alice", hook=hook, verbose=args.verbose) bob = VirtualWorker(id="bob", hook=hook, verbose=args.verbose) charlie = VirtualWorker(id="charlie", hook=hook, verbose=args.verbose) else: kwargs_websocket = { "host": "localhost", "hook": hook, "verbose": args.verbose } alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket) bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket) charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket) workers = [alice, bob, charlie] clients_mem = torch.zeros(len(workers)) use_cuda = not args.no_cuda and torch.cuda.is_available() if use_cuda: # TODO Quickhack. Actually need to fix the problem moving the model to CUDA\n", torch.set_default_tensor_type(torch.cuda.FloatTensor) torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") print(device) kwargs = {"num_workers": 0, "pin_memory": False} if use_cuda else {} federated_train_loader = sy.FederatedDataLoader( datasets.MNIST( "../data", train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]), ).federate(tuple(workers)), batch_size=args.batch_size, shuffle=True, **kwargs, ) test_loader = torch.utils.data.DataLoader( datasets.MNIST( "../data", train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]), ), batch_size=args.test_batch_size, shuffle=True, **kwargs, ) start = time.time() model = Net().to(device) optimizer = optim.SGD(model.parameters(), lr=args.lr) for epoch in range(1, args.epochs + 1): train(args, model, device, federated_train_loader, optimizer, epoch, clients_mem) test(args, model, device, test_loader) t = time.time() print(t - start) if args.save_model: torch.save(model.state_dict(), "mnist_cnn.pt") end = time.time() print(end - start) print("Memory exchanged : ", clients_mem)
async def test_train_config_with_jit_trace_async( hook, start_proc): # pragma: no cover kwargs = { "id": "async_fit", "host": "localhost", "port": 8777, "hook": hook } # data = torch.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True) # target = torch.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False) # dataset_key = "xor" data, target = utils.create_gaussian_mixture_toy_data(100) dataset_key = "gaussian_mixture" mock_data = torch.zeros(1, 2) # TODO check reason for error (RuntimeError: This event loop is already running) when starting websocket server from pytest-asyncio environment # dataset = sy.BaseDataset(data, target) # process_remote_worker = start_proc(WebsocketServerWorker, dataset=(dataset, dataset_key), **kwargs) # time.sleep(0.1) local_worker = WebsocketClientWorker(**kwargs) @hook.torch.jit.script def loss_fn(pred, target): return ((target.view(pred.shape).float() - pred.float())**2).mean() class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(2, 3) self.fc2 = nn.Linear(3, 2) self.fc3 = nn.Linear(2, 1) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model_untraced = Net() model = torch.jit.trace(model_untraced, mock_data) pred = model(data) loss_before = loss_fn(target=target, pred=pred) # Create and send train config train_config = sy.TrainConfig(model=model, loss_fn=loss_fn, batch_size=2, lr=0.1) train_config.send(local_worker) for epoch in range(5): loss = await local_worker.async_fit(dataset_key=dataset_key) if PRINT_IN_UNITTESTS: # pragma: no cover print("-" * 50) print("Iteration %s: alice's loss: %s" % (epoch, loss)) new_model = train_config.model_ptr.get() assert not (model.fc1._parameters["weight"] == new_model.obj.fc1._parameters["weight"]).all() assert not (model.fc2._parameters["weight"] == new_model.obj.fc2._parameters["weight"]).all() assert not (model.fc3._parameters["weight"] == new_model.obj.fc3._parameters["weight"]).all() assert not (model.fc1._parameters["bias"] == new_model.obj.fc1._parameters["bias"]).all() assert not (model.fc2._parameters["bias"] == new_model.obj.fc2._parameters["bias"]).all() assert not (model.fc3._parameters["bias"] == new_model.obj.fc3._parameters["bias"]).all() new_model.obj.eval() pred = new_model.obj(data) loss_after = loss_fn(target=target, pred=pred) if PRINT_IN_UNITTESTS: # pragma: no cover print("Loss before training: {}".format(loss_before)) print("Loss after training: {}".format(loss_after)) local_worker.ws.shutdown() # process_remote_worker.terminate() assert loss_after < loss_before
def test_train_config_with_jit_trace_sync(hook, start_proc): # pragma: no cover kwargs = { "id": "sync_fit", "host": "localhost", "port": 9000, "hook": hook } # data = torch.tensor([[-1, 2.0], [0, 1.1], [-1, 2.1], [0, 1.2]], requires_grad=True) # target = torch.tensor([[1], [0], [1], [0]]) data, target = utils.create_gaussian_mixture_toy_data(100) dataset = sy.BaseDataset(data, target) dataset_key = "gaussian_mixture" process_remote_worker = start_proc(WebsocketServerWorker, dataset=(dataset, dataset_key), **kwargs) time.sleep(0.1) local_worker = WebsocketClientWorker(**kwargs) @hook.torch.jit.script def loss_fn(pred, target): return ((target.view(pred.shape).float() - pred.float())**2).mean() class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(2, 3) self.fc2 = nn.Linear(3, 2) self.fc3 = nn.Linear(2, 1) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model_untraced = Net() model = torch.jit.trace(model_untraced, data) pred = model(data) loss_before = loss_fn(pred=pred, target=target) # Create and send train config train_config = sy.TrainConfig(model=model, loss_fn=loss_fn, batch_size=2, epochs=1) train_config.send(local_worker) for epoch in range(5): loss = local_worker.fit(dataset_key=dataset_key) if PRINT_IN_UNITTESTS: # pragma: no cover print("-" * 50) print("Iteration %s: alice's loss: %s" % (epoch, loss)) new_model = train_config.model_ptr.get() # assert that the new model has updated (modified) parameters assert not (model.fc1._parameters["weight"] == new_model.obj.fc1._parameters["weight"]).all() assert not (model.fc2._parameters["weight"] == new_model.obj.fc2._parameters["weight"]).all() assert not (model.fc3._parameters["weight"] == new_model.obj.fc3._parameters["weight"]).all() assert not (model.fc1._parameters["bias"] == new_model.obj.fc1._parameters["bias"]).all() assert not (model.fc2._parameters["bias"] == new_model.obj.fc2._parameters["bias"]).all() assert not (model.fc3._parameters["bias"] == new_model.obj.fc3._parameters["bias"]).all() new_model.obj.eval() pred = new_model.obj(data) loss_after = loss_fn(pred=pred, target=target) if PRINT_IN_UNITTESTS: # pragma: no cover print("Loss before training: {}".format(loss_before)) print("Loss after training: {}".format(loss_after)) local_worker.ws.shutdown() del local_worker time.sleep(0.1) process_remote_worker.terminate() assert loss_after < loss_before
def experiment(no_cuda): # Creating num_workers clients hook = sy.TorchHook(torch) # Initializing arguments, with GPU usage or not args = Arguments(no_cuda) if args.use_virtual: alice = VirtualWorker(id="alice", hook=hook, verbose=args.verbose) bob = VirtualWorker(id="bob", hook=hook, verbose=args.verbose) charlie = VirtualWorker(id="charlie", hook=hook, verbose=args.verbose) else: kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": args.verbose} alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket) bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket) charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket) use_cuda = not args.no_cuda and torch.cuda.is_available() if use_cuda: # TODO Quickhack. Actually need to fix the problem moving the model to CUDA\n", torch.set_default_tensor_type(torch.cuda.FloatTensor) torch.manual_seed(args.seed) clients = [alice, bob, charlie] clients_mem = torch.zeros(len(clients)) device = torch.device("cuda" if use_cuda else "cpu") kwargs = {'num_workers': 0, 'pin_memory': False} if use_cuda else {} # Federated data loader federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])) .federate(clients), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=args.test_batch_size, shuffle=True, **kwargs) #creating the models for each client models,optimizers = [], [] #print(device) for i in range(len(clients)): #print(i) models.append(Net1().to(device)) models[i] = models[i].send(clients[i]) optimizers.append(optim.SGD(params=models[i].parameters(),lr=0.1)) start = time.time() #%%time model = Net2().to(device) optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment for epoch in range(1, args.epochs + 1): train(args, model, device, federated_train_loader, optimizer, epoch, models, optimizers,clients_mem) test(args, model, device, test_loader, models) t = time.time() print(t-start) if (args.save_model): torch.save(model.state_dict(), "mnist_cnn.pt") end = time.time() print(end - start) print("Memory exchanged : ",clients_mem) return clients_mem
batch_size = 4 optimizer_args = {"lr" : 0.1, "weight_decay" : 0.01} max_nr_batches = -1 # not used in this example shuffle = True train_config = syft.TrainConfig(model=traced_model, loss_fn=loss_fn, optimizer=optimizer, batch_size=batch_size, optimizer_args=optimizer_args, epochs=5, shuffle=shuffle) arw = {"host":"10.0.0.1","hook":hook} h1 = WebsocketClientWorker(id="h1",port=8778,**arw) train_config.send(h1) message = h1.create_message_execute_command(command_name="start_monitoring",command_owner="self") serialized_message = syft.serde.serialize(message) h1._recv_msg(serialized_message) time.sleep(3) for epoch in range(10): time.sleep(3) loss = h1.fit(dataset_key="train") # ask alice to train using "xor" dataset print("-" * 50) print("Iteration %s: h1's loss: %s" % (epoch, loss))