def test_evaluate(hook, start_proc): # pragma: no cover sy.local_worker.clear_objects() sy.frameworks.torch.hook.hook_args.hook_method_args_functions = {} sy.frameworks.torch.hook.hook_args.hook_method_response_functions = {} sy.frameworks.torch.hook.hook_args.get_tensor_type_functions = {} sy.frameworks.torch.hook.hook_args.register_response_functions = {} data, target = utils.iris_data_partial() dataset = sy.BaseDataset(data=data, targets=target) kwargs = { "id": "evaluate_remote", "host": "localhost", "port": 8780, "hook": hook } dataset_key = "iris" # TODO: check why unit test sometimes fails when WebsocketServerWorker is started from the unit test. Fails when run after test_federated_client.py # process_remote_worker = start_proc(WebsocketServerWorker, dataset=(dataset, dataset_key), verbose=True, **kwargs) local_worker = instantiate_websocket_client_worker(**kwargs) def loss_fn(pred, target): return torch.nn.functional.cross_entropy(input=pred, target=target) class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = torch.nn.Linear(4, 3) torch.nn.init.xavier_normal_(self.fc1.weight) def forward(self, x): x = torch.nn.functional.relu(self.fc1(x)) return x model_untraced = Net() model = torch.jit.trace(model_untraced, data) loss_traced = torch.jit.trace( loss_fn, (torch.tensor([[0.3, 0.5, 0.2]]), torch.tensor([1]))) pred = model(data) loss_before = loss_fn(target=target, pred=pred) if PRINT_IN_UNITTESTS: # pragma: no cover print("Loss: {}".format(loss_before)) # Create and send train config train_config = sy.TrainConfig( batch_size=4, model=model, loss_fn=loss_traced, model_id=None, loss_fn_id=None, optimizer_args=None, epochs=1, ) train_config.send(local_worker) result = local_worker.evaluate(dataset_key=dataset_key, calculate_histograms=True, nr_bins=3, calculate_loss=True) test_loss_before, correct_before, len_dataset, hist_pred_before, hist_target = result if PRINT_IN_UNITTESTS: # pragma: no cover print("Evaluation result before training: {}".format(result)) assert len_dataset == 30 assert (hist_target == [10, 10, 10]).all() local_worker.close() local_worker.remove_worker_from_local_worker_registry()
def start_websocket_server_worker( id, host, port, hook, verbose, keep_labels=None, training=True ): # pragma: no cover """Helper function for spinning up a websocket server and setting up the local datasets.""" server = WebsocketServerWorker(id=id, host=host, port=port, hook=hook, verbose=verbose) # Setup toy data (mnist example) mnist_dataset = datasets.MNIST( root="./data", train=training, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ), ) if training: indices = np.isin(mnist_dataset.targets, keep_labels).astype("uint8") logger.info("number of true indices: %s", indices.sum()) selected_data = ( torch.native_masked_select(mnist_dataset.data.transpose(0, 2), torch.tensor(indices)) .view(28, 28, -1) .transpose(2, 0) ) logger.info("after selection: %s", selected_data.shape) selected_targets = torch.native_masked_select(mnist_dataset.targets, torch.tensor(indices)) dataset = sy.BaseDataset( data=selected_data, targets=selected_targets, transform=mnist_dataset.transform ) key = "mnist" else: dataset = sy.BaseDataset( data=mnist_dataset.data, targets=mnist_dataset.targets, transform=mnist_dataset.transform, ) key = "mnist_testing" server.add_dataset(dataset, key=key) # Setup toy data (vectors example) data_vectors = torch.tensor([[-1, 2.0], [0, 1.1], [-1, 2.1], [0, 1.2]], requires_grad=True) target_vectors = torch.tensor([[1], [0], [1], [0]]) server.add_dataset(sy.BaseDataset(data_vectors, target_vectors), key="vectors") # Setup toy data (xor example) data_xor = torch.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True) target_xor = torch.tensor([1.0, 1.0, 0.0, 0.0], requires_grad=False) server.add_dataset(sy.BaseDataset(data_xor, target_xor), key="xor") # Setup gaussian mixture dataset data, target = utils.create_gaussian_mixture_toy_data(nr_samples=100) server.add_dataset(sy.BaseDataset(data, target), key="gaussian_mixture") # Setup partial iris dataset data, target = utils.iris_data_partial() dataset = sy.BaseDataset(data, target) dataset_key = "iris" server.add_dataset(dataset, key=dataset_key) logger.info("datasets: %s", server.datasets) if training: logger.info("len(datasets[mnist]): %s", len(server.datasets["mnist"])) server.start() return server
def test_evaluate(): # pragma: no cover data, target = utils.iris_data_partial() fed_client = FederatedClient() dataset = sy.BaseDataset(data, target) dataset_key = "iris" fed_client.add_dataset(dataset, key=dataset_key) def loss_fn(pred, target): return torch.nn.functional.cross_entropy(input=pred, target=target) class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = torch.nn.Linear(4, 3) def forward(self, x): x = torch.nn.functional.relu(self.fc1(x)) return x model_untraced = Net() with torch.no_grad(): model_untraced.fc1.weight.set_( torch.tensor([ [0.0160, 1.3753, -0.1202, -0.9129], [0.1539, 0.3092, 0.0749, 0.2142], [0.0984, 0.6248, 0.0274, 0.1735], ])) model_untraced.fc1.bias.set_(torch.tensor([0.3477, 0.2970, -0.0799])) model = torch.jit.trace(model_untraced, data) model_id = 0 model_ow = ObjectWrapper(obj=model, id=model_id) loss_id = 1 loss_ow = ObjectWrapper(obj=loss_fn, id=loss_id) pred = model(data) loss_before = loss_fn(target=target, pred=pred) if PRINT_IN_UNITTESTS: # pragma: no cover print("Loss before training: {}".format(loss_before)) # Create and send train config train_config = sy.TrainConfig( batch_size=8, model=None, loss_fn=None, model_id=model_id, loss_fn_id=loss_id, optimizer_args=None, epochs=1, ) fed_client.set_obj(model_ow) fed_client.set_obj(loss_ow) fed_client.set_obj(train_config) fed_client.optimizer = None result = fed_client.evaluate(dataset_key=dataset_key, return_histograms=True, nr_bins=3, return_loss=True) test_loss_before = result["loss"] correct_before = result["nr_correct_predictions"] len_dataset = result["nr_predictions"] hist_pred_before = result["histogram_predictions"] hist_target = result["histogram_target"] if PRINT_IN_UNITTESTS: # pragma: no cover print("Evaluation result before training: {}".format(result)) assert len_dataset == 30 assert (hist_target == [10, 10, 10]).all() train_config = sy.TrainConfig( batch_size=8, model=None, loss_fn=None, model_id=model_id, loss_fn_id=loss_id, optimizer="SGD", optimizer_args={"lr": 0.01}, shuffle=True, epochs=2, ) fed_client.set_obj(train_config) train_model(fed_client, dataset_key, available_dataset_key=dataset_key, nr_rounds=50) result = fed_client.evaluate(dataset_key=dataset_key, return_histograms=True, nr_bins=3, return_loss=True) test_loss_after = result["loss"] correct_after = result["nr_correct_predictions"] len_dataset = result["nr_predictions"] hist_pred_after = result["histogram_predictions"] hist_target = result["histogram_target"] if PRINT_IN_UNITTESTS: # pragma: no cover print("Evaluation result: {}".format(result)) assert len_dataset == 30 assert (hist_target == [10, 10, 10]).all() assert correct_after > correct_before assert torch.norm( torch.tensor(hist_target - hist_pred_after)) < torch.norm( torch.tensor(hist_target - hist_pred_before))