コード例 #1
0
def prepare_training(hook, alice):  # pragma: no cover

    data, target = utils.create_gaussian_mixture_toy_data(nr_samples=100)
    dataset_key = "gaussian_mixture"

    dataset = sy.BaseDataset(data, target)
    alice.add_dataset(dataset, key=dataset_key)

    @hook.torch.jit.script
    def loss_fn(pred, target):
        return ((pred - target.unsqueeze(1))**2).mean()

    class Net(torch.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(2, 3)
            self.fc2 = nn.Linear(3, 2)
            self.fc3 = nn.Linear(2, 1)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    model_untraced = Net()

    model = torch.jit.trace(model_untraced, data)

    pred = model(data)
    loss_before = loss_fn(target=target, pred=pred)
    return model, loss_fn, data, target, loss_before, dataset_key
コード例 #2
0
def start_websocket_server_worker(id,
                                  host,
                                  port,
                                  hook,
                                  verbose,
                                  keep_labels=None,
                                  training=True):  # pragma: no cover
    """Helper function for spinning up a websocket server and setting up the local datasets."""

    server = WebsocketServerWorker(id=id,
                                   host=host,
                                   port=port,
                                   hook=hook,
                                   verbose=verbose)

    # Setup toy data (mnist example)
    mnist_dataset = datasets.MNIST(
        root="./data",
        train=training,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ]),
    )

    if training:
        indices = np.isin(mnist_dataset.targets, keep_labels).astype("uint8")
        logger.info("number of true indices: %s", indices.sum())
        selected_data = (torch.native_masked_select(
            mnist_dataset.data.transpose(0, 2),
            torch.tensor(indices)).view(28, 28, -1).transpose(2, 0))
        logger.info("after selection: %s", selected_data.shape)
        selected_targets = torch.native_masked_select(mnist_dataset.targets,
                                                      torch.tensor(indices))

        dataset = sy.BaseDataset(data=selected_data,
                                 targets=selected_targets,
                                 transform=mnist_dataset.transform)
        key = "mnist"
    else:
        dataset = sy.BaseDataset(
            data=mnist_dataset.data,
            targets=mnist_dataset.targets,
            transform=mnist_dataset.transform,
        )
        key = "mnist_testing"

    server.add_dataset(dataset, key=key)

    # Setup toy data (vectors example)
    data_vectors = torch.tensor([[-1, 2.0], [0, 1.1], [-1, 2.1], [0, 1.2]],
                                requires_grad=True)
    target_vectors = torch.tensor([[1], [0], [1], [0]])

    server.add_dataset(sy.BaseDataset(data_vectors, target_vectors),
                       key="vectors")

    # Setup toy data (xor example)
    data_xor = torch.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]],
                            requires_grad=True)
    target_xor = torch.tensor([1.0, 1.0, 0.0, 0.0], requires_grad=False)

    server.add_dataset(sy.BaseDataset(data_xor, target_xor), key="xor")

    # Setup gaussian mixture dataset
    data, target = utils.create_gaussian_mixture_toy_data(nr_samples=100)
    server.add_dataset(sy.BaseDataset(data, target), key="gaussian_mixture")

    # Setup partial iris dataset
    data, target = utils.iris_data_partial()
    dataset = sy.BaseDataset(data, target)
    dataset_key = "iris"
    server.add_dataset(dataset, key=dataset_key)

    logger.info("datasets: %s", server.datasets)
    if training:
        logger.info("len(datasets[mnist]): %s", len(server.datasets["mnist"]))

    server.start()
    return server
コード例 #3
0
def test_fit(fit_dataset_key, epochs, device):

    if device == "cuda" and not torch.cuda.is_available():
        return

    data, target = utils.create_gaussian_mixture_toy_data(nr_samples=100)

    fed_client = FederatedClient()
    dataset = sy.BaseDataset(data, target)
    dataset_key = "gaussian_mixture"
    fed_client.add_dataset(dataset, key=dataset_key)

    def loss_fn(pred, target):
        return torch.nn.functional.cross_entropy(input=pred, target=target)

    class Net(torch.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = torch.nn.Linear(2, 3)
            self.fc2 = torch.nn.Linear(3, 2)

            torch.nn.init.xavier_normal_(self.fc1.weight)
            torch.nn.init.xavier_normal_(self.fc2.weight)

        def forward(self, x):
            x = torch.nn.functional.relu(self.fc1(x))
            x = torch.nn.functional.relu(self.fc2(x))
            return x

    data_device = data.to(torch.device(device))
    target_device = target.to(torch.device(device))
    model_untraced = Net().to(torch.device(device))
    model = torch.jit.trace(model_untraced, data_device)
    model_id = 0
    model_ow = ObjectWrapper(obj=model, id=model_id)
    loss_id = 1
    loss_ow = ObjectWrapper(obj=loss_fn, id=loss_id)
    pred = model(data_device)
    loss_before = loss_fn(target=target_device, pred=pred)
    if PRINT_IN_UNITTESTS:  # pragma: no cover
        print("Loss before training: {}".format(loss_before))

    # Create and send train config
    train_config = sy.TrainConfig(
        batch_size=8,
        model=None,
        loss_fn=None,
        model_id=model_id,
        loss_fn_id=loss_id,
        optimizer_args={"lr": 0.05, "weight_decay": 0.01},
        epochs=epochs,
    )

    fed_client.set_obj(model_ow)
    fed_client.set_obj(loss_ow)
    fed_client.set_obj(train_config)
    fed_client.optimizer = None

    train_model(
        fed_client, fit_dataset_key, available_dataset_key=dataset_key, nr_rounds=3, device=device
    )

    if dataset_key == fit_dataset_key:
        loss_after = evaluate_model(fed_client, model_id, loss_fn, data_device, target_device)
        if PRINT_IN_UNITTESTS:  # pragma: no cover
            print("Loss after training: {}".format(loss_after))

        if loss_after >= loss_before:  # pragma: no cover
            if PRINT_IN_UNITTESTS:
                print("Loss not reduced, train more: {}".format(loss_after))

            train_model(
                fed_client, fit_dataset_key, available_dataset_key=dataset_key, nr_rounds=10
            )
            loss_after = evaluate_model(fed_client, model_id, loss_fn, data, target)

        assert loss_after < loss_before
コード例 #4
0
def test_train_config_with_jit_trace_sync(
        hook, start_remote_worker):  # pragma: no cover
    data, target = utils.create_gaussian_mixture_toy_data(100)
    dataset = sy.BaseDataset(data, target)
    dataset_key = "gaussian_mixture"

    server, remote_proxy = start_remote_worker(id="sync_fit",
                                               hook=hook,
                                               port=9000,
                                               dataset=(dataset, dataset_key))

    @hook.torch.jit.script
    def loss_fn(pred, target):
        return ((target.view(pred.shape).float() - pred.float())**2).mean()

    class Net(torch.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(2, 3)
            self.fc2 = nn.Linear(3, 2)
            self.fc3 = nn.Linear(2, 1)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    model_untraced = Net()

    model = torch.jit.trace(model_untraced, data)

    pred = model(data)
    loss_before = loss_fn(pred=pred, target=target)

    # Create and send train config
    train_config = sy.TrainConfig(model=model,
                                  loss_fn=loss_fn,
                                  batch_size=2,
                                  epochs=1)
    train_config.send(remote_proxy)

    for epoch in range(5):
        loss = remote_proxy.fit(dataset_key=dataset_key)
        if PRINT_IN_UNITTESTS:  # pragma: no cover
            print("-" * 50)
            print("Iteration %s: alice's loss: %s" % (epoch, loss))

    new_model = train_config.model_ptr.get()

    # assert that the new model has updated (modified) parameters
    assert not ((model.fc1._parameters["weight"] -
                 new_model.obj.fc1._parameters["weight"]).abs() < 10e-3).all()
    assert not ((model.fc2._parameters["weight"] -
                 new_model.obj.fc2._parameters["weight"]).abs() < 10e-3).all()
    assert not ((model.fc3._parameters["weight"] -
                 new_model.obj.fc3._parameters["weight"]).abs() < 10e-3).all()
    assert not ((model.fc1._parameters["bias"] -
                 new_model.obj.fc1._parameters["bias"]).abs() < 10e-3).all()
    assert not ((model.fc2._parameters["bias"] -
                 new_model.obj.fc2._parameters["bias"]).abs() < 10e-3).all()
    assert not ((model.fc3._parameters["bias"] -
                 new_model.obj.fc3._parameters["bias"]).abs() < 10e-3).all()

    new_model.obj.eval()
    pred = new_model.obj(data)
    loss_after = loss_fn(pred=pred, target=target)

    if PRINT_IN_UNITTESTS:  # pragma: no cover
        print("Loss before training: {}".format(loss_before))
        print("Loss after training: {}".format(loss_after))

    remote_proxy.close()
    server.terminate()

    assert loss_after < loss_before
コード例 #5
0
async def test_train_config_with_jit_trace_async(
        hook, start_proc):  # pragma: no cover
    kwargs = {
        "id": "async_fit",
        "host": "localhost",
        "port": 8777,
        "hook": hook
    }
    # data = torch.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)
    # target = torch.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)
    # dataset_key = "xor"
    data, target = utils.create_gaussian_mixture_toy_data(100)
    dataset_key = "gaussian_mixture"

    mock_data = torch.zeros(1, 2)

    # TODO check reason for error (RuntimeError: This event loop is already running) when starting websocket server from pytest-asyncio environment
    # dataset = sy.BaseDataset(data, target)

    # server, remote_proxy = start_remote_worker(id="async_fit", port=8777, hook=hook, dataset=(dataset, dataset_key))

    # time.sleep(0.1)

    remote_proxy = WebsocketClientWorker(**kwargs)

    @hook.torch.jit.script
    def loss_fn(pred, target):
        return ((target.view(pred.shape).float() - pred.float())**2).mean()

    class Net(torch.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(2, 3)
            self.fc2 = nn.Linear(3, 2)
            self.fc3 = nn.Linear(2, 1)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    model_untraced = Net()

    model = torch.jit.trace(model_untraced, mock_data)

    pred = model(data)
    loss_before = loss_fn(target=target, pred=pred)

    # Create and send train config
    train_config = sy.TrainConfig(model=model,
                                  loss_fn=loss_fn,
                                  batch_size=2,
                                  optimizer="SGD",
                                  optimizer_args={"lr": 0.1})
    train_config.send(remote_proxy)

    for epoch in range(5):
        loss = await remote_proxy.async_fit(dataset_key=dataset_key)
        if PRINT_IN_UNITTESTS:  # pragma: no cover
            print("-" * 50)
            print("Iteration %s: alice's loss: %s" % (epoch, loss))

    new_model = train_config.model_ptr.get()

    assert not (model.fc1._parameters["weight"]
                == new_model.obj.fc1._parameters["weight"]).all()
    assert not (model.fc2._parameters["weight"]
                == new_model.obj.fc2._parameters["weight"]).all()
    assert not (model.fc3._parameters["weight"]
                == new_model.obj.fc3._parameters["weight"]).all()
    assert not (model.fc1._parameters["bias"]
                == new_model.obj.fc1._parameters["bias"]).all()
    assert not (model.fc2._parameters["bias"]
                == new_model.obj.fc2._parameters["bias"]).all()
    assert not (model.fc3._parameters["bias"]
                == new_model.obj.fc3._parameters["bias"]).all()

    new_model.obj.eval()
    pred = new_model.obj(data)
    loss_after = loss_fn(target=target, pred=pred)
    if PRINT_IN_UNITTESTS:  # pragma: no cover
        print("Loss before training: {}".format(loss_before))
        print("Loss after training: {}".format(loss_after))

    remote_proxy.close()
    # server.terminate()

    assert loss_after < loss_before