Beispiel #1
0
def test_list_objects_remote(hook, start_proc):

    kwargs = {"id": "fed", "host": "localhost", "port": 8765, "hook": hook}
    process_remote_fed1 = start_proc(WebsocketServerWorker, kwargs)

    time.sleep(0.1)

    kwargs = {"id": "fed", "host": "localhost", "port": 8765, "hook": hook}
    local_worker = WebsocketClientWorker(**kwargs)

    x = torch.tensor([1, 2, 3]).send(local_worker)

    res = local_worker.list_objects_remote()
    res_dict = eval(res.replace("tensor", "torch.tensor"))
    assert len(res_dict) == 1

    y = torch.tensor([4, 5, 6]).send(local_worker)
    res = local_worker.list_objects_remote()
    res_dict = eval(res.replace("tensor", "torch.tensor"))
    assert len(res_dict) == 2

    # delete x before terminating the websocket connection
    del x
    del y
    time.sleep(0.1)
    local_worker.ws.shutdown()
    time.sleep(0.1)
    local_worker.remove_worker_from_local_worker_registry()
    process_remote_fed1.terminate()
Beispiel #2
0
def test_objects_count_remote(hook, start_proc):

    kwargs = {"id": "fed", "host": "localhost", "port": 8764, "hook": hook}
    process_remote_worker = start_proc(WebsocketServerWorker, kwargs)

    time.sleep(0.1)

    kwargs = {"id": "fed", "host": "localhost", "port": 8764, "hook": hook}
    local_worker = WebsocketClientWorker(**kwargs)

    x = torch.tensor([1, 2, 3]).send(local_worker)

    nr_objects = local_worker.objects_count_remote()
    assert nr_objects == 1

    y = torch.tensor([4, 5, 6]).send(local_worker)
    nr_objects = local_worker.objects_count_remote()
    assert nr_objects == 2

    x.get()
    nr_objects = local_worker.objects_count_remote()
    assert nr_objects == 1

    # delete remote object before terminating the websocket connection
    del y
    time.sleep(0.1)
    local_worker.ws.shutdown()
    time.sleep(0.1)
    local_worker.remove_worker_from_local_worker_registry()
    process_remote_worker.terminate()
Beispiel #3
0
def test_websocket_worker_multiple_output_response(hook, start_proc):
    """Evaluates that you can do basic tensor operations using
    WebsocketServerWorker"""

    kwargs = {
        "id": "socket_multiple_output",
        "host": "localhost",
        "port": 8768,
        "hook": hook
    }
    server = start_proc(WebsocketServerWorker, kwargs)

    time.sleep(0.1)
    x = torch.tensor([1.0, 3, 2])

    socket_pipe = WebsocketClientWorker(**kwargs)

    x = x.send(socket_pipe)
    p1, p2 = torch.sort(x)
    x1, x2 = p1.get(), p2.get()

    assert (x1 == torch.tensor([1.0, 2, 3])).all()
    assert (x2 == torch.tensor([0, 2, 1])).all()

    del x

    socket_pipe.ws.shutdown()
    time.sleep(0.1)
    server.terminate()
Beispiel #4
0
def test_websocket_worker_multiple_output_response(hook, start_proc):
    """Evaluates that you can do basic tensor operations using
    WebsocketServerWorker"""

    kwargs = {
        "id": "socket_multiple_output",
        "host": "localhost",
        "port": 8768,
        "hook": hook
    }
    process_remote_worker = start_proc(WebsocketServerWorker, **kwargs)

    time.sleep(0.1)
    x = torch.tensor([1.0, 3, 2])

    local_worker = WebsocketClientWorker(**kwargs)

    x = x.send(local_worker)
    p1, p2 = torch.sort(x)
    x1, x2 = p1.get(), p2.get()

    assert (x1 == torch.tensor([1.0, 2, 3])).all()
    assert (x2 == torch.tensor([0, 2, 1])).all()

    x.get()  # retrieve remote object before closing the websocket connection

    local_worker.ws.shutdown()
    process_remote_worker.terminate()
Beispiel #5
0
def test_websocket_workers_search(hook, start_proc):
    """Evaluates that a client can search and find tensors that belong
    to another party"""
    # Sample tensor to store on the server
    sample_data = torch.tensor([1, 2, 3, 4]).tag("#sample_data", "#another_tag")
    # Args for initializing the websocket server and client
    base_kwargs = {"id": "fed2", "host": "localhost", "port": 8767, "hook": hook}
    server_kwargs = base_kwargs
    server_kwargs["data"] = [sample_data]
    server = start_proc(WebsocketServerWorker, server_kwargs)

    time.sleep(0.1)

    client_worker = WebsocketClientWorker(**base_kwargs)

    # Search for the tensor located on the server by using its tag
    results = client_worker.search("#sample_data", "#another_tag")

    assert results
    assert results[0].owner.id == "me"
    assert results[0].location.id == "fed2"

    # Search multiple times should still work
    results = client_worker.search("#sample_data", "#another_tag")

    assert results
    assert results[0].owner.id == "me"
    assert results[0].location.id == "fed2"

    client_worker.ws.shutdown()
    client_worker.ws.close()
    time.sleep(0.1)
    server.terminate()
Beispiel #6
0
def test_plan_execute_remotely(hook, start_proc):
    """Test plan execution remotely."""
    hook.local_worker.is_client_worker = False

    @sy.func2plan
    def my_plan(data):
        x = data * 2
        y = (x - 2) * 10
        return x + y

    # TODO: remove this line when issue #2062 is fixed
    # Force to build plan
    x = th.tensor([-1, 2, 3])
    my_plan(x)

    kwargs = {"id": "test_plan_worker", "host": "localhost", "port": 8799, "hook": hook}
    server = start_proc(WebsocketServerWorker, kwargs)

    time.sleep(0.1)
    socket_pipe = WebsocketClientWorker(**kwargs)

    plan_ptr = my_plan.send(socket_pipe)
    x_ptr = x.send(socket_pipe)
    plan_res = plan_ptr(x_ptr).get()

    assert (plan_res == th.tensor([-42, 24, 46])).all()

    # delete remote object before websocket connection termination
    del x_ptr

    server.terminate()
Beispiel #7
0
def test_connect_close(hook, start_proc):
    kwargs = {"id": "fed", "host": "localhost", "port": 8763, "hook": hook}
    process_remote_worker = start_proc(WebsocketServerWorker, **kwargs)

    time.sleep(0.1)

    kwargs = {"id": "fed", "host": "localhost", "port": 8763, "hook": hook}
    local_worker = WebsocketClientWorker(**kwargs)

    x = torch.tensor([1, 2, 3])
    x_ptr = x.send(local_worker)

    assert local_worker.objects_count_remote() == 1

    local_worker.close()

    time.sleep(0.1)

    local_worker.connect()

    assert local_worker.objects_count_remote() == 1

    x_val = x_ptr.get()
    assert (x_val == x).all()

    local_worker.ws.shutdown()

    time.sleep(0.1)

    process_remote_worker.terminate()
Beispiel #8
0
def test_execute_plan_remotely(hook, start_proc):
    """Test plan execution remotely."""
    hook.local_worker.is_client_worker = False

    @sy.func2plan(args_shape=[(1, )])
    def my_plan(data):
        x = data * 2
        y = (x - 2) * 10
        return x + y

    x = th.tensor([-1, 2, 3])
    local_res = my_plan(x)

    kwargs = {
        "id": "test_plan_worker",
        "host": "localhost",
        "port": 8799,
        "hook": hook
    }
    server = start_proc(WebsocketServerWorker, **kwargs)

    time.sleep(0.1)
    socket_pipe = WebsocketClientWorker(**kwargs)

    plan_ptr = my_plan.send(socket_pipe)
    x_ptr = x.send(socket_pipe)
    plan_res = plan_ptr(x_ptr).get()

    assert (plan_res == local_res).all()

    # delete remote object before websocket connection termination
    del x_ptr

    server.terminate()
    hook.local_worker.is_client_worker = True
def connect(local=False):
    """
    Connects to the three hospitals on AWS and returns their WebsocketClientWorkers.
    Args:
        local (bool): Set to true for testing and start_local_workers.py will provide
            three test workers locally.
    Returns:
        tuple(class:`syft.workers.WebsocketClientWorker`): tuple of 3 connected workers.

    """
    hook = sy.TorchHook(torch)
    if local:
        naming = LH
    else:
        naming = H

    h1 = WebsocketClientWorker(
        id=naming.h1_name,
        port=naming.h1_port,
        host=naming.h1_host,
        hook=hook
    )
    logger.info("Connected to worker h1.")

    h2 = WebsocketClientWorker(
        id=naming.h2_name,
        port=naming.h2_port,
        host=naming.h2_host,
        hook=hook
    )
    logger.info("Connected to worker h2.")

    h3 = WebsocketClientWorker(
        id=naming.h3_name,
        port=naming.h3_port,
        host=naming.h3_host,
        hook=hook
    )
    logger.info("Connected to worker h3.")
    return h1, h2, h3
Beispiel #10
0
def test_create_already_existing_worker_with_different_type(hook, start_proc):
    # Shares tensor with bob
    bob = sy.VirtualWorker(hook, "bob")
    _ = th.tensor([1, 2, 3]).send(bob)

    kwargs = {"id": "fed1", "host": "localhost", "port": 8765, "hook": hook}
    server = start_proc(WebsocketServerWorker, kwargs)

    time.sleep(0.1)

    # Recreates bob as a different type of worker
    kwargs = {"id": "bob", "host": "localhost", "port": 8765, "hook": hook}
    with pytest.raises(RuntimeError):
        bob = WebsocketClientWorker(**kwargs)

    server.terminate()
Beispiel #11
0
def test_execute_plan_module_remotely(hook, start_proc):
    """Test plan execution remotely."""
    hook.local_worker.is_client_worker = False

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(2, 3)
            self.fc2 = nn.Linear(3, 2)

        @sy.method2plan
        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return F.log_softmax(x, dim=0)

    net = Net()

    x = th.tensor([-1, 2.0])
    local_res = net(x)
    assert not net.forward.is_built

    net.forward.build(x)

    kwargs = {
        "id": "test_plan_worker_2",
        "host": "localhost",
        "port": 8799,
        "hook": hook
    }
    server = start_proc(WebsocketServerWorker, **kwargs)

    time.sleep(0.1)
    socket_pipe = WebsocketClientWorker(**kwargs)

    plan_ptr = net.send(socket_pipe)
    x_ptr = x.send(socket_pipe)
    remote_res = plan_ptr(x_ptr).get()

    assert (remote_res == local_res).all()

    # delete remote object before websocket connection termination
    del x_ptr

    server.terminate()
    hook.local_worker.is_client_worker = True
Beispiel #12
0
def instantiate_websocket_client_worker(**kwargs):  # pragma: no cover
    """ Helper function to instantiate the websocket client.
    If connection is refused, we wait a bit and try again.
    After 5 failed tries, a ConnectionRefusedError is raised.
    """
    retry_counter = 0
    connection_open = False
    while not connection_open:
        try:
            local_worker = WebsocketClientWorker(**kwargs)
            connection_open = True
        except ConnectionRefusedError as e:
            if retry_counter < 5:
                retry_counter += 1
                time.sleep(0.1)
            else:
                raise e
    return local_worker
Beispiel #13
0
def test_websocket_garbage_collection(hook, start_proc):
    # Args for initializing the websocket server and client
    base_kwargs = {"id": "ws_gc", "host": "localhost", "port": 8777, "hook": hook}
    server = start_proc(WebsocketServerWorker, base_kwargs)

    time.sleep(0.1)
    client_worker = WebsocketClientWorker(**base_kwargs)

    sample_data = torch.tensor([1, 2, 3, 4])
    sample_ptr = sample_data.send(client_worker)

    _ = sample_ptr.get()
    assert sample_data not in client_worker._objects

    client_worker.ws.shutdown()
    client_worker.ws.close()
    time.sleep(0.1)
    server.terminate()
Beispiel #14
0
def test_websocket_worker(hook):
    """Evaluates that you can do basic tensor operations using
    WebsocketServerWorker"""

    kwargs = {"id": "fed1", "host": "localhost", "port": 8765, "hook": hook}
    server = start_proc(WebsocketServerWorker, kwargs)

    time.sleep(0.1)
    x = torch.ones(5)

    socket_pipe = WebsocketClientWorker(**kwargs)

    x = x.send(socket_pipe)
    y = x + x
    y = y.get()

    assert (y == torch.ones(5) * 2).all()

    del x

    server.terminate()
Beispiel #15
0
def test_websocket_worker(hook, start_proc):
    """Evaluates that you can do basic tensor operations using
    WebsocketServerWorker"""

    kwargs = {"id": "fed", "host": "localhost", "port": 8766, "hook": hook}
    process_remote_worker = start_proc(WebsocketServerWorker, kwargs)

    time.sleep(0.1)
    x = torch.ones(5)

    local_worker = WebsocketClientWorker(**kwargs)

    x = x.send(local_worker)
    y = x + x
    y = y.get()

    assert (y == torch.ones(5) * 2).all()

    del x

    local_worker.ws.shutdown()
    time.sleep(0.1)
    local_worker.remove_worker_from_local_worker_registry()
    process_remote_worker.terminate()
Beispiel #16
0
def test_train_plan_locally_and_then_send_it(hook, start_proc):
    """Test training a plan locally and then executing it remotely."""
    hook.local_worker.is_client_worker = False

    # Create toy model
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(2, 3)
            self.fc2 = nn.Linear(3, 2)

        @sy.method2plan
        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return F.log_softmax(x, dim=0)

    net = Net()

    # Create toy data
    x = th.tensor([-1, 2.0])
    y = th.tensor([1.0])

    # Train Model
    opt = optim.SGD(params=net.parameters(), lr=0.01)
    previous_loss = None

    for _ in range(5):
        # 1) erase previous gradients (if they exist)
        opt.zero_grad()

        # 2) make a prediction
        pred = net(x)

        # 3) calculate how much we missed
        loss = ((pred - y)**2).sum()

        # 4) figure out which weights caused us to miss
        loss.backward()

        # 5) change those weights
        opt.step()

        if previous_loss is not None:
            assert loss < previous_loss

        previous_loss = loss

    local_res = net(x)
    net.forward.build(x)

    kwargs = {
        "id": "test_plan_worker_3",
        "host": "localhost",
        "port": 8800,
        "hook": hook
    }
    server = start_proc(WebsocketServerWorker, **kwargs)

    time.sleep(0.1)
    socket_pipe = WebsocketClientWorker(**kwargs)

    plan_ptr = net.send(socket_pipe)
    x_ptr = x.send(socket_pipe)
    remote_res = plan_ptr(x_ptr).get()

    assert (remote_res == local_res).all()

    # delete remote object before websocket connection termination
    del x_ptr

    server.terminate()
    hook.local_worker.is_client_worker = True
Beispiel #17
0
def test_websocket_worker_basic(hook, start_proc, secure, tmpdir):
    """Evaluates that you can do basic tensor operations using
    WebsocketServerWorker in insecure and secure mode."""
    def create_self_signed_cert(cert_path, key_path):
        # create a key pair
        k = crypto.PKey()
        k.generate_key(crypto.TYPE_RSA, 1024)

        # create a self-signed cert
        cert = crypto.X509()
        cert.gmtime_adj_notBefore(0)
        cert.gmtime_adj_notAfter(1000)
        cert.set_pubkey(k)
        cert.sign(k, "sha1")

        # store keys and cert
        open(cert_path,
             "wb").write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
        open(key_path,
             "wb").write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k))

    kwargs = {
        "id": "secure_fed" if secure else "fed",
        "host": "localhost",
        "port": 8766 if secure else 8765,
        "hook": hook,
    }

    if secure:
        # Create cert and keys
        cert_path = tmpdir.join("test.crt")
        key_path = tmpdir.join("test.key")
        create_self_signed_cert(cert_path, key_path)
        kwargs["cert_path"] = cert_path
        kwargs["key_path"] = key_path

    process_remote_worker = start_proc(WebsocketServerWorker, **kwargs)

    time.sleep(0.1)
    x = torch.ones(5)

    if secure:
        # unused args
        del kwargs["cert_path"]
        del kwargs["key_path"]

    kwargs["secure"] = secure
    local_worker = WebsocketClientWorker(**kwargs)

    x = x.send(local_worker)
    y = x + x
    y = y.get()

    assert (y == torch.ones(5) * 2).all()

    del x

    local_worker.ws.shutdown()
    time.sleep(0.1)
    local_worker.remove_worker_from_local_worker_registry()
    process_remote_worker.terminate()
Beispiel #18
0
    category, rand_cat_index = randomChoice(all_categories) #cat = category, it's not a random animal
    #rand_line_index is a relative index for a data point within the random category rand_cat_index
    line, rand_line_index = randomChoice(category_lines[category])
    category_start_index = categories_start_index[category]
    absolute_index = category_start_index + rand_line_index
    return(absolute_index)
  
########## part 3 remote workers set-up and dataset distribution ##############

hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning
#alice = sy.VirtualWorker(hook, id="alice")  
#bob = sy.VirtualWorker(hook, id="bob")  

#If you have your workers operating remotely, like on Raspberry PIs
kwargs_websocket_alice = {"host": "6", "hook": hook}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket_alice)
kwargs_websocket_bob = {"host": "ip_bob", "hook": hook}
bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket_bob)
workers_virtual = [alice, bob]

langDataset =  LanguageDataset(array_lines_proper_dimension, categories_numpy)

#assign the data points and the corresponding categories to workers.
print("assignment starts")
federated_train_loader = sy.FederatedDataLoader(
            langDataset.federate(workers_virtual),
            batch_size=args.batch_size)
print("assignment completed") 
# time test for this part shows that time taken is not significant for this stage

print("Generating list of batches for the workers...")
def main():
    args = define_and_get_arguments()
    print(args)
    hook = sy.TorchHook(torch)

    host = "localhost"

    if args.use_virtual:
        alice = VirtualWorker(id="alice", hook=hook, verbose=args.verbose)
        bob = VirtualWorker(id="bob", hook=hook, verbose=args.verbose)
        charlie = VirtualWorker(id="charlie", hook=hook, verbose=args.verbose)
    else:
        kwargs_websocket = {
            "host": host,
            "hook": hook,
            "verbose": args.verbose
        }
        alice = WebsocketClientWorker(id="alice",
                                      port=8777,
                                      **kwargs_websocket)
        bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
        charlie = WebsocketClientWorker(id="charlie",
                                        port=8779,
                                        **kwargs_websocket)

    workers = [alice, bob, charlie]

    use_cuda = args.cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

    # Search multiple times should still work
    tr_alice = alice.search("#mnist", "#alice", "#train_tag")
    tr_bob = bob.search("#mnist", "#bob", "#train_tag")
    tr_charlie = charlie.search("#mnist", "#charlie", "#train_tag")

    base_data = []
    base_data.append(BaseDataset(tr_alice[0], tr_alice[1]))
    base_data.append(BaseDataset(tr_bob[0], tr_bob[1]))
    base_data.append(BaseDataset(tr_charlie[0], tr_charlie[1]))

    federated_train_loader = sy.FederatedDataLoader(
        FederatedDataset(base_data),
        batch_size=args.batch_size,
        shuffle=True,
        iter_per_worker=True,
        **kwargs,
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "../data",
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]),
        ),
        batch_size=args.test_batch_size,
        shuffle=True,
        **kwargs,
    )

    model = Net().to(device)

    for epoch in range(1, args.epochs + 1):
        logger.info("Starting epoch %s/%s", epoch, args.epochs)
        model = train(model, device, federated_train_loader, args.lr,
                      args.federate_after_n_batches)
        test(model, device, test_loader)

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")
Beispiel #20
0
def main():
    args = define_and_get_arguments()

    hook = sy.TorchHook(torch)

    host = "localhost"

    if args.use_virtual:
        alice = VirtualWorker(id="hospital_a", hook=hook, verbose=args.verbose)
        bob = VirtualWorker(id="hospital_b", hook=hook, verbose=args.verbose)
        charlie = VirtualWorker(id="hospital_c",
                                hook=hook,
                                verbose=args.verbose)
    else:
        kwargs_websocket = {
            "host": host,
            "hook": hook,
            "verbose": args.verbose
        }
        hospital_a = WebsocketClientWorker(id="hospital_a",
                                           port=8777,
                                           **kwargs_websocket)
        hospital_b = WebsocketClientWorker(id="hospital_b",
                                           port=8778,
                                           **kwargs_websocket)
        hospital_c = WebsocketClientWorker(id="hospital_c",
                                           port=8779,
                                           **kwargs_websocket)

        print()
        print(
            "*******************************************************************************************************"
        )
        print("building training channels ...")
        print(" #hospital_a, remote tensor reference: ", hospital_a)
        print(" #hospital_b, remote tensor reference: ", hospital_b)
        print(" #hospital_c, remote tensor reference: ", hospital_c)
        print()

    workers = [hospital_a, hospital_b, hospital_c]

    use_cuda = args.cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

    # Search multiple times should still work
    tr_hospital_a = hospital_a.search("#chest_xray", "#hospital_a",
                                      "#train_tag")
    tr_hospital_b = hospital_b.search("#chest_xray", "#hospital_b",
                                      "#train_tag")
    tr_hospital_c = hospital_c.search("#chest_xray", "#hospital_c",
                                      "#train_tag")

    base_data = []
    base_data.append(BaseDataset(tr_hospital_a[0], tr_hospital_a[1]))
    base_data.append(BaseDataset(tr_hospital_b[0], tr_hospital_b[1]))
    base_data.append(BaseDataset(tr_hospital_c[0], tr_hospital_c[1]))

    federated_train_loader = sy.FederatedDataLoader(
        FederatedDataset(base_data),
        batch_size=args.batch_size,
        shuffle=True,
        iter_per_worker=True,
        **kwargs,
    )

    data_transforms = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    test = datasets.ImageFolder('chest_xray/small', data_transforms)

    local_test_loader = torch.utils.data.DataLoader(
        test, batch_size=args.test_batch_size, shuffle=True, **kwargs)

    model = resnet.resnet18_simple()
    # print("*******************************************************************************************************")
    # print("model architecture")
    # print(model)
    # print()

    print(
        "*******************************************************************************************************"
    )
    print("starting federated learning ...")
    for epoch in range(1, args.epochs + 1):
        logger.info(" starting fl training epoch %s/%s", epoch, args.epochs)
        model = fl_train(model, device, federated_train_loader, args.lr,
                         args.federate_after_n_batches)

        logger.info(" starting local inference")
        local_test(model, device, local_test_loader)

    if args.save_model:
        torch.save(model.state_dict(), "./log/chest_xray_resnet18.pt")
def main():
    args = Arguments(False)

    hook = sy.TorchHook(torch)

    if args.use_virtual:
        alice = VirtualWorker(id="alice", hook=hook, verbose=args.verbose)
        bob = VirtualWorker(id="bob", hook=hook, verbose=args.verbose)
        charlie = VirtualWorker(id="charlie", hook=hook, verbose=args.verbose)
    else:
        kwargs_websocket = {
            "host": "localhost",
            "hook": hook,
            "verbose": args.verbose
        }
        alice = WebsocketClientWorker(id="alice",
                                      port=8777,
                                      **kwargs_websocket)
        bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
        charlie = WebsocketClientWorker(id="charlie",
                                        port=8779,
                                        **kwargs_websocket)

    workers = [alice, bob, charlie]
    clients_mem = torch.zeros(len(workers))

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    if use_cuda:
        # TODO Quickhack. Actually need to fix the problem moving the model to CUDA\n",
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    print(device)

    kwargs = {"num_workers": 0, "pin_memory": False} if use_cuda else {}

    federated_train_loader = sy.FederatedDataLoader(
        datasets.MNIST(
            "../data",
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]),
        ).federate(tuple(workers)),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs,
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "../data",
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]),
        ),
        batch_size=args.test_batch_size,
        shuffle=True,
        **kwargs,
    )

    start = time.time()
    model = Net().to(device)
    optimizer = optim.SGD(model.parameters(), lr=args.lr)

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, federated_train_loader, optimizer, epoch,
              clients_mem)
        test(args, model, device, test_loader)
        t = time.time()
        print(t - start)

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")

    end = time.time()
    print(end - start)
    print("Memory exchanged : ", clients_mem)
Beispiel #22
0
async def test_train_config_with_jit_trace_async(
        hook, start_proc):  # pragma: no cover
    kwargs = {
        "id": "async_fit",
        "host": "localhost",
        "port": 8777,
        "hook": hook
    }
    # data = torch.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)
    # target = torch.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)
    # dataset_key = "xor"
    data, target = utils.create_gaussian_mixture_toy_data(100)
    dataset_key = "gaussian_mixture"

    mock_data = torch.zeros(1, 2)

    # TODO check reason for error (RuntimeError: This event loop is already running) when starting websocket server from pytest-asyncio environment
    # dataset = sy.BaseDataset(data, target)

    # process_remote_worker = start_proc(WebsocketServerWorker, dataset=(dataset, dataset_key), **kwargs)

    # time.sleep(0.1)

    local_worker = WebsocketClientWorker(**kwargs)

    @hook.torch.jit.script
    def loss_fn(pred, target):
        return ((target.view(pred.shape).float() - pred.float())**2).mean()

    class Net(torch.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(2, 3)
            self.fc2 = nn.Linear(3, 2)
            self.fc3 = nn.Linear(2, 1)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    model_untraced = Net()

    model = torch.jit.trace(model_untraced, mock_data)

    pred = model(data)
    loss_before = loss_fn(target=target, pred=pred)

    # Create and send train config
    train_config = sy.TrainConfig(model=model,
                                  loss_fn=loss_fn,
                                  batch_size=2,
                                  lr=0.1)
    train_config.send(local_worker)

    for epoch in range(5):
        loss = await local_worker.async_fit(dataset_key=dataset_key)
        if PRINT_IN_UNITTESTS:  # pragma: no cover
            print("-" * 50)
            print("Iteration %s: alice's loss: %s" % (epoch, loss))

    new_model = train_config.model_ptr.get()

    assert not (model.fc1._parameters["weight"]
                == new_model.obj.fc1._parameters["weight"]).all()
    assert not (model.fc2._parameters["weight"]
                == new_model.obj.fc2._parameters["weight"]).all()
    assert not (model.fc3._parameters["weight"]
                == new_model.obj.fc3._parameters["weight"]).all()
    assert not (model.fc1._parameters["bias"]
                == new_model.obj.fc1._parameters["bias"]).all()
    assert not (model.fc2._parameters["bias"]
                == new_model.obj.fc2._parameters["bias"]).all()
    assert not (model.fc3._parameters["bias"]
                == new_model.obj.fc3._parameters["bias"]).all()

    new_model.obj.eval()
    pred = new_model.obj(data)
    loss_after = loss_fn(target=target, pred=pred)
    if PRINT_IN_UNITTESTS:  # pragma: no cover
        print("Loss before training: {}".format(loss_before))
        print("Loss after training: {}".format(loss_after))

    local_worker.ws.shutdown()
    # process_remote_worker.terminate()

    assert loss_after < loss_before
Beispiel #23
0
def test_train_config_with_jit_trace_sync(hook,
                                          start_proc):  # pragma: no cover
    kwargs = {
        "id": "sync_fit",
        "host": "localhost",
        "port": 9000,
        "hook": hook
    }
    # data = torch.tensor([[-1, 2.0], [0, 1.1], [-1, 2.1], [0, 1.2]], requires_grad=True)
    # target = torch.tensor([[1], [0], [1], [0]])

    data, target = utils.create_gaussian_mixture_toy_data(100)

    dataset = sy.BaseDataset(data, target)

    dataset_key = "gaussian_mixture"
    process_remote_worker = start_proc(WebsocketServerWorker,
                                       dataset=(dataset, dataset_key),
                                       **kwargs)

    time.sleep(0.1)

    local_worker = WebsocketClientWorker(**kwargs)

    @hook.torch.jit.script
    def loss_fn(pred, target):
        return ((target.view(pred.shape).float() - pred.float())**2).mean()

    class Net(torch.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(2, 3)
            self.fc2 = nn.Linear(3, 2)
            self.fc3 = nn.Linear(2, 1)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    model_untraced = Net()

    model = torch.jit.trace(model_untraced, data)

    pred = model(data)
    loss_before = loss_fn(pred=pred, target=target)

    # Create and send train config
    train_config = sy.TrainConfig(model=model,
                                  loss_fn=loss_fn,
                                  batch_size=2,
                                  epochs=1)
    train_config.send(local_worker)

    for epoch in range(5):
        loss = local_worker.fit(dataset_key=dataset_key)
        if PRINT_IN_UNITTESTS:  # pragma: no cover
            print("-" * 50)
            print("Iteration %s: alice's loss: %s" % (epoch, loss))

    new_model = train_config.model_ptr.get()

    # assert that the new model has updated (modified) parameters
    assert not (model.fc1._parameters["weight"]
                == new_model.obj.fc1._parameters["weight"]).all()
    assert not (model.fc2._parameters["weight"]
                == new_model.obj.fc2._parameters["weight"]).all()
    assert not (model.fc3._parameters["weight"]
                == new_model.obj.fc3._parameters["weight"]).all()
    assert not (model.fc1._parameters["bias"]
                == new_model.obj.fc1._parameters["bias"]).all()
    assert not (model.fc2._parameters["bias"]
                == new_model.obj.fc2._parameters["bias"]).all()
    assert not (model.fc3._parameters["bias"]
                == new_model.obj.fc3._parameters["bias"]).all()

    new_model.obj.eval()
    pred = new_model.obj(data)
    loss_after = loss_fn(pred=pred, target=target)

    if PRINT_IN_UNITTESTS:  # pragma: no cover
        print("Loss before training: {}".format(loss_before))
        print("Loss after training: {}".format(loss_after))

    local_worker.ws.shutdown()
    del local_worker

    time.sleep(0.1)
    process_remote_worker.terminate()

    assert loss_after < loss_before
def experiment(no_cuda):

    # Creating num_workers clients

    hook = sy.TorchHook(torch)



    # Initializing arguments, with GPU usage or not
    args = Arguments(no_cuda)

    if args.use_virtual:
        alice = VirtualWorker(id="alice", hook=hook, verbose=args.verbose)
        bob = VirtualWorker(id="bob", hook=hook, verbose=args.verbose)
        charlie = VirtualWorker(id="charlie", hook=hook, verbose=args.verbose)
    else:
        kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": args.verbose}
        alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
        bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
        charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket)
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    if use_cuda:
    # TODO Quickhack. Actually need to fix the problem moving the model to CUDA\n",
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    torch.manual_seed(args.seed)

    clients = [alice, bob, charlie]
    clients_mem = torch.zeros(len(clients))

    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 0, 'pin_memory': False} if use_cuda else {}


    # Federated data loader
    federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader
      datasets.MNIST('../data', train=True, download=True,
                     transform=transforms.Compose([
                         transforms.ToTensor(),
                         transforms.Normalize((0.1307,), (0.3081,))
                     ]))
      .federate(clients), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
      batch_size=args.batch_size, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(
      datasets.MNIST('../data', train=False, transform=transforms.Compose([
                         transforms.ToTensor(),
                         transforms.Normalize((0.1307,), (0.3081,))
                     ])),
      batch_size=args.test_batch_size, shuffle=True, **kwargs)


    #creating the models for each client
    models,optimizers = [], []
    #print(device)
    for i in range(len(clients)):
        #print(i)
        models.append(Net1().to(device))
        models[i] = models[i].send(clients[i])
        optimizers.append(optim.SGD(params=models[i].parameters(),lr=0.1))



    start = time.time()
    #%%time
    model = Net2().to(device)
    optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, federated_train_loader, optimizer, epoch, models, optimizers,clients_mem)
        test(args, model, device, test_loader, models)
        t = time.time()
        print(t-start)
    if (args.save_model):
        torch.save(model.state_dict(), "mnist_cnn.pt")

    end = time.time()
    print(end - start)
    print("Memory exchanged : ",clients_mem)
    return clients_mem
Beispiel #25
0
batch_size = 4
optimizer_args = {"lr" : 0.1, "weight_decay" : 0.01}
max_nr_batches = -1  # not used in this example
shuffle = True


train_config = syft.TrainConfig(model=traced_model,
                              loss_fn=loss_fn,
                              optimizer=optimizer,
                              batch_size=batch_size,
                              optimizer_args=optimizer_args,
                              epochs=5,
                              shuffle=shuffle)

arw = {"host":"10.0.0.1","hook":hook}
h1 = WebsocketClientWorker(id="h1",port=8778,**arw)
train_config.send(h1)


message = h1.create_message_execute_command(command_name="start_monitoring",command_owner="self")
serialized_message = syft.serde.serialize(message)
h1._recv_msg(serialized_message)


time.sleep(3)
for epoch in range(10):
    time.sleep(3)
    loss = h1.fit(dataset_key="train")  # ask alice to train using "xor" dataset
    print("-" * 50)
    print("Iteration %s: h1's loss: %s" % (epoch, loss))