async def main(): args = define_and_get_arguments() hook = sy.TorchHook(torch) kwargs_websocket = {"hook": hook, "verbose": args.verbose, "host": "0.0.0.0"} alice = websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket) bob = websocket_client.WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket) charlie = websocket_client.WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket) testing = websocket_client.WebsocketClientWorker(id="testing", port=8780, **kwargs_websocket) for wcw in [alice, bob, charlie, testing]: wcw.clear_objects_remote() worker_instances = [alice, bob, charlie] use_cuda = args.cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") model = Net().to(device) traced_model = torch.jit.trace(model, torch.zeros([1, 1, 28, 28], dtype=torch.float).to(device)) learning_rate = args.lr for curr_round in range(1, args.training_rounds + 1): logger.info("Training round %s/%s", curr_round, args.training_rounds) results = await asyncio.gather( *[ fit_model_on_worker( worker=worker, traced_model=traced_model, batch_size=args.batch_size, curr_round=curr_round, max_nr_batches=args.federate_after_n_batches, lr=learning_rate, ) for worker in worker_instances ] ) models = {} loss_values = {} test_models = curr_round % 10 == 1 or curr_round == args.training_rounds if test_models: logger.info("Evaluating models") np.set_printoptions(formatter={"float": "{: .0f}".format}) for worker_id, worker_model, _ in results: evaluate_model_on_worker( model_identifier="Model update " + worker_id, worker=testing, dataset_key="mnist_testing", model=worker_model, nr_bins=10, batch_size=128, device=args.device, print_target_hist=False, ) # Federate models (note that this will also change the model in models[0] for worker_id, worker_model, worker_loss in results: if worker_model is not None: models[worker_id] = worker_model loss_values[worker_id] = worker_loss traced_model = utils.federated_avg(models) if test_models: evaluate_model_on_worker( model_identifier="Federated model", worker=testing, dataset_key="mnist_testing", model=traced_model, nr_bins=10, batch_size=128, device=args.device, print_target_hist=False, ) # decay learning rate learning_rate = max(0.98 * learning_rate, args.lr * 0.01) if args.save_model: torch.save(model.state_dict(), "mnist_cnn.pt")
import torch import syft as sy from syft.workers import websocket_client as wsc hook = sy.TorchHook(torch) data = torch.tensor([[0., 0.], [0., 2.]], requires_grad=True) # target = torch.tensor([[0.], [0.]], requires_grad=True) kwargs_websocket = {"hook": hook, "verbose": True, "host": "127.0.0.1"} mbp = wsc.WebsocketClientWorker(id="mbp", port=8001, **kwargs_websocket) print(mbp.list_objects_remote()) data.send(mbp) print(mbp.list_objects_remote()) if __name__ == "__main__": pass
async def main(): args = define_and_get_arguments() hook = sy.TorchHook(torch) kwargs_websocket = { "hook": hook, "verbose": args.verbose, "host": "127.0.0.1" } alice = websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket) bob = websocket_client.WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket) charlie = websocket_client.WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket) testing = websocket_client.WebsocketClientWorker(id="testing", port=8780, **kwargs_websocket) for wcw in [alice, bob, charlie, testing]: wcw.clear_objects_remote() worker_instances = [alice, bob, charlie] use_cuda = args.cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") model = Net().to(device) traced_model = torch.jit.trace( model, torch.zeros([1, 47], dtype=torch.long).to(device)) learning_rate = args.lr for epoch in range(1, 11): logger.info("Training epoch %s/%s", epoch, 10) results = await asyncio.gather(*[ fit_model_on_worker( worker=worker, traced_model=traced_model, batch_size=args.batch_size, epoch=epoch, max_nr_batches=-1, lr=learning_rate, ) for worker in worker_instances ]) models = {} loss_values = {} test_models = epoch > 0 and epoch <= 10 if test_models: logger.info("Evaluating models") np.set_printoptions(formatter={"float": "{: .0f}".format}) for worker_id, worker_model, _ in results: evaluate_model_on_worker( model_identifier="Model update " + worker_id, worker=testing, dataset_key="dga_testing", model=worker_model, nr_bins=2, batch_size=500, device=device, print_target_hist=False, ) for worker_id, worker_model, worker_loss in results: if worker_model is not None: models[worker_id] = worker_model loss_values[worker_id] = worker_loss traced_model = utils.federated_avg(models) if test_models: evaluate_model_on_worker( model_identifier="Federated model", worker=testing, dataset_key="dga_testing", model=traced_model, nr_bins=2, batch_size=500, device=device, print_target_hist=False, ) # decay learning rate learning_rate = max(0.98 * learning_rate, args.lr * 0.01)
async def main(): args = define_and_get_arguments() hook = sy.TorchHook(torch) if (args.localworkers): # ----------------------------- This is for localhost workers -------------------------------- kwargs_websocket = { "hook": hook, "verbose": args.verbose, "host": "0.0.0.0" } alice = websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket) bob = websocket_client.WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket) charlie = websocket_client.WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket) testing = websocket_client.WebsocketClientWorker(id="testing", port=8780, **kwargs_websocket) else: # ----------------------------- This is for remote workers ------------------------------------ kwargs_websocket_alice = {"host": "128.226.78.195", "hook": hook} alice = websocket_client.WebsocketClientWorker( id="alice", port=8777, **kwargs_websocket_alice) kwargs_websocket_bob = {"host": "128.226.77.222", "hook": hook} bob = websocket_client.WebsocketClientWorker(id="bob", port=8777, **kwargs_websocket_bob) kwargs_websocket_charlie = {"host": "128.226.88.120", "hook": hook} charlie = websocket_client.WebsocketClientWorker( id="charlie", port=8777, **kwargs_websocket_charlie) # kwargs_websocket_testing = {"host": "128.226.77.111", "hook": hook} kwargs_websocket_testing = {"host": "128.226.88.210", "hook": hook} testing = websocket_client.WebsocketClientWorker( id="testing", port=8777, **kwargs_websocket_testing) for wcw in [alice, bob, charlie, testing]: wcw.clear_objects_remote() worker_instances = [alice, bob, charlie] use_cuda = args.cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") model = Net().to(device) if (os.path.isfile('mnist_cnn_asyn.pt')): model.load_state_dict(torch.load("mnist_cnn_asyn.pt")) model.eval() traced_model = torch.jit.trace( model, torch.zeros([1, 1, 28, 28], dtype=torch.float).to(device)) learning_rate = args.lr # Execute traning and test process round for curr_round in range(1, args.training_rounds + 1): logger.info("Training round %s/%s", curr_round, args.training_rounds) results = await asyncio.gather(*[ fit_model_on_worker( worker=worker, traced_model=traced_model, batch_size=args.batch_size, curr_round=curr_round, max_nr_batches=args.federate_after_n_batches, lr=learning_rate, ) for worker in worker_instances ]) models = {} loss_values = {} # Apply evaluate model for each 10 round and at the last round test_models = curr_round % 10 == 1 or curr_round == args.training_rounds if test_models: logger.info("Evaluating models") np.set_printoptions(formatter={"float": "{: .0f}".format}) for worker_id, worker_model, _ in results: evaluate_model_on_worker( model_identifier="Model update " + worker_id, worker=testing, dataset_key="mnist_testing", model=worker_model, nr_bins=10, batch_size=128, device=device, print_target_hist=False, ) # Federate models (note that this will also change the model in models[0] for worker_id, worker_model, worker_loss in results: if worker_model is not None: models[worker_id] = worker_model loss_values[worker_id] = worker_loss traced_model = utils.federated_avg(models) if test_models: evaluate_model_on_worker( model_identifier="Federated model", worker=testing, dataset_key="mnist_testing", model=traced_model, nr_bins=10, batch_size=128, device=device, print_target_hist=False, ) # save indermediate model model_dir = "models_asyn" if (not os.path.exists(model_dir)): os.makedirs(model_dir) model_name = "{}/mnist_cnn_{}.pt".format(model_dir, curr_round) torch.save(traced_model.state_dict(), model_name) # decay learning rate learning_rate = max(0.98 * learning_rate, args.lr * 0.01) if args.save_model: torch.save(traced_model.state_dict(), "mnist_cnn_asyn.pt")