def test_websocket_workers_search(hook, start_proc): """Evaluates that a client can search and find tensors that belong to another party""" # Sample tensor to store on the server sample_data = torch.tensor([1, 2, 3, 4]).tag("#sample_data", "#another_tag") # Args for initializing the websocket server and client base_kwargs = {"id": "fed2", "host": "localhost", "port": 8767, "hook": hook} server_kwargs = base_kwargs server_kwargs["data"] = [sample_data] server = start_proc(WebsocketServerWorker, server_kwargs) time.sleep(0.1) client_worker = WebsocketClientWorker(**base_kwargs) # Search for the tensor located on the server by using its tag results = client_worker.search("#sample_data", "#another_tag") assert results assert results[0].owner.id == "me" assert results[0].location.id == "fed2" # Search multiple times should still work results = client_worker.search("#sample_data", "#another_tag") assert results assert results[0].owner.id == "me" assert results[0].location.id == "fed2" client_worker.ws.shutdown() client_worker.ws.close() time.sleep(0.1) server.terminate()
def main(): args = define_and_get_arguments() hook = sy.TorchHook(torch) host = "localhost" if args.use_virtual: alice = VirtualWorker(id="hospital_a", hook=hook, verbose=args.verbose) bob = VirtualWorker(id="hospital_b", hook=hook, verbose=args.verbose) charlie = VirtualWorker(id="hospital_c", hook=hook, verbose=args.verbose) else: kwargs_websocket = { "host": host, "hook": hook, "verbose": args.verbose } hospital_a = WebsocketClientWorker(id="hospital_a", port=8777, **kwargs_websocket) hospital_b = WebsocketClientWorker(id="hospital_b", port=8778, **kwargs_websocket) hospital_c = WebsocketClientWorker(id="hospital_c", port=8779, **kwargs_websocket) print() print( "*******************************************************************************************************" ) print("building training channels ...") print(" #hospital_a, remote tensor reference: ", hospital_a) print(" #hospital_b, remote tensor reference: ", hospital_b) print(" #hospital_c, remote tensor reference: ", hospital_c) print() workers = [hospital_a, hospital_b, hospital_c] use_cuda = args.cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} # Search multiple times should still work tr_hospital_a = hospital_a.search("#chest_xray", "#hospital_a", "#train_tag") tr_hospital_b = hospital_b.search("#chest_xray", "#hospital_b", "#train_tag") tr_hospital_c = hospital_c.search("#chest_xray", "#hospital_c", "#train_tag") base_data = [] base_data.append(BaseDataset(tr_hospital_a[0], tr_hospital_a[1])) base_data.append(BaseDataset(tr_hospital_b[0], tr_hospital_b[1])) base_data.append(BaseDataset(tr_hospital_c[0], tr_hospital_c[1])) federated_train_loader = sy.FederatedDataLoader( FederatedDataset(base_data), batch_size=args.batch_size, shuffle=True, iter_per_worker=True, **kwargs, ) data_transforms = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.RandomRotation(20), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) test = datasets.ImageFolder('chest_xray/small', data_transforms) local_test_loader = torch.utils.data.DataLoader( test, batch_size=args.test_batch_size, shuffle=True, **kwargs) model = resnet.resnet18_simple() # print("*******************************************************************************************************") # print("model architecture") # print(model) # print() print( "*******************************************************************************************************" ) print("starting federated learning ...") for epoch in range(1, args.epochs + 1): logger.info(" starting fl training epoch %s/%s", epoch, args.epochs) model = fl_train(model, device, federated_train_loader, args.lr, args.federate_after_n_batches) logger.info(" starting local inference") local_test(model, device, local_test_loader) if args.save_model: torch.save(model.state_dict(), "./log/chest_xray_resnet18.pt")