def prepare_training(hook, alice): # pragma: no cover data, target = utils.create_gaussian_mixture_toy_data(nr_samples=100) dataset_key = "gaussian_mixture" dataset = sy.BaseDataset(data, target) alice.add_dataset(dataset, key=dataset_key) @hook.torch.jit.script def loss_fn(pred, target): return ((target.float() - pred.float())**2).mean() class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(2, 3) self.fc2 = nn.Linear(3, 2) self.fc3 = nn.Linear(2, 1) nn.init.xavier_uniform_(self.fc1.weight) nn.init.xavier_uniform_(self.fc2.weight) nn.init.xavier_uniform_(self.fc3.weight) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model_untraced = Net() model = torch.jit.trace(model_untraced, data) pred = model(data) loss_before = loss_fn(target=target, pred=pred) return model, loss_fn, data, target, loss_before, dataset_key
def test_train_config_with_jit_trace_sync( hook, start_remote_worker): # pragma: no cover data, target = utils.create_gaussian_mixture_toy_data(100) dataset = sy.BaseDataset(data, target) dataset_key = "gaussian_mixture" server, remote_proxy = start_remote_worker(id="sync_fit", hook=hook, port=9000, dataset=(dataset, dataset_key)) @hook.torch.jit.script def loss_fn(pred, target): return ((target.view(pred.shape).float() - pred.float())**2).mean() class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(2, 3) self.fc2 = nn.Linear(3, 2) self.fc3 = nn.Linear(2, 1) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model_untraced = Net() model = torch.jit.trace(model_untraced, data) pred = model(data) loss_before = loss_fn(pred=pred, target=target) # Create and send train config train_config = sy.TrainConfig(model=model, loss_fn=loss_fn, batch_size=2, epochs=1) train_config.send(remote_proxy) for epoch in range(5): loss = remote_proxy.fit(dataset_key=dataset_key) if PRINT_IN_UNITTESTS: # pragma: no cover print("-" * 50) print("Iteration %s: alice's loss: %s" % (epoch, loss)) new_model = train_config.model_ptr.get() # assert that the new model has updated (modified) parameters assert not ((model.fc1._parameters["weight"] - new_model.obj.fc1._parameters["weight"]).abs() < 10e-3).all() assert not ((model.fc2._parameters["weight"] - new_model.obj.fc2._parameters["weight"]).abs() < 10e-3).all() assert not ((model.fc3._parameters["weight"] - new_model.obj.fc3._parameters["weight"]).abs() < 10e-3).all() assert not ((model.fc1._parameters["bias"] - new_model.obj.fc1._parameters["bias"]).abs() < 10e-3).all() assert not ((model.fc2._parameters["bias"] - new_model.obj.fc2._parameters["bias"]).abs() < 10e-3).all() assert not ((model.fc3._parameters["bias"] - new_model.obj.fc3._parameters["bias"]).abs() < 10e-3).all() new_model.obj.eval() pred = new_model.obj(data) loss_after = loss_fn(pred=pred, target=target) if PRINT_IN_UNITTESTS: # pragma: no cover print("Loss before training: {}".format(loss_before)) print("Loss after training: {}".format(loss_after)) remote_proxy.close() server.terminate() assert loss_after < loss_before
async def test_train_config_with_jit_trace_async( hook, start_proc): # pragma: no cover kwargs = { "id": "async_fit", "host": "localhost", "port": 8777, "hook": hook } # data = torch.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True) # target = torch.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False) # dataset_key = "xor" data, target = utils.create_gaussian_mixture_toy_data(100) dataset_key = "gaussian_mixture" mock_data = torch.zeros(1, 2) # TODO check reason for error (RuntimeError: This event loop is already running) when starting websocket server from pytest-asyncio environment # dataset = sy.BaseDataset(data, target) # server, remote_proxy = start_remote_worker(id="async_fit", port=8777, hook=hook, dataset=(dataset, dataset_key)) # time.sleep(0.1) remote_proxy = WebsocketClientWorker(**kwargs) @hook.torch.jit.script def loss_fn(pred, target): return ((target.view(pred.shape).float() - pred.float())**2).mean() class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(2, 3) self.fc2 = nn.Linear(3, 2) self.fc3 = nn.Linear(2, 1) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model_untraced = Net() model = torch.jit.trace(model_untraced, mock_data) pred = model(data) loss_before = loss_fn(target=target, pred=pred) # Create and send train config train_config = sy.TrainConfig(model=model, loss_fn=loss_fn, batch_size=2, optimizer="SGD", optimizer_args={"lr": 0.1}) train_config.send(remote_proxy) for epoch in range(5): loss = await remote_proxy.async_fit(dataset_key=dataset_key) if PRINT_IN_UNITTESTS: # pragma: no cover print("-" * 50) print("Iteration %s: alice's loss: %s" % (epoch, loss)) new_model = train_config.model_ptr.get() assert not (model.fc1._parameters["weight"] == new_model.obj.fc1._parameters["weight"]).all() assert not (model.fc2._parameters["weight"] == new_model.obj.fc2._parameters["weight"]).all() assert not (model.fc3._parameters["weight"] == new_model.obj.fc3._parameters["weight"]).all() assert not (model.fc1._parameters["bias"] == new_model.obj.fc1._parameters["bias"]).all() assert not (model.fc2._parameters["bias"] == new_model.obj.fc2._parameters["bias"]).all() assert not (model.fc3._parameters["bias"] == new_model.obj.fc3._parameters["bias"]).all() new_model.obj.eval() pred = new_model.obj(data) loss_after = loss_fn(target=target, pred=pred) if PRINT_IN_UNITTESTS: # pragma: no cover print("Loss before training: {}".format(loss_before)) print("Loss after training: {}".format(loss_after)) remote_proxy.close() # server.terminate() assert loss_after < loss_before
def start_websocket_server_worker( id, host, port, hook, verbose, keep_labels=None, training=True ): # pragma: no cover """Helper function for spinning up a websocket server and setting up the local datasets.""" server = WebsocketServerWorker(id=id, host=host, port=port, hook=hook, verbose=verbose) # Setup toy data (mnist example) mnist_dataset = datasets.MNIST( root="./data", train=training, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ), ) if training: indices = np.isin(mnist_dataset.targets, keep_labels).astype("uint8") logger.info("number of true indices: %s", indices.sum()) selected_data = ( torch.native_masked_select(mnist_dataset.data.transpose(0, 2), torch.tensor(indices)) .view(28, 28, -1) .transpose(2, 0) ) logger.info("after selection: %s", selected_data.shape) selected_targets = torch.native_masked_select(mnist_dataset.targets, torch.tensor(indices)) dataset = sy.BaseDataset( data=selected_data, targets=selected_targets, transform=mnist_dataset.transform ) key = "mnist" else: dataset = sy.BaseDataset( data=mnist_dataset.data, targets=mnist_dataset.targets, transform=mnist_dataset.transform, ) key = "mnist_testing" server.add_dataset(dataset, key=key) # Setup toy data (vectors example) data_vectors = torch.tensor([[-1, 2.0], [0, 1.1], [-1, 2.1], [0, 1.2]], requires_grad=True) target_vectors = torch.tensor([[1], [0], [1], [0]]) server.add_dataset(sy.BaseDataset(data_vectors, target_vectors), key="vectors") # Setup toy data (xor example) data_xor = torch.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True) target_xor = torch.tensor([1.0, 1.0, 0.0, 0.0], requires_grad=False) server.add_dataset(sy.BaseDataset(data_xor, target_xor), key="xor") # Setup gaussian mixture dataset data, target = utils.create_gaussian_mixture_toy_data(nr_samples=100) server.add_dataset(sy.BaseDataset(data, target), key="gaussian_mixture") # Setup partial iris dataset data, target = utils.iris_data_partial() dataset = sy.BaseDataset(data, target) dataset_key = "iris" server.add_dataset(dataset, key=dataset_key) logger.info("datasets: %s", server.datasets) if training: logger.info("len(datasets[mnist]): %s", len(server.datasets["mnist"])) server.start() return server
def test_fit(): data, target = utils.create_gaussian_mixture_toy_data(nr_samples=100) fed_client = federated.FederatedClient() dataset = sy.BaseDataset(data, target) dataset_key = "gaussian_mixture" fed_client.add_dataset(dataset, key=dataset_key) def loss_fn(target, pred): return torch.nn.functional.cross_entropy(input=pred, target=target) class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = torch.nn.Linear(2, 3) self.fc2 = torch.nn.Linear(3, 2) torch.nn.init.xavier_normal_(self.fc1.weight) torch.nn.init.xavier_normal_(self.fc2.weight) def forward(self, x): x = torch.nn.functional.relu(self.fc1(x)) x = torch.nn.functional.relu(self.fc2(x)) return x model_untraced = Net() model = torch.jit.trace(model_untraced, data) model_id = 0 model_ow = pointers.ObjectWrapper(obj=model, id=model_id) loss_id = 1 loss_ow = pointers.ObjectWrapper(obj=loss_fn, id=loss_id) pred = model(data) loss_before = loss_fn(target=target, pred=pred) if PRINT_IN_UNITTESTS: # pragma: no cover print("Loss before training: {}".format(loss_before)) # Create and send train config train_config = sy.TrainConfig( batch_size=8, model=None, loss_fn=None, model_id=model_id, loss_fn_id=loss_id, lr=0.05, weight_decay=0.01, ) fed_client.set_obj(model_ow) fed_client.set_obj(loss_ow) fed_client.set_obj(train_config) fed_client.optimizer = None for curr_round in range(12): loss = fed_client.fit(dataset_key=dataset_key) if PRINT_IN_UNITTESTS and curr_round % 4 == 0: # pragma: no cover print("-" * 50) print("Iteration %s: alice's loss: %s" % (curr_round, loss)) new_model = fed_client.get_obj(model_id) pred = new_model.obj(data) loss_after = loss_fn(target=target, pred=pred) if PRINT_IN_UNITTESTS: # pragma: no cover: print("Loss after training: {}".format(loss_after)) assert loss_after < loss_before
def test_fit(fit_dataset_key, epochs): data, target = utils.create_gaussian_mixture_toy_data(nr_samples=100) fed_client = federated.FederatedClient() dataset = sy.BaseDataset(data, target) dataset_key = "gaussian_mixture" fed_client.add_dataset(dataset, key=dataset_key) def loss_fn(target, pred): return torch.nn.functional.cross_entropy(input=pred, target=target) class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = torch.nn.Linear(2, 3) self.fc2 = torch.nn.Linear(3, 2) torch.nn.init.xavier_normal_(self.fc1.weight) torch.nn.init.xavier_normal_(self.fc2.weight) def forward(self, x): x = torch.nn.functional.relu(self.fc1(x)) x = torch.nn.functional.relu(self.fc2(x)) return x model_untraced = Net() model = torch.jit.trace(model_untraced, data) model_id = 0 model_ow = pointers.ObjectWrapper(obj=model, id=model_id) loss_id = 1 loss_ow = pointers.ObjectWrapper(obj=loss_fn, id=loss_id) pred = model(data) loss_before = loss_fn(target=target, pred=pred) if PRINT_IN_UNITTESTS: # pragma: no cover print("Loss before training: {}".format(loss_before)) # Create and send train config train_config = sy.TrainConfig( batch_size=8, model=None, loss_fn=None, model_id=model_id, loss_fn_id=loss_id, optimizer_args={ "lr": 0.05, "weight_decay": 0.01 }, epochs=epochs, ) fed_client.set_obj(model_ow) fed_client.set_obj(loss_ow) fed_client.set_obj(train_config) fed_client.optimizer = None train_model(fed_client, fit_dataset_key, available_dataset_key=dataset_key, nr_rounds=3) if dataset_key == fit_dataset_key: loss_after = evaluate_model(fed_client, model_id, loss_fn, data, target) if PRINT_IN_UNITTESTS: # pragma: no cover print("Loss after training: {}".format(loss_after)) if loss_after >= loss_before: # pragma: no cover if PRINT_IN_UNITTESTS: print("Loss not reduced, train more: {}".format(loss_after)) train_model(fed_client, fit_dataset_key, available_dataset_key=dataset_key, nr_rounds=10) loss_after = evaluate_model(fed_client, model_id, loss_fn, data, target) assert loss_after < loss_before
def test_train_config_with_jit_trace_sync(hook, start_proc): # pragma: no cover kwargs = { "id": "sync_fit", "host": "localhost", "port": 9000, "hook": hook } # data = torch.tensor([[-1, 2.0], [0, 1.1], [-1, 2.1], [0, 1.2]], requires_grad=True) # target = torch.tensor([[1], [0], [1], [0]]) data, target = utils.create_gaussian_mixture_toy_data(100) dataset = sy.BaseDataset(data, target) dataset_key = "gaussian_mixture" process_remote_worker = start_proc(WebsocketServerWorker, dataset=(dataset, dataset_key), **kwargs) time.sleep(0.1) local_worker = WebsocketClientWorker(**kwargs) @hook.torch.jit.script def loss_fn(pred, target): return ((target.view(pred.shape).float() - pred.float())**2).mean() class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(2, 3) self.fc2 = nn.Linear(3, 2) self.fc3 = nn.Linear(2, 1) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model_untraced = Net() model = torch.jit.trace(model_untraced, data) pred = model(data) loss_before = loss_fn(pred=pred, target=target) # Create and send train config train_config = sy.TrainConfig(model=model, loss_fn=loss_fn, batch_size=2, epochs=1) train_config.send(local_worker) for epoch in range(5): loss = local_worker.fit(dataset_key=dataset_key) if PRINT_IN_UNITTESTS: # pragma: no cover print("-" * 50) print("Iteration %s: alice's loss: %s" % (epoch, loss)) new_model = train_config.model_ptr.get() # assert that the new model has updated (modified) parameters assert not (model.fc1._parameters["weight"] == new_model.obj.fc1._parameters["weight"]).all() assert not (model.fc2._parameters["weight"] == new_model.obj.fc2._parameters["weight"]).all() assert not (model.fc3._parameters["weight"] == new_model.obj.fc3._parameters["weight"]).all() assert not (model.fc1._parameters["bias"] == new_model.obj.fc1._parameters["bias"]).all() assert not (model.fc2._parameters["bias"] == new_model.obj.fc2._parameters["bias"]).all() assert not (model.fc3._parameters["bias"] == new_model.obj.fc3._parameters["bias"]).all() new_model.obj.eval() pred = new_model.obj(data) loss_after = loss_fn(pred=pred, target=target) if PRINT_IN_UNITTESTS: # pragma: no cover print("Loss before training: {}".format(loss_before)) print("Loss after training: {}".format(loss_after)) local_worker.ws.shutdown() del local_worker time.sleep(0.1) process_remote_worker.terminate() assert loss_after < loss_before