Beispiel #1
0
def test_connect_close(hook, start_proc):
    kwargs = {"id": "fed", "host": "localhost", "port": 8763, "hook": hook}
    process_remote_worker = start_proc(WebsocketServerWorker, **kwargs)

    time.sleep(0.1)

    kwargs = {"id": "fed", "host": "localhost", "port": 8763, "hook": hook}
    local_worker = WebsocketClientWorker(**kwargs)

    x = torch.tensor([1, 2, 3])
    x_ptr = x.send(local_worker)

    assert local_worker.objects_count_remote() == 1

    local_worker.close()

    time.sleep(0.1)

    local_worker.connect()

    assert local_worker.objects_count_remote() == 1

    x_val = x_ptr.get()
    assert (x_val == x).all()

    local_worker.ws.shutdown()

    time.sleep(0.1)

    process_remote_worker.terminate()
Beispiel #2
0
def test_websocket_worker_multiple_output_response(hook, start_proc):
    """Evaluates that you can do basic tensor operations using
    WebsocketServerWorker"""

    kwargs = {
        "id": "socket_multiple_output",
        "host": "localhost",
        "port": 8768,
        "hook": hook
    }
    server = start_proc(WebsocketServerWorker, kwargs)

    time.sleep(0.1)
    x = torch.tensor([1.0, 3, 2])

    socket_pipe = WebsocketClientWorker(**kwargs)

    x = x.send(socket_pipe)
    p1, p2 = torch.sort(x)
    x1, x2 = p1.get(), p2.get()

    assert (x1 == torch.tensor([1.0, 2, 3])).all()
    assert (x2 == torch.tensor([0, 2, 1])).all()

    del x

    socket_pipe.ws.shutdown()
    time.sleep(0.1)
    server.terminate()
Beispiel #3
0
def test_websocket_worker_multiple_output_response(hook, start_proc):
    """Evaluates that you can do basic tensor operations using
    WebsocketServerWorker"""

    kwargs = {
        "id": "socket_multiple_output",
        "host": "localhost",
        "port": 8768,
        "hook": hook
    }
    process_remote_worker = start_proc(WebsocketServerWorker, **kwargs)

    time.sleep(0.1)
    x = torch.tensor([1.0, 3, 2])

    local_worker = WebsocketClientWorker(**kwargs)

    x = x.send(local_worker)
    p1, p2 = torch.sort(x)
    x1, x2 = p1.get(), p2.get()

    assert (x1 == torch.tensor([1.0, 2, 3])).all()
    assert (x2 == torch.tensor([0, 2, 1])).all()

    x.get()  # retrieve remote object before closing the websocket connection

    local_worker.ws.shutdown()
    process_remote_worker.terminate()
Beispiel #4
0
def test_objects_count_remote(hook, start_proc):

    kwargs = {"id": "fed", "host": "localhost", "port": 8764, "hook": hook}
    process_remote_worker = start_proc(WebsocketServerWorker, kwargs)

    time.sleep(0.1)

    kwargs = {"id": "fed", "host": "localhost", "port": 8764, "hook": hook}
    local_worker = WebsocketClientWorker(**kwargs)

    x = torch.tensor([1, 2, 3]).send(local_worker)

    nr_objects = local_worker.objects_count_remote()
    assert nr_objects == 1

    y = torch.tensor([4, 5, 6]).send(local_worker)
    nr_objects = local_worker.objects_count_remote()
    assert nr_objects == 2

    x.get()
    nr_objects = local_worker.objects_count_remote()
    assert nr_objects == 1

    # delete remote object before terminating the websocket connection
    del y
    time.sleep(0.1)
    local_worker.ws.shutdown()
    time.sleep(0.1)
    local_worker.remove_worker_from_local_worker_registry()
    process_remote_worker.terminate()
Beispiel #5
0
def test_list_objects_remote(hook, start_proc):

    kwargs = {"id": "fed", "host": "localhost", "port": 8765, "hook": hook}
    process_remote_fed1 = start_proc(WebsocketServerWorker, kwargs)

    time.sleep(0.1)

    kwargs = {"id": "fed", "host": "localhost", "port": 8765, "hook": hook}
    local_worker = WebsocketClientWorker(**kwargs)

    x = torch.tensor([1, 2, 3]).send(local_worker)

    res = local_worker.list_objects_remote()
    res_dict = eval(res.replace("tensor", "torch.tensor"))
    assert len(res_dict) == 1

    y = torch.tensor([4, 5, 6]).send(local_worker)
    res = local_worker.list_objects_remote()
    res_dict = eval(res.replace("tensor", "torch.tensor"))
    assert len(res_dict) == 2

    # delete x before terminating the websocket connection
    del x
    del y
    time.sleep(0.1)
    local_worker.ws.shutdown()
    time.sleep(0.1)
    local_worker.remove_worker_from_local_worker_registry()
    process_remote_fed1.terminate()
Beispiel #6
0
def test_websocket_garbage_collection(hook, start_proc):
    # Args for initializing the websocket server and client
    base_kwargs = {"id": "ws_gc", "host": "localhost", "port": 8777, "hook": hook}
    server = start_proc(WebsocketServerWorker, base_kwargs)

    time.sleep(0.1)
    client_worker = WebsocketClientWorker(**base_kwargs)

    sample_data = torch.tensor([1, 2, 3, 4])
    sample_ptr = sample_data.send(client_worker)

    _ = sample_ptr.get()
    assert sample_data not in client_worker._objects

    client_worker.ws.shutdown()
    client_worker.ws.close()
    time.sleep(0.1)
    server.terminate()
Beispiel #7
0
def test_train_config_with_jit_trace_sync(hook,
                                          start_proc):  # pragma: no cover
    kwargs = {
        "id": "sync_fit",
        "host": "localhost",
        "port": 9000,
        "hook": hook
    }
    # data = torch.tensor([[-1, 2.0], [0, 1.1], [-1, 2.1], [0, 1.2]], requires_grad=True)
    # target = torch.tensor([[1], [0], [1], [0]])

    data, target = utils.create_gaussian_mixture_toy_data(100)

    dataset = sy.BaseDataset(data, target)

    dataset_key = "gaussian_mixture"
    process_remote_worker = start_proc(WebsocketServerWorker,
                                       dataset=(dataset, dataset_key),
                                       **kwargs)

    time.sleep(0.1)

    local_worker = WebsocketClientWorker(**kwargs)

    @hook.torch.jit.script
    def loss_fn(pred, target):
        return ((target.view(pred.shape).float() - pred.float())**2).mean()

    class Net(torch.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(2, 3)
            self.fc2 = nn.Linear(3, 2)
            self.fc3 = nn.Linear(2, 1)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    model_untraced = Net()

    model = torch.jit.trace(model_untraced, data)

    pred = model(data)
    loss_before = loss_fn(pred=pred, target=target)

    # Create and send train config
    train_config = sy.TrainConfig(model=model,
                                  loss_fn=loss_fn,
                                  batch_size=2,
                                  epochs=1)
    train_config.send(local_worker)

    for epoch in range(5):
        loss = local_worker.fit(dataset_key=dataset_key)
        if PRINT_IN_UNITTESTS:  # pragma: no cover
            print("-" * 50)
            print("Iteration %s: alice's loss: %s" % (epoch, loss))

    new_model = train_config.model_ptr.get()

    # assert that the new model has updated (modified) parameters
    assert not (model.fc1._parameters["weight"]
                == new_model.obj.fc1._parameters["weight"]).all()
    assert not (model.fc2._parameters["weight"]
                == new_model.obj.fc2._parameters["weight"]).all()
    assert not (model.fc3._parameters["weight"]
                == new_model.obj.fc3._parameters["weight"]).all()
    assert not (model.fc1._parameters["bias"]
                == new_model.obj.fc1._parameters["bias"]).all()
    assert not (model.fc2._parameters["bias"]
                == new_model.obj.fc2._parameters["bias"]).all()
    assert not (model.fc3._parameters["bias"]
                == new_model.obj.fc3._parameters["bias"]).all()

    new_model.obj.eval()
    pred = new_model.obj(data)
    loss_after = loss_fn(pred=pred, target=target)

    if PRINT_IN_UNITTESTS:  # pragma: no cover
        print("Loss before training: {}".format(loss_before))
        print("Loss after training: {}".format(loss_after))

    local_worker.ws.shutdown()
    del local_worker

    time.sleep(0.1)
    process_remote_worker.terminate()

    assert loss_after < loss_before
Beispiel #8
0
async def test_train_config_with_jit_trace_async(
        hook, start_proc):  # pragma: no cover
    kwargs = {
        "id": "async_fit",
        "host": "localhost",
        "port": 8777,
        "hook": hook
    }
    # data = torch.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)
    # target = torch.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)
    # dataset_key = "xor"
    data, target = utils.create_gaussian_mixture_toy_data(100)
    dataset_key = "gaussian_mixture"

    mock_data = torch.zeros(1, 2)

    # TODO check reason for error (RuntimeError: This event loop is already running) when starting websocket server from pytest-asyncio environment
    # dataset = sy.BaseDataset(data, target)

    # process_remote_worker = start_proc(WebsocketServerWorker, dataset=(dataset, dataset_key), **kwargs)

    # time.sleep(0.1)

    local_worker = WebsocketClientWorker(**kwargs)

    @hook.torch.jit.script
    def loss_fn(pred, target):
        return ((target.view(pred.shape).float() - pred.float())**2).mean()

    class Net(torch.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(2, 3)
            self.fc2 = nn.Linear(3, 2)
            self.fc3 = nn.Linear(2, 1)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    model_untraced = Net()

    model = torch.jit.trace(model_untraced, mock_data)

    pred = model(data)
    loss_before = loss_fn(target=target, pred=pred)

    # Create and send train config
    train_config = sy.TrainConfig(model=model,
                                  loss_fn=loss_fn,
                                  batch_size=2,
                                  lr=0.1)
    train_config.send(local_worker)

    for epoch in range(5):
        loss = await local_worker.async_fit(dataset_key=dataset_key)
        if PRINT_IN_UNITTESTS:  # pragma: no cover
            print("-" * 50)
            print("Iteration %s: alice's loss: %s" % (epoch, loss))

    new_model = train_config.model_ptr.get()

    assert not (model.fc1._parameters["weight"]
                == new_model.obj.fc1._parameters["weight"]).all()
    assert not (model.fc2._parameters["weight"]
                == new_model.obj.fc2._parameters["weight"]).all()
    assert not (model.fc3._parameters["weight"]
                == new_model.obj.fc3._parameters["weight"]).all()
    assert not (model.fc1._parameters["bias"]
                == new_model.obj.fc1._parameters["bias"]).all()
    assert not (model.fc2._parameters["bias"]
                == new_model.obj.fc2._parameters["bias"]).all()
    assert not (model.fc3._parameters["bias"]
                == new_model.obj.fc3._parameters["bias"]).all()

    new_model.obj.eval()
    pred = new_model.obj(data)
    loss_after = loss_fn(target=target, pred=pred)
    if PRINT_IN_UNITTESTS:  # pragma: no cover
        print("Loss before training: {}".format(loss_before))
        print("Loss after training: {}".format(loss_after))

    local_worker.ws.shutdown()
    # process_remote_worker.terminate()

    assert loss_after < loss_before
Beispiel #9
0
def main():
    args = define_and_get_arguments()

    hook = sy.TorchHook(torch)

    host = "localhost"

    if args.use_virtual:
        alice = VirtualWorker(id="hospital_a", hook=hook, verbose=args.verbose)
        bob = VirtualWorker(id="hospital_b", hook=hook, verbose=args.verbose)
        charlie = VirtualWorker(id="hospital_c",
                                hook=hook,
                                verbose=args.verbose)
    else:
        kwargs_websocket = {
            "host": host,
            "hook": hook,
            "verbose": args.verbose
        }
        hospital_a = WebsocketClientWorker(id="hospital_a",
                                           port=8777,
                                           **kwargs_websocket)
        hospital_b = WebsocketClientWorker(id="hospital_b",
                                           port=8778,
                                           **kwargs_websocket)
        hospital_c = WebsocketClientWorker(id="hospital_c",
                                           port=8779,
                                           **kwargs_websocket)

        print()
        print(
            "*******************************************************************************************************"
        )
        print("building training channels ...")
        print(" #hospital_a, remote tensor reference: ", hospital_a)
        print(" #hospital_b, remote tensor reference: ", hospital_b)
        print(" #hospital_c, remote tensor reference: ", hospital_c)
        print()

    workers = [hospital_a, hospital_b, hospital_c]

    use_cuda = args.cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

    # Search multiple times should still work
    tr_hospital_a = hospital_a.search("#chest_xray", "#hospital_a",
                                      "#train_tag")
    tr_hospital_b = hospital_b.search("#chest_xray", "#hospital_b",
                                      "#train_tag")
    tr_hospital_c = hospital_c.search("#chest_xray", "#hospital_c",
                                      "#train_tag")

    base_data = []
    base_data.append(BaseDataset(tr_hospital_a[0], tr_hospital_a[1]))
    base_data.append(BaseDataset(tr_hospital_b[0], tr_hospital_b[1]))
    base_data.append(BaseDataset(tr_hospital_c[0], tr_hospital_c[1]))

    federated_train_loader = sy.FederatedDataLoader(
        FederatedDataset(base_data),
        batch_size=args.batch_size,
        shuffle=True,
        iter_per_worker=True,
        **kwargs,
    )

    data_transforms = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    test = datasets.ImageFolder('chest_xray/small', data_transforms)

    local_test_loader = torch.utils.data.DataLoader(
        test, batch_size=args.test_batch_size, shuffle=True, **kwargs)

    model = resnet.resnet18_simple()
    # print("*******************************************************************************************************")
    # print("model architecture")
    # print(model)
    # print()

    print(
        "*******************************************************************************************************"
    )
    print("starting federated learning ...")
    for epoch in range(1, args.epochs + 1):
        logger.info(" starting fl training epoch %s/%s", epoch, args.epochs)
        model = fl_train(model, device, federated_train_loader, args.lr,
                         args.federate_after_n_batches)

        logger.info(" starting local inference")
        local_test(model, device, local_test_loader)

    if args.save_model:
        torch.save(model.state_dict(), "./log/chest_xray_resnet18.pt")
Beispiel #10
0
def test_websocket_worker_basic(hook, start_proc, secure, tmpdir):
    """Evaluates that you can do basic tensor operations using
    WebsocketServerWorker in insecure and secure mode."""
    def create_self_signed_cert(cert_path, key_path):
        # create a key pair
        k = crypto.PKey()
        k.generate_key(crypto.TYPE_RSA, 1024)

        # create a self-signed cert
        cert = crypto.X509()
        cert.gmtime_adj_notBefore(0)
        cert.gmtime_adj_notAfter(1000)
        cert.set_pubkey(k)
        cert.sign(k, "sha1")

        # store keys and cert
        open(cert_path,
             "wb").write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
        open(key_path,
             "wb").write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k))

    kwargs = {
        "id": "secure_fed" if secure else "fed",
        "host": "localhost",
        "port": 8766 if secure else 8765,
        "hook": hook,
    }

    if secure:
        # Create cert and keys
        cert_path = tmpdir.join("test.crt")
        key_path = tmpdir.join("test.key")
        create_self_signed_cert(cert_path, key_path)
        kwargs["cert_path"] = cert_path
        kwargs["key_path"] = key_path

    process_remote_worker = start_proc(WebsocketServerWorker, **kwargs)

    time.sleep(0.1)
    x = torch.ones(5)

    if secure:
        # unused args
        del kwargs["cert_path"]
        del kwargs["key_path"]

    kwargs["secure"] = secure
    local_worker = WebsocketClientWorker(**kwargs)

    x = x.send(local_worker)
    y = x + x
    y = y.get()

    assert (y == torch.ones(5) * 2).all()

    del x

    local_worker.ws.shutdown()
    time.sleep(0.1)
    local_worker.remove_worker_from_local_worker_registry()
    process_remote_worker.terminate()
Beispiel #11
0
    category, rand_cat_index = randomChoice(all_categories) #cat = category, it's not a random animal
    #rand_line_index is a relative index for a data point within the random category rand_cat_index
    line, rand_line_index = randomChoice(category_lines[category])
    category_start_index = categories_start_index[category]
    absolute_index = category_start_index + rand_line_index
    return(absolute_index)
  
########## part 3 remote workers set-up and dataset distribution ##############

hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning
#alice = sy.VirtualWorker(hook, id="alice")  
#bob = sy.VirtualWorker(hook, id="bob")  

#If you have your workers operating remotely, like on Raspberry PIs
kwargs_websocket_alice = {"host": "6", "hook": hook}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket_alice)
kwargs_websocket_bob = {"host": "ip_bob", "hook": hook}
bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket_bob)
workers_virtual = [alice, bob]

langDataset =  LanguageDataset(array_lines_proper_dimension, categories_numpy)

#assign the data points and the corresponding categories to workers.
print("assignment starts")
federated_train_loader = sy.FederatedDataLoader(
            langDataset.federate(workers_virtual),
            batch_size=args.batch_size)
print("assignment completed") 
# time test for this part shows that time taken is not significant for this stage

print("Generating list of batches for the workers...")
def experiment(no_cuda):

    # Creating num_workers clients

    hook = sy.TorchHook(torch)



    # Initializing arguments, with GPU usage or not
    args = Arguments(no_cuda)

    if args.use_virtual:
        alice = VirtualWorker(id="alice", hook=hook, verbose=args.verbose)
        bob = VirtualWorker(id="bob", hook=hook, verbose=args.verbose)
        charlie = VirtualWorker(id="charlie", hook=hook, verbose=args.verbose)
    else:
        kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": args.verbose}
        alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
        bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
        charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket)
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    if use_cuda:
    # TODO Quickhack. Actually need to fix the problem moving the model to CUDA\n",
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    torch.manual_seed(args.seed)

    clients = [alice, bob, charlie]
    clients_mem = torch.zeros(len(clients))

    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 0, 'pin_memory': False} if use_cuda else {}


    # Federated data loader
    federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader
      datasets.MNIST('../data', train=True, download=True,
                     transform=transforms.Compose([
                         transforms.ToTensor(),
                         transforms.Normalize((0.1307,), (0.3081,))
                     ]))
      .federate(clients), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
      batch_size=args.batch_size, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(
      datasets.MNIST('../data', train=False, transform=transforms.Compose([
                         transforms.ToTensor(),
                         transforms.Normalize((0.1307,), (0.3081,))
                     ])),
      batch_size=args.test_batch_size, shuffle=True, **kwargs)


    #creating the models for each client
    models,optimizers = [], []
    #print(device)
    for i in range(len(clients)):
        #print(i)
        models.append(Net1().to(device))
        models[i] = models[i].send(clients[i])
        optimizers.append(optim.SGD(params=models[i].parameters(),lr=0.1))



    start = time.time()
    #%%time
    model = Net2().to(device)
    optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, federated_train_loader, optimizer, epoch, models, optimizers,clients_mem)
        test(args, model, device, test_loader, models)
        t = time.time()
        print(t-start)
    if (args.save_model):
        torch.save(model.state_dict(), "mnist_cnn.pt")

    end = time.time()
    print(end - start)
    print("Memory exchanged : ",clients_mem)
    return clients_mem