Пример #1
0
def main(number, start_slice, end_slice):
    mnist_dataset = TrainDataset(transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))]),
                                 number=number,
                                 start_slice=start_slice,
                                 end_slice=end_slice)
    _id = 'h%s' % number
    ip = '10.0.0.%s' % number

    hook = syft.TorchHook(torch)

    server = WebsocketServerWorker(id=_id,
                                   host=ip,
                                   port=8778,
                                   hook=hook,
                                   verbose=True)
    print("Worker:{}, Dataset contains {}".format(_id,
                                                  str(len(
                                                      mnist_dataset.data))))
    dataset = syft.BaseDataset(data=mnist_dataset.data,
                               targets=mnist_dataset.target,
                               transform=mnist_dataset.transform)
    key = "targeted"
    server.add_dataset(dataset, key=key)
    server.start()
Пример #2
0
def start_websocket_server_worker(id, host, port, hook, verbose, keep_labels=None, training=True):
    """Helper function for spinning up a websocket server and setting up the local datasets."""

    server = WebsocketServerWorker(id=id, host=host, port=port, hook=hook, verbose=verbose)

    # Setup toy data (mnist example)
    mnist_dataset = datasets.MNIST(
        root="./data",
        train=training,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    )

    if training:
        indices = np.isin(mnist_dataset.targets, keep_labels).astype("uint8")
        logger.info("number of true indices: %s", indices.sum())
        selected_data = (
            torch.native_masked_select(mnist_dataset.data.transpose(0, 2), torch.tensor(indices))
            .view(28, 28, -1)
            .transpose(2, 0)
        )
        logger.info("after selection: %s", selected_data.shape)
        selected_targets = torch.native_masked_select(mnist_dataset.targets, torch.tensor(indices))

        dataset = sy.BaseDataset(
            data=selected_data, targets=selected_targets, transform=mnist_dataset.transform
        )
        key = "mnist"
    else:
        dataset = sy.BaseDataset(
            data=mnist_dataset.data,
            targets=mnist_dataset.targets,
            transform=mnist_dataset.transform,
        )
        key = "mnist_testing"

    server.add_dataset(dataset, key=key)

    logger.info("datasets: %s", server.datasets)
    if training:
        logger.info("len(datasets[mnist]): %s", len(server.datasets["mnist"]))

    server.start()
    return server
Пример #3
0
def main():
    parser = argparse.ArgumentParser(
        description="Run websocket server worker.")
    parser.add_argument(
        "--port",
        "-p",
        type=int,
        help="port number of the websocket server worker, e.g. --port 8777")
    parser.add_argument("--host",
                        type=str,
                        default="localhost",
                        help="host for the connection")
    parser.add_argument(
        "--id",
        type=str,
        help="name (id) of the websocket server worker, e.g. --id alice")
    parser.add_argument(
        "--verbose",
        "-v",
        action="store_true",
        help="if set, websocket server worker will be started in verbose mode")
    args = parser.parse_args()

    hook = sy.TorchHook(torch)

    kwargs = {
        "id": str(args.id),
        "host": args.host,
        "port": args.port,
        "hook": hook,
        "verbose": args.verbose,
    }

    global local_worker
    local_worker = WebsocketServerWorker(**kwargs)
    local_worker.start()

    loop = asyncio.new_event_loop()
    update_thread = Thread(target=start_background_loop,
                           args=(loop, ),
                           daemon=True)
    update_thread.start()
def main(**kwargs):  # pragma: no cover
    """Helper function for spinning up a websocket participant."""

    # Create websocket worker
    worker = WebsocketServerWorker(**kwargs)

    # Setup toy data (xor example)
    data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)
    target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)

    # Create a dataset using the toy data
    dataset = sy.BaseDataset(data, target)

    # Tell the worker about the dataset
    worker.add_dataset(dataset, key="xor")

    # Start worker
    worker.start()

    return worker
def start_websocket_server_worker(id, host, port, hook, verbose):
    """Helper function for spinning up a websocket server and setting up the local datasets."""

    server = WebsocketServerWorker(id=id,
                                   host=host,
                                   port=port,
                                   hook=hook,
                                   verbose=verbose)

    # # Setup toy data (mnist example)
    # mnist_dataset = datasets.MNIST(
    #     root="./data",
    #     train=True,
    #     download=True,
    #     transform=transforms.Compose(
    #         [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    #     ),
    # )
    #
    # indices = np.isin(mnist_dataset.targets, keep_labels).astype("uint8")
    # logger.info("number of true indices: %s", indices.sum())
    # selected_data = (
    #     torch.native_masked_select(mnist_dataset.data.transpose(0, 2), torch.tensor(indices))
    #         .view(28, 28, -1)
    #         .transpose(2, 0)
    # )
    # logger.info("after selection: %s", selected_data.shape)
    # selected_targets = torch.native_masked_select(mnist_dataset.targets, torch.tensor(indices))
    #
    # dataset = sy.BaseDataset(
    #     data=selected_data, targets=selected_targets, transform=mnist_dataset.transform
    # )
    # key = "mnist"
    #
    # # Adding Dataset
    # server.add_dataset(dataset, key=key)
    #
    # logger.info("datasets: %s", server.datasets)

    server.start()
    return server
Пример #6
0
def start_websocket_server_worker(id,
                                  host,
                                  port,
                                  hook,
                                  verbose,
                                  dataset,
                                  training=True):
    """Helper function for spinning up a websocket server and setting up the local datasets."""

    server = WebsocketServerWorker(id=id,
                                   host=host,
                                   port=port,
                                   hook=hook,
                                   verbose=verbose)
    dataset_key = dataset
    #if we are in the traning loop
    if training:
        with open("./data/split/%d" % int(id), "rb") as fp:  # Unpickling
            data = pickle.load(fp)
        dataset_data, dataset_target = readnpy(data)
        print(type(dataset_data.long()))
        logger.info("Number of samples for client %s is %s : ", id,
                    len(dataset_data))
        dataset = sy.BaseDataset(data=dataset_data, targets=dataset_target)
        key = dataset_key

    nb_labels = len(torch.unique(dataset_target))
    server.add_dataset(dataset, key=key)
    count = [0] * nb_labels
    logger.info("Dataset(train set) ,available numbers on %s: ", id)
    for i in range(nb_labels):
        count[i] = (dataset.targets == i).sum().item()
        logger.info("      %s: %s", i, count[i])
    logger.info("datasets: %s", server.datasets)
    if training:
        logger.info("len(datasets): %s", len(server.datasets[key]))

    server.start()
    return server
Пример #7
0
parser.add_argument(
    "--id",
    type=str,
    help="name (id) of the websocket server worker, e.g. --id alice")

parser.add_argument(
    "--verbose",
    "-v",
    action="store_true",
    help="if set, websocket server worker will be started in verbose mode",
)

args = parser.parse_args()

kwargs = {
    "id": args.id,
    "host": args.host,
    "port": args.port,
    "verbose": args.verbose,
}

kwargs = {"id": "worker", "host": "localhost", "port": "8777", "hook": hook}

if os.name != "nt":
    print("not nt")
    server = start_proc(WebsocketServerWorker, kwargs)
else:
    print("nt")
    server = WebsocketServerWorker(**kwargs)
    server.start()
Пример #8
0
from syft.workers.websocket_server import WebsocketServerWorker
import syft as sy
import torch
hook = sy.TorchHook(torch)

kwargs = {
    "id": "bob",
    "host": "localhost",
    "port": 8778,
    "hook": hook,
    "verbose": True
}

server = WebsocketServerWorker(**kwargs)
print(server.start())
print(server.list_objects())
Пример #9
0
def start_websocket_server_worker(id,
                                  host,
                                  port,
                                  hook,
                                  verbose,
                                  keep_labels=None,
                                  training=True):  # pragma: no cover
    """Helper function for spinning up a websocket server and setting up the local datasets."""

    server = WebsocketServerWorker(id=id,
                                   host=host,
                                   port=port,
                                   hook=hook,
                                   verbose=verbose)

    # Setup toy data (mnist example)
    mnist_dataset = datasets.MNIST(
        root="./data",
        train=training,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ]),
    )

    if training:
        indices = np.isin(mnist_dataset.targets, keep_labels).astype("uint8")
        logger.info("number of true indices: %s", indices.sum())
        selected_data = (torch.native_masked_select(
            mnist_dataset.data.transpose(0, 2),
            torch.tensor(indices)).view(28, 28, -1).transpose(2, 0))
        logger.info("after selection: %s", selected_data.shape)
        selected_targets = torch.native_masked_select(mnist_dataset.targets,
                                                      torch.tensor(indices))

        dataset = sy.BaseDataset(data=selected_data,
                                 targets=selected_targets,
                                 transform=mnist_dataset.transform)
        key = "mnist"
    else:
        dataset = sy.BaseDataset(
            data=mnist_dataset.data,
            targets=mnist_dataset.targets,
            transform=mnist_dataset.transform,
        )
        key = "mnist_testing"

    server.add_dataset(dataset, key=key)

    # Setup toy data (vectors example)
    data_vectors = torch.tensor([[-1, 2.0], [0, 1.1], [-1, 2.1], [0, 1.2]],
                                requires_grad=True)
    target_vectors = torch.tensor([[1], [0], [1], [0]])

    server.add_dataset(sy.BaseDataset(data_vectors, target_vectors),
                       key="vectors")

    # Setup toy data (xor example)
    data_xor = torch.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]],
                            requires_grad=True)
    target_xor = torch.tensor([1.0, 1.0, 0.0, 0.0], requires_grad=False)

    server.add_dataset(sy.BaseDataset(data_xor, target_xor), key="xor")

    # Setup gaussian mixture dataset
    data, target = utils.create_gaussian_mixture_toy_data(nr_samples=100)
    server.add_dataset(sy.BaseDataset(data, target), key="gaussian_mixture")

    # Setup partial iris dataset
    data, target = utils.iris_data_partial()
    dataset = sy.BaseDataset(data, target)
    dataset_key = "iris"
    server.add_dataset(dataset, key=dataset_key)

    logger.info("datasets: %s", server.datasets)
    if training:
        logger.info("len(datasets[mnist]): %s", len(server.datasets["mnist"]))

    server.start()
    return server
Пример #10
0
import torch
import syft

from syft.workers.websocket_server import WebsocketServerWorker

# Hook and start server
hook = syft.TorchHook(torch)
server_worker = WebsocketServerWorker(id="bad",
                                      host="localhost",
                                      port=8777,
                                      hook=hook)

test_data = torch.tensor([1, 2, 3]).tag("test")
server_worker.set_obj(test_data)

print("Bad server started.")
server_worker.start()
Пример #11
0
import syft as sy
from syft.workers.websocket_server import WebsocketServerWorker
import torch

hook = sy.TorchHook(torch)

kwargs = {
    "id": 'mbp2',
    "host": 'localhost',
    "port": '8002',
    "hook": hook,
    "verbose": True,
}

if __name__ == "__main__":
    server2 = WebsocketServerWorker(**kwargs)
    tensor = torch.tensor([2])
    server2.load_data([tensor])
    print(server2._objects)
    server2.start()
Пример #12
0
class Worker:
    class Mode(Enum):
        COLLECTING = 1
        READY_TO_UPDATE = 2
        WAITING = 3

    def __init__(self, name, worker_ip, conductor_ip, batch_size, hook,
                 faceCascade, verbose):
        self.name = name
        self.worker_ip = worker_ip
        self.conductor_ip = conductor_ip
        self.syft_port = None
        self.num_peers = 0
        self.batch_size = batch_size
        self.verbose = verbose
        self.hook = hook
        self.registered = False
        self.active = False
        self.dataset = []
        self.current_face = None
        self.reconstructed_capture = None
        self.reconstruction_loss = 0
        self.kl_loss = 0
        self.random_image = None
        #self.socket_uri = 'ws://%s:%d' % (conductor_ip, conductor_port)
        self.socket_uri = conductor_ip
        self.faceCascade = faceCascade
        self.waiting_for_conductor = False
        self.last_updated_time = None
        self.num_updates = 0

    def set_mode(self, mode):
        self.mode = mode

    def setup_data_source(self, data_input, image_size):
        self.image_size = image_size
        self.input = data_input
        if self.input == 'cam':
            self.camera = cv2.VideoCapture(0)

    def setup_lfw_loader(self, lfw_loader):
        self.lfw_loader = lfw_loader

    def ready_to_update(self):
        ready = (self.num_samples() == self.batch_size)
        return ready

    def get_next_batch(self):
        log('Get next batch...', self.verbose)
        data = np.array(self.dataset)
        data = torch.tensor(np.mean(data, -1))
        log('Found batch of size %s' % str(data.shape), self.verbose)
        self.x_ptr = data.tag('#x').send(self.local_worker)

    def activate(self):
        kwargs = {
            "hook": self.hook,
            "id": self.name,
            "host": "0.0.0.0",
            "port": self.syft_port
        }
        self.active = True
        self.local_worker = WebsocketServerWorker(**kwargs)
        self.local_worker.start()

    async def capture(self):
        if self.input == 'cam':
            log('Get camera input', self.verbose.capture)
            if not self.camera.isOpened():
                log('Unable to load camera.', self.verbose.capture)
                time.sleep(2)
                return
            _, image = self.camera.read()

        elif self.input == 'picam':
            log('Get picamera input', self.verbose.capture)
            stream = io.BytesIO()
            with picamera.PiCamera() as camera:
                camera.start_preview()
                #time.sleep(1)
                camera.capture(stream, format='jpeg')
            stream.seek(0)
            image = np.array(Image.open(stream))

        elif self.input == 'lfw':
            log('Get LFW input', self.verbose.capture)
            image = self.lfw_loader.get_random_image()

        self.current_capture = image

    async def process_capture(self, faceCascade):
        if self.waiting_for_conductor or self.num_samples() >= self.batch_size:
            self.current_face = None
            return False

        face = search_for_face(self.current_capture, self.faceCascade)
        if face is not None:
            ih, iw = self.current_capture.shape[0:2]
            x, y, w, h = face
            x1, x2 = int(x - 0.33 * w), int(x + 1.33 * w)
            y1, y2 = int(y - 0.33 * h), int(y + 1.33 * h)

            if x1 < 0 or x2 >= iw or y1 < 0 or y2 >= ih:
                log('Face not fully inside frame', self.verbose.capture)
                self.current_face = None
                return False

            log('Found face inside [%d, %d, %d, %d]' % (y1, y2, x1, x2),
                self.verbose.capture)

            # crop out face image
            face_image = self.current_capture[y1:y2, x1:x2]
            face_image = cv2.resize(face_image,
                                    (self.image_size, self.image_size),
                                    interpolation=cv2.INTER_CUBIC)
            face_image = np.array(face_image).astype(np.float32) / 255.

            log('Image reshaped to %s' % str(face_image.shape),
                self.verbose.capture)
            self.current_face = np.expand_dims(face_image, 0)  #face_image
            self.dataset.append(face_image)
            return True

        else:
            log('No face found', self.verbose.capture)
            self.current_face = None
            return False

    def draw(self):
        log('Draw loop', self.verbose.draw)

        if self.last_updated_time:
            last_updated_str = time.strftime(
                "%H:%M:%S", time.gmtime(self.last_updated_time))
            last_updated_ago = time.time() - self.last_updated_time
        else:
            last_updated_str = 'N/A'
            last_updated_ago = 0

        model_type, num_params = 'fully-connected VAE', 48812  # 'convolutional'

        # draw images
        im_frame = Image.fromarray(self.current_capture)
        log(
            'Resize frame_image from %dx%d to %dx%d' %
            (im_frame.width, im_frame.height, FRAME_W, FRAME_H),
            self.verbose.draw)
        im_frame = im_frame.resize((FRAME_W, FRAME_H), Image.BICUBIC)
        log('Drawing...', self.verbose.draw)
        ctx.rectangle((0, 0, GUI_W, GUI_H), fill='#000', outline='#000')
        gui.paste(im_frame, (FRAME_X, FRAME_Y))

        if self.current_face is None and not self.waiting_for_conductor:
            ctx.text((FRAME_X + 3, FRAME_Y + 2),
                     "Place your face inside the red ellipse :)",
                     font=font2,
                     fill='#00f')
            ctx.ellipse((FRAME_X + 20, FRAME_Y + 25, FRAME_X + FRAME_W - 20,
                         FRAME_Y + FRAME_H - 20),
                        width=12,
                        fill=None,
                        outline='#0000ff44')
        elif not self.waiting_for_conductor:
            ctx.rectangle(
                (FRAME_X, FRAME_Y, FRAME_X + FRAME_W, FRAME_Y + FRAME_H),
                width=8,
                fill=None,
                outline='#0f0')

        if self.current_face is not None:
            img_capture = Image.fromarray(
                (255 * self.current_face[0]).astype(np.uint8))
            img_capture = img_capture.resize((INSET_DIM, INSET_DIM),
                                             Image.NEAREST)
            gui.paste(img_capture, (FACE_X, FACE_Y))
        if self.reconstructed_capture is not None:
            img_reconstructed = Image.fromarray(
                self.reconstructed_capture.astype(np.uint8))
            img_reconstructed = img_reconstructed.resize(
                (INSET_DIM, INSET_DIM), Image.NEAREST)
            gui.paste(img_reconstructed, (FACER_X, FACER_Y))

        currTime = time.strftime('%l:%M%p')
        ctx.text((FRAME_X, FRAME_Y - FONT_SIZE_1 - 5),
                 "Camera feed %s" % currTime,
                 font=font1,
                 fill=(0, 255, 0, 255))
        ctx.text((FACE_X, FACE_Y - FONT_SIZE_2 - 5),
                 "detected",
                 font=font2,
                 fill=(0, 255, 0, 255))
        ctx.text((FACER_X, FACER_Y - FONT_SIZE_2 - 5),
                 "reconstructed",
                 font=font2,
                 fill=(0, 255, 0, 255))
        ctx.text((FACEG_X, FACEG_Y - FONT_SIZE_2 - 5),
                 "randomly generated",
                 font=font2,
                 fill=(0, 255, 0, 255))

        # draw dashboard
        log('Draw dashboard', self.verbose.draw)
        gui_color = '#0ff' if self.mode == Worker.Mode.WAITING else '#0f0'
        gui_color_rgba = (0, 255, 255,
                          255) if self.mode == Worker.Mode.WAITING else (0,
                                                                         255,
                                                                         0,
                                                                         255)
        dashboard = Image.new('RGB', (DASH_W, DASH_H))
        ctx_dash = ImageDraw.Draw(dashboard)
        ctx_dash.rectangle((0, 0, DASH_W - 1, DASH_H - 1),
                           fill=None,
                           outline=gui_color)
        ctx_dash.rectangle((0, 0, DASH_W - 1, FONT_SIZE_2 + 8),
                           fill=None,
                           outline=gui_color)
        ctx_dash.text((3, 3), "Dashboard", font=font2, fill=gui_color_rgba)
        time_running = str(round(time.time() - t0)) + ' sec'
        if self.mode == Worker.Mode.COLLECTING:
            current_status = 'Collecting samples'
        elif self.mode == Worker.Mode.READY_TO_UPDATE:
            current_status = 'Ready to update'
        elif self.mode == Worker.Mode.WAITING:
            current_status = 'Waiting for conductor'

        last_updated_ago_str = (
            '%d sec' % last_updated_ago) if last_updated_ago <= 60 else (
                '%d min' % int(last_updated_ago / 60))
        for l, line in enumerate([
                'Current status:  %s' % current_status, '',
                'Time running:  %s' % time_running,
                'Number of peers:  %d' % self.num_peers,
                'My name:  %s' % self.name,
                'My location:  %s:%d' % (self.worker_ip, self.syft_port),
                'Conductor location:  %s' % self.conductor_ip, '',
                'Model:  %s (%d parameters, batch size %d)' %
            (model_type, num_params, self.batch_size),
                'Number of local updates:  %d' % self.num_updates,
                'Last model update:  %s (%s ago)' %
            (last_updated_str, last_updated_ago_str),
                'Current reconstruction loss:  %0.2f' %
                self.reconstruction_loss,
                'Current KL-divergence loss:  %0.2f' % self.kl_loss, '',
                'Current batch num samples:  %d' % self.num_samples()
        ]):
            ctx_dash.text((6, 15 + FONT_SIZE_2 + l * FONT_SIZE_3 * 1.4),
                          line,
                          font=font3,
                          fill=gui_color_rgba)
        gui.paste(dashboard, (DASH_X, DASH_Y))

        return gui

    def num_samples(self):
        return len(self.dataset)

    async def get_port_assignment(self):
        log('Get port assignment from conductor', self.verbose.event)
        async with websockets.connect(self.socket_uri) as websocket:
            message = json.dumps({
                'action': 'get_available_port'
            })  #, 'name': self.name, 'host': get_local_ip_address()})
            await websocket.send(message)
            result = json.loads(await websocket.recv())
            if result['success']:
                log(
                    'Got available syft port from conductor: %d' %
                    result['syft_port'], self.verbose.event)
                self.syft_port = result['syft_port']

    async def try_register(self):
        log('Register with conductor', self.verbose.event)
        async with websockets.connect(self.socket_uri) as websocket:
            message = json.dumps({
                'action': 'register',
                'name': self.name,
                'syft_port': self.syft_port,
                'host': get_local_ip_address()
            })
            await websocket.send(message)
            result = json.loads(await websocket.recv())
            if result['success']:
                log('Registration successful.', self.verbose.event)
                self.num_peers = result['num_peers']
                self.registered = True

    async def request_update(self):
        log('Request to make a model update', self.verbose.event)
        async with websockets.connect(self.socket_uri) as websocket:
            message = json.dumps({
                'action': 'request_model',
                'name': self.name
            })
            self.waiting_for_conductor = True
            await websocket.send(message)
            self.waiting_for_conductor = False
            result = json.loads(await websocket.recv())
            if result['success']:
                log('Ready to make an update...', self.verbose.event)

    async def ping_conductor(self):
        async with websockets.connect(self.socket_uri) as websocket:
            message = json.dumps({
                'action': 'ping_conductor',
                'name': self.name
            })
            await websocket.send(message)
            result = json.loads(await websocket.recv())
            return result
def start_websocket_server_worker(id,
                                  host,
                                  port,
                                  hook,
                                  verbose,
                                  keep_labels=None,
                                  training=True,
                                  pytest_testing=False):
    """Helper function for spinning up a websocket server and setting up the local datasets."""

    server = WebsocketServerWorker(id=id,
                                   host=host,
                                   port=port,
                                   hook=hook,
                                   verbose=verbose)

    X, Y, max_features, max_len = init_data(id, keep_labels)
    X, x_test, Y, y_test = train_test_split(X,
                                            Y,
                                            test_size=0.0001,
                                            shuffle=True)

    if not training:
        selected_data = torch.LongTensor(X)
        selected_targets = torch.LongTensor(Y).squeeze(1)
    else:
        if id == 'alice':
            selected_data = torch.LongTensor(X)
            selected_targets = torch.LongTensor(Y).squeeze(1)
        elif id == 'bob':
            selected_data = torch.LongTensor(X)
            selected_targets = torch.LongTensor(Y).squeeze(1)
        elif id == 'charlie':
            selected_data = torch.LongTensor(X)
            selected_targets = torch.LongTensor(Y).squeeze(1)

    if training:

        dataset = sy.BaseDataset(data=selected_data, targets=selected_targets)
        key = "dga"
    else:
        dataset = sy.BaseDataset(
            data=selected_data,
            targets=selected_targets,
        )
        key = "dga_testing"

    # Adding Dataset
    server.add_dataset(dataset, key=key)

    count = [0] * 2

    for i in range(2):
        count[i] = (dataset.targets == i).sum().item()
        logger.info("      %s: %s", i, count[i])

    logger.info("datasets: %s", server.datasets)
    if training:
        logger.info("Examples in local dataset: %s", len(server.datasets[key]))

    server.start()
    return server
Пример #14
0
    id=id,
    port=port)
# hook = sy.TorchHook(torch, local_worker=server_worker)

# data in server
x = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1.]],
                 requires_grad=True).tag("toy", "data")
y = torch.tensor([[0], [0], [1], [1.]],
                 requires_grad=True).tag("toy", "target")
# x.private, x.private = True, True

x_ptr = x.send(server_worker)
y_ptr = y.send(server_worker)
print(x_ptr, y_ptr)

# x = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=False)
# y = torch.tensor([[0],[0],[1],[1.]], requires_grad=False)
# server_worker.add_dataset(sy.BaseDataset(data=x, targets=y), key="vectors")

print('>>> server_worker:', server_worker)
print('>>> server_worker.list_objects():', server_worker.list_objects())
print('>>> server_worker.list_tensors():', server_worker.list_tensors())

server_worker.start(
)  # Might need to interrupt with `CTRL-C` or some other means

print('>>> server_worker.list_objects()', server_worker.list_objects())
print('>>> server_worker.objects_count()', server_worker.objects_count())
print('>>> server_worker.list_tensors():', server_worker.list_tensors())
print('>>> server_worker.host', server_worker.host)
print('>>> server_worker.port', server_worker.port)