Beispiel #1
0
def test_evaluate(hook, start_proc):  # pragma: no cover

    sy.local_worker.clear_objects()
    sy.frameworks.torch.hook.hook_args.hook_method_args_functions = {}
    sy.frameworks.torch.hook.hook_args.hook_method_response_functions = {}
    sy.frameworks.torch.hook.hook_args.get_tensor_type_functions = {}
    sy.frameworks.torch.hook.hook_args.register_response_functions = {}

    data, target = utils.iris_data_partial()

    dataset = sy.BaseDataset(data=data, targets=target)

    kwargs = {
        "id": "evaluate_remote",
        "host": "localhost",
        "port": 8780,
        "hook": hook
    }
    dataset_key = "iris"
    # TODO: check why unit test sometimes fails when WebsocketServerWorker is started from the unit test. Fails when run after test_federated_client.py
    # process_remote_worker = start_proc(WebsocketServerWorker, dataset=(dataset, dataset_key), verbose=True, **kwargs)

    local_worker = instantiate_websocket_client_worker(**kwargs)

    def loss_fn(pred, target):
        return torch.nn.functional.cross_entropy(input=pred, target=target)

    class Net(torch.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = torch.nn.Linear(4, 3)

            torch.nn.init.xavier_normal_(self.fc1.weight)

        def forward(self, x):
            x = torch.nn.functional.relu(self.fc1(x))
            return x

    model_untraced = Net()
    model = torch.jit.trace(model_untraced, data)
    loss_traced = torch.jit.trace(
        loss_fn, (torch.tensor([[0.3, 0.5, 0.2]]), torch.tensor([1])))

    pred = model(data)
    loss_before = loss_fn(target=target, pred=pred)
    if PRINT_IN_UNITTESTS:  # pragma: no cover
        print("Loss: {}".format(loss_before))

    # Create and send train config
    train_config = sy.TrainConfig(
        batch_size=4,
        model=model,
        loss_fn=loss_traced,
        model_id=None,
        loss_fn_id=None,
        optimizer_args=None,
        epochs=1,
    )
    train_config.send(local_worker)

    result = local_worker.evaluate(dataset_key=dataset_key,
                                   calculate_histograms=True,
                                   nr_bins=3,
                                   calculate_loss=True)

    test_loss_before, correct_before, len_dataset, hist_pred_before, hist_target = result

    if PRINT_IN_UNITTESTS:  # pragma: no cover
        print("Evaluation result before training: {}".format(result))

    assert len_dataset == 30
    assert (hist_target == [10, 10, 10]).all()

    local_worker.close()
    local_worker.remove_worker_from_local_worker_registry()
Beispiel #2
0
def start_websocket_server_worker(
    id, host, port, hook, verbose, keep_labels=None, training=True
):  # pragma: no cover
    """Helper function for spinning up a websocket server and setting up the local datasets."""

    server = WebsocketServerWorker(id=id, host=host, port=port, hook=hook, verbose=verbose)

    # Setup toy data (mnist example)
    mnist_dataset = datasets.MNIST(
        root="./data",
        train=training,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    )

    if training:
        indices = np.isin(mnist_dataset.targets, keep_labels).astype("uint8")
        logger.info("number of true indices: %s", indices.sum())
        selected_data = (
            torch.native_masked_select(mnist_dataset.data.transpose(0, 2), torch.tensor(indices))
            .view(28, 28, -1)
            .transpose(2, 0)
        )
        logger.info("after selection: %s", selected_data.shape)
        selected_targets = torch.native_masked_select(mnist_dataset.targets, torch.tensor(indices))

        dataset = sy.BaseDataset(
            data=selected_data, targets=selected_targets, transform=mnist_dataset.transform
        )
        key = "mnist"
    else:
        dataset = sy.BaseDataset(
            data=mnist_dataset.data,
            targets=mnist_dataset.targets,
            transform=mnist_dataset.transform,
        )
        key = "mnist_testing"

    server.add_dataset(dataset, key=key)

    # Setup toy data (vectors example)
    data_vectors = torch.tensor([[-1, 2.0], [0, 1.1], [-1, 2.1], [0, 1.2]], requires_grad=True)
    target_vectors = torch.tensor([[1], [0], [1], [0]])

    server.add_dataset(sy.BaseDataset(data_vectors, target_vectors), key="vectors")

    # Setup toy data (xor example)
    data_xor = torch.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)
    target_xor = torch.tensor([1.0, 1.0, 0.0, 0.0], requires_grad=False)

    server.add_dataset(sy.BaseDataset(data_xor, target_xor), key="xor")

    # Setup gaussian mixture dataset
    data, target = utils.create_gaussian_mixture_toy_data(nr_samples=100)
    server.add_dataset(sy.BaseDataset(data, target), key="gaussian_mixture")

    # Setup partial iris dataset
    data, target = utils.iris_data_partial()
    dataset = sy.BaseDataset(data, target)
    dataset_key = "iris"
    server.add_dataset(dataset, key=dataset_key)

    logger.info("datasets: %s", server.datasets)
    if training:
        logger.info("len(datasets[mnist]): %s", len(server.datasets["mnist"]))

    server.start()
    return server
Beispiel #3
0
def test_evaluate():  # pragma: no cover
    data, target = utils.iris_data_partial()

    fed_client = FederatedClient()
    dataset = sy.BaseDataset(data, target)
    dataset_key = "iris"
    fed_client.add_dataset(dataset, key=dataset_key)

    def loss_fn(pred, target):
        return torch.nn.functional.cross_entropy(input=pred, target=target)

    class Net(torch.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = torch.nn.Linear(4, 3)

        def forward(self, x):
            x = torch.nn.functional.relu(self.fc1(x))
            return x

    model_untraced = Net()

    with torch.no_grad():
        model_untraced.fc1.weight.set_(
            torch.tensor([
                [0.0160, 1.3753, -0.1202, -0.9129],
                [0.1539, 0.3092, 0.0749, 0.2142],
                [0.0984, 0.6248, 0.0274, 0.1735],
            ]))
        model_untraced.fc1.bias.set_(torch.tensor([0.3477, 0.2970, -0.0799]))

    model = torch.jit.trace(model_untraced, data)
    model_id = 0
    model_ow = ObjectWrapper(obj=model, id=model_id)
    loss_id = 1
    loss_ow = ObjectWrapper(obj=loss_fn, id=loss_id)
    pred = model(data)
    loss_before = loss_fn(target=target, pred=pred)
    if PRINT_IN_UNITTESTS:  # pragma: no cover
        print("Loss before training: {}".format(loss_before))

    # Create and send train config
    train_config = sy.TrainConfig(
        batch_size=8,
        model=None,
        loss_fn=None,
        model_id=model_id,
        loss_fn_id=loss_id,
        optimizer_args=None,
        epochs=1,
    )

    fed_client.set_obj(model_ow)
    fed_client.set_obj(loss_ow)
    fed_client.set_obj(train_config)
    fed_client.optimizer = None

    result = fed_client.evaluate(dataset_key=dataset_key,
                                 return_histograms=True,
                                 nr_bins=3,
                                 return_loss=True)

    test_loss_before = result["loss"]
    correct_before = result["nr_correct_predictions"]
    len_dataset = result["nr_predictions"]
    hist_pred_before = result["histogram_predictions"]
    hist_target = result["histogram_target"]

    if PRINT_IN_UNITTESTS:  # pragma: no cover
        print("Evaluation result before training: {}".format(result))

    assert len_dataset == 30
    assert (hist_target == [10, 10, 10]).all()

    train_config = sy.TrainConfig(
        batch_size=8,
        model=None,
        loss_fn=None,
        model_id=model_id,
        loss_fn_id=loss_id,
        optimizer="SGD",
        optimizer_args={"lr": 0.01},
        shuffle=True,
        epochs=2,
    )
    fed_client.set_obj(train_config)
    train_model(fed_client,
                dataset_key,
                available_dataset_key=dataset_key,
                nr_rounds=50)

    result = fed_client.evaluate(dataset_key=dataset_key,
                                 return_histograms=True,
                                 nr_bins=3,
                                 return_loss=True)

    test_loss_after = result["loss"]
    correct_after = result["nr_correct_predictions"]
    len_dataset = result["nr_predictions"]
    hist_pred_after = result["histogram_predictions"]
    hist_target = result["histogram_target"]

    if PRINT_IN_UNITTESTS:  # pragma: no cover
        print("Evaluation result: {}".format(result))

    assert len_dataset == 30
    assert (hist_target == [10, 10, 10]).all()
    assert correct_after > correct_before
    assert torch.norm(
        torch.tensor(hist_target - hist_pred_after)) < torch.norm(
            torch.tensor(hist_target - hist_pred_before))