Ejemplo n.º 1
0
def main(number, start_slice, end_slice):
    mnist_dataset = TrainDataset(transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))]),
                                 number=number,
                                 start_slice=start_slice,
                                 end_slice=end_slice)
    _id = 'h%s' % number
    ip = '10.0.0.%s' % number

    hook = syft.TorchHook(torch)

    server = WebsocketServerWorker(id=_id,
                                   host=ip,
                                   port=8778,
                                   hook=hook,
                                   verbose=True)
    print("Worker:{}, Dataset contains {}".format(_id,
                                                  str(len(
                                                      mnist_dataset.data))))
    dataset = syft.BaseDataset(data=mnist_dataset.data,
                               targets=mnist_dataset.target,
                               transform=mnist_dataset.transform)
    key = "targeted"
    server.add_dataset(dataset, key=key)
    server.start()
Ejemplo n.º 2
0
def start_websocket_server_worker(id, host, port, hook, verbose, keep_labels=None, training=True):
    """Helper function for spinning up a websocket server and setting up the local datasets."""

    server = WebsocketServerWorker(id=id, host=host, port=port, hook=hook, verbose=verbose)

    # Setup toy data (mnist example)
    mnist_dataset = datasets.MNIST(
        root="./data",
        train=training,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    )

    if training:
        indices = np.isin(mnist_dataset.targets, keep_labels).astype("uint8")
        logger.info("number of true indices: %s", indices.sum())
        selected_data = (
            torch.native_masked_select(mnist_dataset.data.transpose(0, 2), torch.tensor(indices))
            .view(28, 28, -1)
            .transpose(2, 0)
        )
        logger.info("after selection: %s", selected_data.shape)
        selected_targets = torch.native_masked_select(mnist_dataset.targets, torch.tensor(indices))

        dataset = sy.BaseDataset(
            data=selected_data, targets=selected_targets, transform=mnist_dataset.transform
        )
        key = "mnist"
    else:
        dataset = sy.BaseDataset(
            data=mnist_dataset.data,
            targets=mnist_dataset.targets,
            transform=mnist_dataset.transform,
        )
        key = "mnist_testing"

    server.add_dataset(dataset, key=key)

    logger.info("datasets: %s", server.datasets)
    if training:
        logger.info("len(datasets[mnist]): %s", len(server.datasets["mnist"]))

    server.start()
    return server
def main(**kwargs):  # pragma: no cover
    """Helper function for spinning up a websocket participant."""

    # Create websocket worker
    worker = WebsocketServerWorker(**kwargs)

    # Setup toy data (xor example)
    data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)
    target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)

    # Create a dataset using the toy data
    dataset = sy.BaseDataset(data, target)

    # Tell the worker about the dataset
    worker.add_dataset(dataset, key="xor")

    # Start worker
    worker.start()

    return worker
Ejemplo n.º 4
0
def start_websocket_server_worker(id,
                                  host,
                                  port,
                                  hook,
                                  verbose,
                                  dataset,
                                  training=True):
    """Helper function for spinning up a websocket server and setting up the local datasets."""

    server = WebsocketServerWorker(id=id,
                                   host=host,
                                   port=port,
                                   hook=hook,
                                   verbose=verbose)
    dataset_key = dataset
    #if we are in the traning loop
    if training:
        with open("./data/split/%d" % int(id), "rb") as fp:  # Unpickling
            data = pickle.load(fp)
        dataset_data, dataset_target = readnpy(data)
        print(type(dataset_data.long()))
        logger.info("Number of samples for client %s is %s : ", id,
                    len(dataset_data))
        dataset = sy.BaseDataset(data=dataset_data, targets=dataset_target)
        key = dataset_key

    nb_labels = len(torch.unique(dataset_target))
    server.add_dataset(dataset, key=key)
    count = [0] * nb_labels
    logger.info("Dataset(train set) ,available numbers on %s: ", id)
    for i in range(nb_labels):
        count[i] = (dataset.targets == i).sum().item()
        logger.info("      %s: %s", i, count[i])
    logger.info("datasets: %s", server.datasets)
    if training:
        logger.info("len(datasets): %s", len(server.datasets[key]))

    server.start()
    return server
Ejemplo n.º 5
0
def start_websocket_server_worker(id,
                                  host,
                                  port,
                                  hook,
                                  verbose,
                                  keep_labels=None,
                                  training=True):  # pragma: no cover
    """Helper function for spinning up a websocket server and setting up the local datasets."""

    server = WebsocketServerWorker(id=id,
                                   host=host,
                                   port=port,
                                   hook=hook,
                                   verbose=verbose)

    # Setup toy data (mnist example)
    mnist_dataset = datasets.MNIST(
        root="./data",
        train=training,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ]),
    )

    if training:
        indices = np.isin(mnist_dataset.targets, keep_labels).astype("uint8")
        logger.info("number of true indices: %s", indices.sum())
        selected_data = (torch.native_masked_select(
            mnist_dataset.data.transpose(0, 2),
            torch.tensor(indices)).view(28, 28, -1).transpose(2, 0))
        logger.info("after selection: %s", selected_data.shape)
        selected_targets = torch.native_masked_select(mnist_dataset.targets,
                                                      torch.tensor(indices))

        dataset = sy.BaseDataset(data=selected_data,
                                 targets=selected_targets,
                                 transform=mnist_dataset.transform)
        key = "mnist"
    else:
        dataset = sy.BaseDataset(
            data=mnist_dataset.data,
            targets=mnist_dataset.targets,
            transform=mnist_dataset.transform,
        )
        key = "mnist_testing"

    server.add_dataset(dataset, key=key)

    # Setup toy data (vectors example)
    data_vectors = torch.tensor([[-1, 2.0], [0, 1.1], [-1, 2.1], [0, 1.2]],
                                requires_grad=True)
    target_vectors = torch.tensor([[1], [0], [1], [0]])

    server.add_dataset(sy.BaseDataset(data_vectors, target_vectors),
                       key="vectors")

    # Setup toy data (xor example)
    data_xor = torch.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]],
                            requires_grad=True)
    target_xor = torch.tensor([1.0, 1.0, 0.0, 0.0], requires_grad=False)

    server.add_dataset(sy.BaseDataset(data_xor, target_xor), key="xor")

    # Setup gaussian mixture dataset
    data, target = utils.create_gaussian_mixture_toy_data(nr_samples=100)
    server.add_dataset(sy.BaseDataset(data, target), key="gaussian_mixture")

    # Setup partial iris dataset
    data, target = utils.iris_data_partial()
    dataset = sy.BaseDataset(data, target)
    dataset_key = "iris"
    server.add_dataset(dataset, key=dataset_key)

    logger.info("datasets: %s", server.datasets)
    if training:
        logger.info("len(datasets[mnist]): %s", len(server.datasets["mnist"]))

    server.start()
    return server
def start_websocket_server_worker(id,
                                  host,
                                  port,
                                  hook,
                                  verbose,
                                  keep_labels=None,
                                  training=True,
                                  pytest_testing=False):
    """Helper function for spinning up a websocket server and setting up the local datasets."""

    server = WebsocketServerWorker(id=id,
                                   host=host,
                                   port=port,
                                   hook=hook,
                                   verbose=verbose)

    X, Y, max_features, max_len = init_data(id, keep_labels)
    X, x_test, Y, y_test = train_test_split(X,
                                            Y,
                                            test_size=0.0001,
                                            shuffle=True)

    if not training:
        selected_data = torch.LongTensor(X)
        selected_targets = torch.LongTensor(Y).squeeze(1)
    else:
        if id == 'alice':
            selected_data = torch.LongTensor(X)
            selected_targets = torch.LongTensor(Y).squeeze(1)
        elif id == 'bob':
            selected_data = torch.LongTensor(X)
            selected_targets = torch.LongTensor(Y).squeeze(1)
        elif id == 'charlie':
            selected_data = torch.LongTensor(X)
            selected_targets = torch.LongTensor(Y).squeeze(1)

    if training:

        dataset = sy.BaseDataset(data=selected_data, targets=selected_targets)
        key = "dga"
    else:
        dataset = sy.BaseDataset(
            data=selected_data,
            targets=selected_targets,
        )
        key = "dga_testing"

    # Adding Dataset
    server.add_dataset(dataset, key=key)

    count = [0] * 2

    for i in range(2):
        count[i] = (dataset.targets == i).sum().item()
        logger.info("      %s: %s", i, count[i])

    logger.info("datasets: %s", server.datasets)
    if training:
        logger.info("Examples in local dataset: %s", len(server.datasets[key]))

    server.start()
    return server