Beispiel #1
0
def test_websocket_workers_search(hook, start_proc):
    """Evaluates that a client can search and find tensors that belong
    to another party"""
    # Sample tensor to store on the server
    sample_data = torch.tensor([1, 2, 3, 4]).tag("#sample_data", "#another_tag")
    # Args for initializing the websocket server and client
    base_kwargs = {"id": "fed2", "host": "localhost", "port": 8767, "hook": hook}
    server_kwargs = base_kwargs
    server_kwargs["data"] = [sample_data]
    server = start_proc(WebsocketServerWorker, server_kwargs)

    time.sleep(0.1)

    client_worker = WebsocketClientWorker(**base_kwargs)

    # Search for the tensor located on the server by using its tag
    results = client_worker.search("#sample_data", "#another_tag")

    assert results
    assert results[0].owner.id == "me"
    assert results[0].location.id == "fed2"

    # Search multiple times should still work
    results = client_worker.search("#sample_data", "#another_tag")

    assert results
    assert results[0].owner.id == "me"
    assert results[0].location.id == "fed2"

    client_worker.ws.shutdown()
    client_worker.ws.close()
    time.sleep(0.1)
    server.terminate()
Beispiel #2
0
def main():
    args = define_and_get_arguments()

    hook = sy.TorchHook(torch)

    host = "localhost"

    if args.use_virtual:
        alice = VirtualWorker(id="hospital_a", hook=hook, verbose=args.verbose)
        bob = VirtualWorker(id="hospital_b", hook=hook, verbose=args.verbose)
        charlie = VirtualWorker(id="hospital_c",
                                hook=hook,
                                verbose=args.verbose)
    else:
        kwargs_websocket = {
            "host": host,
            "hook": hook,
            "verbose": args.verbose
        }
        hospital_a = WebsocketClientWorker(id="hospital_a",
                                           port=8777,
                                           **kwargs_websocket)
        hospital_b = WebsocketClientWorker(id="hospital_b",
                                           port=8778,
                                           **kwargs_websocket)
        hospital_c = WebsocketClientWorker(id="hospital_c",
                                           port=8779,
                                           **kwargs_websocket)

        print()
        print(
            "*******************************************************************************************************"
        )
        print("building training channels ...")
        print(" #hospital_a, remote tensor reference: ", hospital_a)
        print(" #hospital_b, remote tensor reference: ", hospital_b)
        print(" #hospital_c, remote tensor reference: ", hospital_c)
        print()

    workers = [hospital_a, hospital_b, hospital_c]

    use_cuda = args.cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

    # Search multiple times should still work
    tr_hospital_a = hospital_a.search("#chest_xray", "#hospital_a",
                                      "#train_tag")
    tr_hospital_b = hospital_b.search("#chest_xray", "#hospital_b",
                                      "#train_tag")
    tr_hospital_c = hospital_c.search("#chest_xray", "#hospital_c",
                                      "#train_tag")

    base_data = []
    base_data.append(BaseDataset(tr_hospital_a[0], tr_hospital_a[1]))
    base_data.append(BaseDataset(tr_hospital_b[0], tr_hospital_b[1]))
    base_data.append(BaseDataset(tr_hospital_c[0], tr_hospital_c[1]))

    federated_train_loader = sy.FederatedDataLoader(
        FederatedDataset(base_data),
        batch_size=args.batch_size,
        shuffle=True,
        iter_per_worker=True,
        **kwargs,
    )

    data_transforms = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    test = datasets.ImageFolder('chest_xray/small', data_transforms)

    local_test_loader = torch.utils.data.DataLoader(
        test, batch_size=args.test_batch_size, shuffle=True, **kwargs)

    model = resnet.resnet18_simple()
    # print("*******************************************************************************************************")
    # print("model architecture")
    # print(model)
    # print()

    print(
        "*******************************************************************************************************"
    )
    print("starting federated learning ...")
    for epoch in range(1, args.epochs + 1):
        logger.info(" starting fl training epoch %s/%s", epoch, args.epochs)
        model = fl_train(model, device, federated_train_loader, args.lr,
                         args.federate_after_n_batches)

        logger.info(" starting local inference")
        local_test(model, device, local_test_loader)

    if args.save_model:
        torch.save(model.state_dict(), "./log/chest_xray_resnet18.pt")