def train(model, device, federated_train_loader): model.train() nr_batches = federate_after_n_batches models = {} loss_values = {} iter(federated_train_loader) # initialize iterators batches = get_next_batches(federated_train_loader, nr_batches) counter = 0 while True: logger.debug("Starting training round, batches [{}, {}]".format( counter, counter + nr_batches)) data_for_all_workers = True for worker in batches: curr_batches = batches[worker] if curr_batches: models[worker], loss_values[worker] = train_on_batches( worker, curr_batches, model, device, learning_rate) else: data_for_all_workers = False counter += nr_batches if not data_for_all_workers: logger.debug("At least one worker ran out of data, stopping.") break model = utils.federated_avg(models) batches = get_next_batches(federated_train_loader, nr_batches) return model
def train(): for data_index in range(len(remote_dataset[0])-1): for remote_index in range(len(compute_nodes)): data, target = remote_dataset[remote_index][data_index] models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index]) for model in models: model.get() return utils.federated_avg({ "sm1": models[0], "sm2": models[1] })
def train(self): for data_index in range(len(self.remote_dataset[0]) - 1): # update remote models for remote_index in range(len(self.compute_nodes)): data, target = self.remote_dataset[remote_index][data_index] self.models[remote_index] = train_process(data, target, self.models[remote_index], self.optimizers[remote_index]) for model in self.models: model.get() return utils.federated_avg({ "bob": self.models[0], "alice": self.models[1] })
def train_on_devices(remote_dataset, devices, models, optimizers): # iterate through each worker's dataset seperately for data_index in range(len(remote_dataset[0]) - 1): for device_index in range(len(devices)): data, target = remote_dataset[device_index][data_index] models[device_index] = update_model(data, target, models[device_index], optimizers[device_index]) for model in models: model.get() return utils.federated_avg({'bob': models[0], 'alice': models[1]})
def fed_avg_every_n_iters(model_pointers, iter, federate_after_n_batches): models_local = {} if(iter % args.federate_after_n_batches == 0): for worker_name, model_pointer in model_pointers.items(): # #need to assign the model to the worker it belongs to. models_local[worker_name] = model_pointer.copy().get() model_avg = utils.federated_avg(models_local) for worker in workers_virtual: model_copied_avg = model_avg.copy() model_ptr = model_copied_avg.send(worker) model_pointers[worker.id] = model_ptr return(model_pointers)
async def train(train_loader, valid_loader, model, architect, w_optim, alpha_optim, lr, epoch): top1 = utils.AverageMeter() top5 = utils.AverageMeter() losses = utils.AverageMeter() cur_step = epoch * len(train_loader) writer.add_scalar('train/lr', lr, cur_step) model.train() for step in range(len(remote_train_data[0]) - 1): results = await asyncio.gather(*[ update(step, i, model, alpha_optim, w_optim, architect, lr) for i in range(len(workers)) ]) models = {} for i, r in enumerate(results): models[r[0]] = r[1] losses.update(r[2].item(), r[5]) top1.update(r[3].item(), r[5]) top5.update(r[4].item(), r[5]) writer.add_scalar('train/loss ' + str(i), r[2].item(), cur_step) writer.add_scalar('train/top1 ' + str(i), r[3].item(), cur_step) writer.add_scalar('train/top5 ' + str(i), r[4].item(), cur_step) model = federated_avg(models) if step % config.print_freq == 0 or step == len(train_loader) - 1: logger.info( "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( epoch + 1, config.epochs, step, len(train_loader) - 1, losses=losses, top1=top1, top5=top5)) cur_step += 1 logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format( epoch + 1, config.epochs, top1.avg))
def test_federated_avg(): class Net(th.nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = th.nn.Linear(2, 2) net1 = Net() net2 = Net() net3 = Net() models = {} models[0] = net1 models[1] = net2 models[2] = net3 avg_model = utils.federated_avg(models) assert avg_model != net1 assert (avg_model.fc1.weight.data != net1.fc1.weight.data).all() assert (avg_model.fc1.bias.data != net1.fc1.bias.data).all()
def train(model, device, federated_train_loader, lr, federate_after_n_batches, abort_after_one=False): model.train() nr_batches = federate_after_n_batches models = {} loss_values = {} iter(federated_train_loader) # initialize iterators batches = get_next_batches(federated_train_loader, nr_batches) counter = 0 while True: print( f"Starting training round, batches [{counter}, {counter + nr_batches}]" ) data_for_all_workers = True for worker in batches: curr_batches = batches[worker] if curr_batches: models[worker], loss_values[worker] = train_on_batches( worker, curr_batches, model, device, lr) else: data_for_all_workers = False counter += nr_batches if not data_for_all_workers: print("At least one worker ran out of data, stopping.") break model = utils.federated_avg(models) batches = get_next_batches(federated_train_loader, nr_batches) if abort_after_one: break return model
def FederatedTrainer(model, device, fed_data_loader, lr, fed_after_n_batches): model.train() nr_batches = fed_after_n_batches data_for_all_workers = True batch_counter = 0 models = {} losses = {} iter(fed_data_loader) batches = GetNextBatch(fed_data_loader, nr_batches) while True: mylogger.logger.debug( "Starting training round, batches [{}, {}]".format( batch_counter, batch_counter + nr_batches)) for worker in batches: curr_batches = batches[worker] if curr_batches: models[worker], losses[worker] = TrainOnBatches( worker, curr_batches, model, device, lr) else: data_for_all_workers = False batch_counter += 1 if not data_for_all_workers: mylogger.logger.debug("stopping.") break model = SyftUtils.federated_avg(models) batches = GetNextBatch(fed_data_loader, nr_batches) return model
async def main(): args = define_and_get_arguments() hook = sy.TorchHook(torch) kwargs_websocket = {"hook": hook, "verbose": args.verbose, "host": "0.0.0.0"} alice = websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket) bob = websocket_client.WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket) charlie = websocket_client.WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket) testing = websocket_client.WebsocketClientWorker(id="testing", port=8780, **kwargs_websocket) for wcw in [alice, bob, charlie, testing]: wcw.clear_objects_remote() worker_instances = [alice, bob, charlie] use_cuda = args.cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") model = Net().to(device) traced_model = torch.jit.trace(model, torch.zeros([1, 1, 28, 28], dtype=torch.float).to(device)) learning_rate = args.lr for curr_round in range(1, args.training_rounds + 1): logger.info("Training round %s/%s", curr_round, args.training_rounds) results = await asyncio.gather( *[ fit_model_on_worker( worker=worker, traced_model=traced_model, batch_size=args.batch_size, curr_round=curr_round, max_nr_batches=args.federate_after_n_batches, lr=learning_rate, ) for worker in worker_instances ] ) models = {} loss_values = {} test_models = curr_round % 10 == 1 or curr_round == args.training_rounds if test_models: logger.info("Evaluating models") np.set_printoptions(formatter={"float": "{: .0f}".format}) for worker_id, worker_model, _ in results: evaluate_model_on_worker( model_identifier="Model update " + worker_id, worker=testing, dataset_key="mnist_testing", model=worker_model, nr_bins=10, batch_size=128, device=args.device, print_target_hist=False, ) # Federate models (note that this will also change the model in models[0] for worker_id, worker_model, worker_loss in results: if worker_model is not None: models[worker_id] = worker_model loss_values[worker_id] = worker_loss traced_model = utils.federated_avg(models) if test_models: evaluate_model_on_worker( model_identifier="Federated model", worker=testing, dataset_key="mnist_testing", model=traced_model, nr_bins=10, batch_size=128, device=args.device, print_target_hist=False, ) # decay learning rate learning_rate = max(0.98 * learning_rate, args.lr * 0.01) if args.save_model: torch.save(model.state_dict(), "mnist_cnn.pt")
r = open("memory_cpu.txt", "a+") r.write("Iteration Number " + str(iteration) + '\n') r.write('physical memory use: (in MB)' + str(p.memory_info()[0] / 2.**20) + '\n') r.write('physical memory use: (in MB)' + str(p.memory_percent()) + '\n') r.write('percentage utilization of this process in the system' + str(p.cpu_percent(interval=None)) + '\n') r.close() #Loading all the updated models received from the clients -we have for beacause we trained with 4 clients- for i in range(int(0.7 * len(clients))): print(i) models[i] = pickle.load(open('model' + str(i) + '.sav', 'rb')) #doing the federated avg federated_model = utils.federated_avg(models) #Saving the global model to be sent again to the clients filename = 'initial_model.sav' pickle.dump(federated_model, open(filename, 'wb')) #When all communication rounds end, the final model is sent to all the clients that are still connected for client in clients: t = Thread(target=last_model, args=(client[0], client[1])) trds.append(t) t.start() for tr in trds: tr.join()
def async_train(model, device, federated_train_loader, lr, federate_after_n_batches, abort_after_one=False): model.train() nr_batches = federate_after_n_batches models = {} loss_values = {} # Create queue to save results ret_queue = queue.Queue() iter(federated_train_loader) # initialize iterators batches = get_next_batches(federated_train_loader, nr_batches) counter = 0 while True: logger.debug( f"Starting training round, batches [{counter}, {counter + nr_batches}]" ) data_for_all_workers = True # Create thread pool threads_pool = [] # clear queue ret_queue.queue.clear() # for each work to start train thread for worker in batches: curr_batches = batches[worker] if curr_batches: # Create new threads p_thread = TrainThread( [ret_queue, [worker, curr_batches, model, device, lr]]) # append to threads pool threads_pool.append(p_thread) # The start() method starts a thread by calling the run method. p_thread.start() else: data_for_all_workers = False # The join() waits for all threads to terminate. for p_thread in threads_pool: p_thread.join() # get all results from queue while not ret_queue.empty(): q_data = ret_queue.get() models[q_data[0]] = q_data[1] loss_values[q_data[0]] = q_data[2] counter += nr_batches if not data_for_all_workers: logger.debug("At least one worker ran out of data, stopping.") break logger.info("Execute federated avg.") model = utils.federated_avg(models) logger.info("Get next batches.") batches = get_next_batches(federated_train_loader, nr_batches) if abort_after_one: break return model
async def main(): args = define_and_get_arguments() hook = sy.TorchHook(torch) kwargs_websocket = { "hook": hook, "verbose": args.verbose, "host": "127.0.0.1" } alice = websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket) bob = websocket_client.WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket) charlie = websocket_client.WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket) testing = websocket_client.WebsocketClientWorker(id="testing", port=8780, **kwargs_websocket) for wcw in [alice, bob, charlie, testing]: wcw.clear_objects_remote() worker_instances = [alice, bob, charlie] use_cuda = args.cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") model = Net().to(device) traced_model = torch.jit.trace( model, torch.zeros([1, 47], dtype=torch.long).to(device)) learning_rate = args.lr for epoch in range(1, 11): logger.info("Training epoch %s/%s", epoch, 10) results = await asyncio.gather(*[ fit_model_on_worker( worker=worker, traced_model=traced_model, batch_size=args.batch_size, epoch=epoch, max_nr_batches=-1, lr=learning_rate, ) for worker in worker_instances ]) models = {} loss_values = {} test_models = epoch > 0 and epoch <= 10 if test_models: logger.info("Evaluating models") np.set_printoptions(formatter={"float": "{: .0f}".format}) for worker_id, worker_model, _ in results: evaluate_model_on_worker( model_identifier="Model update " + worker_id, worker=testing, dataset_key="dga_testing", model=worker_model, nr_bins=2, batch_size=500, device=device, print_target_hist=False, ) for worker_id, worker_model, worker_loss in results: if worker_model is not None: models[worker_id] = worker_model loss_values[worker_id] = worker_loss traced_model = utils.federated_avg(models) if test_models: evaluate_model_on_worker( model_identifier="Federated model", worker=testing, dataset_key="dga_testing", model=traced_model, nr_bins=2, batch_size=500, device=device, print_target_hist=False, ) # decay learning rate learning_rate = max(0.98 * learning_rate, args.lr * 0.01)
async def main(): args = define_and_get_arguments() hook = sy.TorchHook(torch) if (args.localworkers): # ----------------------------- This is for localhost workers -------------------------------- kwargs_websocket = { "hook": hook, "verbose": args.verbose, "host": "0.0.0.0" } alice = websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket) bob = websocket_client.WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket) charlie = websocket_client.WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket) testing = websocket_client.WebsocketClientWorker(id="testing", port=8780, **kwargs_websocket) else: # ----------------------------- This is for remote workers ------------------------------------ kwargs_websocket_alice = {"host": "128.226.78.195", "hook": hook} alice = websocket_client.WebsocketClientWorker( id="alice", port=8777, **kwargs_websocket_alice) kwargs_websocket_bob = {"host": "128.226.77.222", "hook": hook} bob = websocket_client.WebsocketClientWorker(id="bob", port=8777, **kwargs_websocket_bob) kwargs_websocket_charlie = {"host": "128.226.88.120", "hook": hook} charlie = websocket_client.WebsocketClientWorker( id="charlie", port=8777, **kwargs_websocket_charlie) # kwargs_websocket_testing = {"host": "128.226.77.111", "hook": hook} kwargs_websocket_testing = {"host": "128.226.88.210", "hook": hook} testing = websocket_client.WebsocketClientWorker( id="testing", port=8777, **kwargs_websocket_testing) for wcw in [alice, bob, charlie, testing]: wcw.clear_objects_remote() worker_instances = [alice, bob, charlie] use_cuda = args.cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") model = Net().to(device) if (os.path.isfile('mnist_cnn_asyn.pt')): model.load_state_dict(torch.load("mnist_cnn_asyn.pt")) model.eval() traced_model = torch.jit.trace( model, torch.zeros([1, 1, 28, 28], dtype=torch.float).to(device)) learning_rate = args.lr # Execute traning and test process round for curr_round in range(1, args.training_rounds + 1): logger.info("Training round %s/%s", curr_round, args.training_rounds) results = await asyncio.gather(*[ fit_model_on_worker( worker=worker, traced_model=traced_model, batch_size=args.batch_size, curr_round=curr_round, max_nr_batches=args.federate_after_n_batches, lr=learning_rate, ) for worker in worker_instances ]) models = {} loss_values = {} # Apply evaluate model for each 10 round and at the last round test_models = curr_round % 10 == 1 or curr_round == args.training_rounds if test_models: logger.info("Evaluating models") np.set_printoptions(formatter={"float": "{: .0f}".format}) for worker_id, worker_model, _ in results: evaluate_model_on_worker( model_identifier="Model update " + worker_id, worker=testing, dataset_key="mnist_testing", model=worker_model, nr_bins=10, batch_size=128, device=device, print_target_hist=False, ) # Federate models (note that this will also change the model in models[0] for worker_id, worker_model, worker_loss in results: if worker_model is not None: models[worker_id] = worker_model loss_values[worker_id] = worker_loss traced_model = utils.federated_avg(models) if test_models: evaluate_model_on_worker( model_identifier="Federated model", worker=testing, dataset_key="mnist_testing", model=traced_model, nr_bins=10, batch_size=128, device=device, print_target_hist=False, ) # save indermediate model model_dir = "models_asyn" if (not os.path.exists(model_dir)): os.makedirs(model_dir) model_name = "{}/mnist_cnn_{}.pt".format(model_dir, curr_round) torch.save(traced_model.state_dict(), model_name) # decay learning rate learning_rate = max(0.98 * learning_rate, args.lr * 0.01) if args.save_model: torch.save(traced_model.state_dict(), "mnist_cnn_asyn.pt")