def train_and_validate_drfa_centered(Clients, Server): log('start training and validation with Federated setting in a centered way.' ) tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.') for oc in range(Server.args.graph.n_nodes): Server.lambda_vector[oc] = Clients[oc].args.num_samples_per_epoch Server.lambda_vector /= Server.lambda_vector.sum() # Number of communication rounds in federated setting should be defined for n_c in range(Server.args.num_comms): Server.args.rounds_comm += 1 Server.args.local_index += 1 Server.args.quant_error = 0.0 # Preset variables for this round of communication Server.zero_grad() Server.reset_tracker(Server.local_val_tracker) Server.reset_tracker(Server.global_val_tracker) Server.reset_tracker(Server.global_test_tracker) if Server.args.fed_personal: Server.reset_tracker(Server.local_personal_val_tracker) Server.reset_tracker(Server.global_personal_val_tracker) # Configuring the devices for this round of communication log("Starting round {} of training".format(n_c + 1)) online_clients = set_online_clients_centered(Server.args) Server.args.drfa_gamma *= 0.9 for oc in online_clients: Clients[oc].model.load_state_dict(Server.model.state_dict()) Clients[oc].args.rounds_comm = Server.args.rounds_comm local_steps = 0 is_sync = False do_validate_centered(Clients[oc].args, Server.model, Server.criterion, Server.metrics, Server.optimizer, Clients[oc].train_loader, Server.global_val_tracker, val=False, local=False) if Server.args.per_class_acc: Clients[oc].reset_tracker(Clients[oc].local_val_tracker) Clients[oc].reset_tracker(Clients[oc].global_val_tracker) if Server.args.fed_personal: Clients[oc].reset_tracker( Clients[oc].local_personal_val_tracker) Clients[oc].reset_tracker( Clients[oc].global_personal_val_tracker) do_validate_centered( Clients[oc].args, Server.model, Server.criterion, Server.metrics, Server.optimizer, Clients[oc].val_loader, Clients[oc].global_personal_val_tracker, val=True, local=False) do_validate_centered(Clients[oc].args, Server.model, Server.criterion, Server.metrics, Server.optimizer, Clients[oc].train_loader, Clients[oc].global_val_tracker, val=False, local=False) if Server.args.fed_personal: do_validate_centered(Clients[oc].args, Server.model, Server.criterion, Server.metrics, Server.optimizer, Clients[oc].val_loader, Server.global_personal_val_tracker, val=True, local=False) if Server.args.federated_type == 'perfedavg': for _input_val, _target_val in Clients[oc].val_loader1: _input_val, _target_val = load_data_batch( Clients[oc].args, _input_val, _target_val, tracker) break k = torch.randint(low=1, high=Server.args.local_step, size=(1, )).item() while not is_sync: if Server.args.arch == 'rnn': Clients[oc].model.init_hidden(Server.args.batch_size) for _input, _target in Clients[oc].train_loader: local_steps += 1 if k == local_steps: Clients[oc].kth_model.load_state_dict( Clients[oc].model.state_dict()) Clients[oc].model.train() # update local step. logging_load_time(tracker) # update local index and get local step Clients[oc].args.local_index += 1 Clients[oc].args.local_data_seen += len(_target) get_current_epoch(Clients[oc].args) local_step = get_current_local_step(Clients[oc].args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(Clients[oc].args, Clients[oc].optimizer, Clients[oc].scheduler) # load data _input, _target = load_data_batch(Clients[oc].args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: is_sync = is_sync_fed(Clients[oc].args) break # inference and get current performance. Clients[oc].optimizer.zero_grad() loss, performance = inference(Clients[oc].model, Clients[oc].criterion, Clients[oc].metrics, _input, _target, rnn=Server.args.arch in ['rnn']) # compute gradient and do local SGD step. loss.backward() if Clients[oc].args.federated_type == 'fedgate': # Update gradients with control variates for client_param, delta_param in zip( Clients[oc].model.parameters(), Clients[oc].model_delta.parameters()): client_param.grad.data -= delta_param.data elif Clients[oc].args.federated_type == 'scaffold': for cp, ccp, scp in zip( Clients[oc].model.parameters(), Clients[oc].model_client_control.parameters(), Server.model_server_control.parameters()): cp.grad.data += scp.data - ccp.data elif Clients[oc].args.federated_type == 'fedprox': # Adding proximal gradients and loss for fedprox for client_param, server_param in zip( Clients[oc].model.parameters(), Server.model.parameters()): loss += Clients[ oc].args.fedprox_mu / 2 * torch.norm( client_param.data - server_param.data) client_param.grad.data += Clients[ oc].args.fedprox_mu * (client_param.data - server_param.data) if 'robust' in Clients[oc].args.arch: Clients[oc].model.noise.grad.data *= -1 Clients[oc].optimizer.step( apply_lr=True, apply_in_momentum=Clients[oc].args.in_momentum, apply_out_momentum=False) if 'robust' in Clients[oc].args.arch: if torch.norm(Clients[oc].model.noise.data) > 1: Clients[oc].model.noise.data /= torch.norm( Clients[oc].model.noise.data) if Clients[oc].args.federated_type == 'perfedavg': # _input_val, _target_val = Clients[oc].load_next_val_batch() lr = adjust_learning_rate( Clients[oc].args, Clients[oc].optimizer, Clients[oc].scheduler, lr_external=Clients[oc].args.perfedavg_beta) if _input_val.size(0) == 1: is_sync = is_sync_fed(Clients[oc].args) break Clients[oc].optimizer.zero_grad() loss, performance = inference(Clients[oc].model, Clients[oc].criterion, Clients[oc].metrics, _input_val, _target_val) loss.backward() Clients[oc].optimizer.step( apply_lr=True, apply_in_momentum=Clients[oc].args.in_momentum, apply_out_momentum=False) # reset load time for the tracker. tracker['start_load_time'] = time.time() # model_local = deepcopy(model_client) is_sync = is_sync_fed(Clients[oc].args) if is_sync: break do_validate_centered(Clients[oc].args, Clients[oc].model, Clients[oc].criterion, Clients[oc].metrics, Clients[oc].optimizer, Clients[oc].train_loader, Server.local_val_tracker, val=False, local=True) if Server.args.per_class_acc: do_validate_centered(Clients[oc].args, Clients[oc].model, Clients[oc].criterion, Clients[oc].metrics, Clients[oc].optimizer, Clients[oc].train_loader, Clients[oc].local_val_tracker, val=False, local=True) if Server.args.fed_personal: do_validate_centered( Clients[oc].args, Clients[oc].model, Clients[oc].criterion, Clients[oc].metrics, Clients[oc].optimizer, Clients[oc].val_loader, Clients[oc].local_personal_val_tracker, val=True, local=True) if Server.args.fed_personal: do_validate_centered(Clients[oc].args, Clients[oc].model, Clients[oc].criterion, Clients[oc].metrics, Clients[oc].optimizer, Clients[oc].val_loader, Server.local_personal_val_tracker, val=True, local=True) # Sync the model server based on model_clients tracker['start_sync_time'] = time.time() Server.args.global_index += 1 logging_sync_time(tracker) if Server.args.federated_type == 'scaffold': scaffold_aggregation_centered(Clients, Server, online_clients, local_steps, lr) elif Server.args.federated_type == 'fedgate': fedgate_aggregation_centered(Clients, Server, online_clients, local_steps, lr) elif Server.args.federated_type == 'qsparse': qsparse_aggregation_centered(Clients, Server, online_clients, local_steps, lr) else: fedavg_aggregation_centered(Clients, Server, online_clients, Server.lambda_vector.numpy()) # Aggregate Kth models aggregate_kth_model_centered(Clients, Server, online_clients) # Log performance # Client training performance log_validation_centered(Server.args, Server.local_val_tracker, val=False, local=True) # Server training performance log_validation_centered(Server.args, Server.global_val_tracker, val=False, local=False) if Server.args.fed_personal: # Client validation performance log_validation_centered(Server.args, Server.local_personal_val_tracker, val=True, local=True) # Server validation performance log_validation_centered(Server.args, Server.global_personal_val_tracker, val=True, local=False) # Per client stats if Server.args.per_class_acc: log_validation_per_client_centered(Server.args, Clients, online_clients, val=False, local=False) log_validation_per_client_centered(Server.args, Clients, online_clients, val=False, local=True) if Server.args.fed_personal: log_validation_per_client_centered(Server.args, Clients, online_clients, val=True, local=False) log_validation_per_client_centered(Server.args, Clients, online_clients, val=True, local=True) # Test on server do_validate_centered(Server.args, Server.model, Server.criterion, Server.metrics, Server.optimizer, Server.test_loader, Server.global_test_tracker, val=False, local=False) log_test_centered(Server.args, Server.global_test_tracker) online_clients_lambda = set_online_clients_centered(Server.args) loss_tensor = torch.zeros(Server.args.graph.n_nodes) num_online_clients = len(online_clients_lambda) for ocl in online_clients_lambda: for _input, _target in Clients[ocl].train_loader: _input, _target = load_data_batch(Clients[ocl].args, _input, _target, tracker) loss, _ = inference(Server.kth_model, Clients[ocl].criterion, Clients[ocl].metrics, _input, _target, rnn=Server.args.arch in ['rnn']) break loss_tensor[ocl] = loss * (Server.args.graph.n_nodes / num_online_clients) Server.lambda_vector += Server.args.drfa_gamma * Server.args.local_step * loss_tensor lambda_vector = projection_simplex_sort( Server.lambda_vector.detach().numpy()) print(lambda_vector) # Avoid zero probability lambda_zeros = np.argwhere(lambda_vector <= 1e-3) if len(lambda_zeros) > 0: lambda_vector[lambda_zeros[0]] = 1e-3 lambda_vector /= np.sum(lambda_vector) Server.lambda_vector = torch.tensor(lambda_vector) # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() # validate the model at the server # if args.graph.rank == 0: # do_test(args, model_server, optimizer, criterion, metrics, test_loader) # do_validate_test(args, model_server, optimizer, criterion, metrics, test_loader) return
def train_and_validate_federated_drfa(client): """The training scheme of Distributionally Robust Federated Learning DRFA. paper: https://papers.nips.cc/paper/2020/hash/ac450d10e166657ec8f93a1b65ca1b14-Abstract.html """ log('start training and validation with Federated setting.', client.args.debug) if client.args.evaluate and client.args.graph.rank == 0: # Do the testing on the server and return do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') return # Initialize lambda variable proportianate to their sample size if client.args.graph.rank == 0: gather_list_size = [ torch.tensor(0.0) for _ in range(client.args.graph.n_nodes) ] dist.gather(torch.tensor(client.args.num_samples_per_epoch, dtype=torch.float32), gather_list=gather_list_size, dst=0) lambda_vector = torch.stack( gather_list_size) / client.args.train_dataset_size else: dist.gather(torch.tensor(client.args.num_samples_per_epoch, dtype=torch.float32), dst=0) lambda_vector = torch.tensor([1 / client.args.graph.n_nodes] * client.args.graph.n_nodes) tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.', client.args.debug) # Number of communication rounds in federated setting should be defined for n_c in range(client.args.num_comms): client.args.rounds_comm += 1 client.args.comm_time.append(0.0) # Configuring the devices for this round of communication # TODO: not make the server rank hard coded log("Starting round {} of training".format(n_c + 1), client.args.debug) online_clients = set_online_clients(client.args) if n_c == 0: # The first round server should be in the communication to initilize its own training online_clients = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_server = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_group = dist.new_group(online_clients_server) client.args.drfa_gamma *= 0.9 if client.args.graph.rank in online_clients_server: if client.args.federated_type == 'scaffold': st = time.time() client.model_server, client.model_server_control = distribute_model_server_control( client.model_server, client.model_server_control, online_clients_group, src=0) client.args.comm_time[-1] += time.time() - st else: st = time.time() model_server = distribute_model_server(client.model_server, online_clients_group, src=0) client.args.comm_time[-1] += time.time() - st client.model.load_state_dict(client.model_server.state_dict()) # Send related variables to drfa algorithm st = time.time() dist.broadcast(client.lambda_vector, src=0, group=online_clients_group) # Sending the random number k to all nodes: # Does not fully support the epoch mode now k = torch.randint(low=1, high=client.args.local_step, size=(1, )) dist.broadcast(k, src=0, group=online_clients_group) client.args.comm_time[-1] += time.time() - st k = k.tolist()[0] local_steps = 0 # Start running updates on local machines if client.args.graph.rank in online_clients: is_sync = False while not is_sync: for _input, _target in client.train_loader: local_steps += 1 # Getting the k-th model for dual variable update if k == local_steps: client.kth_model.load_state_dict( client.model.state_dict()) client.model.train() # update local step. logging_load_time(tracker) # update local index and get local step client.args.local_index += 1 client.args.local_data_seen += len(_target) get_current_epoch(client.args) local_step = get_current_local_step(client.args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler) # load data _input, _target = load_data_batch( client.args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: is_sync = is_sync_fed(client.args) break # inference and get current performance. client.optimizer.zero_grad() loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target) # compute gradient and do local SGD step. loss.backward() if client.args.federated_type == 'fedgate': for client_param, delta_param in zip( client.model.parameters(), client.model_delta.parameters()): client_param.grad.data -= delta_param.data elif client.args.federated_type == 'scaffold': for cp, ccp, scp in zip( client.model.parameters(), client.model_client_control.parameters(), client.model_server_control.parameters()): cp.grad.data += scp.data - ccp.data client.optimizer.step( apply_lr=True, apply_in_momentum=client.args.in_momentum, apply_out_momentum=False) # logging locally. # logging_computing(tracker, loss, performance, _input, lr) # display the logging info. # logging_display_training(client.args, tracker) # reset load time for the tracker. tracker['start_load_time'] = time.time() is_sync = is_sync_fed(client.args) if is_sync: break else: log("Offline in this round. Waiting on others to finish!", client.args.debug) # Validate the local models befor sync do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train', local=True) if client.args.fed_personal: do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation', local=True) # Sync the model server based on model_clients log('Enter synching', client.args.debug) tracker['start_sync_time'] = time.time() client.args.global_index += 1 if client.args.federated_type == 'fedgate': client.model_server, client.model_delta = fedgate_aggregation( client.args, client.model_server, client.model, client.model_delta, client.model_memory, online_clients_group, online_clients, client.optimizer, lr, local_steps, lambda_weight=client.lambda_vector[ client.args.graph.rank].item()) elif client.args.federated_type == 'scaffold': client.model_server, client.model_client_control, client.model_server_control = scaffold_aggregation( client.args, client.model_server, client.model, client.model_server_control, client.model_client_control, online_clients_group, online_clients, client.optimizer, lr, local_steps, lambda_weight=client.lambda_vector[ client.args.graph.rank].item()) else: client.model_server = fedavg_aggregation( client.args, client.model_server, client.model, online_clients_group, online_clients, client.optimizer, lambda_weight=client.lambda_vector[ client.args.graph.rank].item()) # Average the kth_model client.kth_model = aggregate_models_virtual( client.args, client.kth_model, online_clients_group, online_clients) # evaluate the sync time logging_sync_time(tracker) # Do the validation on the server model do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train') if client.args.fed_personal: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation') # validate the model at the server if client.args.graph.rank == 0: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.test_loader, online_clients_group, data_mode='test') else: log("Offline in this round. Waiting on others to finish!", client.args.debug) # Update lambda parameters online_clients_lambda = set_online_clients(client.args) online_clients_server_lambda = online_clients_lambda if 0 in online_clients_lambda else [ 0 ] + online_clients_lambda online_clients_group_lambda = dist.new_group( online_clients_server_lambda) if client.args.graph.rank in online_clients_server_lambda: st = time.time() client.kth_model = distribute_model_server( client.kth_model, online_clients_group_lambda, src=0) client.args.comm_time[-1] += time.time() - st loss = torch.tensor(0.0) if client.args.graph.rank in online_clients_lambda: for _input, _target in client.train_loader: _input, _target = load_data_batch(client.args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: break loss, _ = inference(client.kth_model, client.criterion, client.metrics, _input, _target) break loss_tensor_online = loss_gather( client.args, torch.tensor(loss.item()), group=online_clients_group_lambda, online_clients=online_clients_lambda) if client.args.graph.rank == 0: num_online_clients = len( online_clients_lambda ) if 0 in online_clients_lambda else len( online_clients_lambda) + 1 loss_tensor = torch.zeros(client.args.graph.n_nodes) loss_tensor[sorted( online_clients_server_lambda)] = loss_tensor_online * ( client.args.graph.n_nodes / num_online_clients) # Dual update client.lambda_vector += client.args.drfa_gamma * client.args.local_step * loss_tensor client.lambda_vector = euclidean_proj_simplex( client.lambda_vector) # Avoid zero probability lambda_zeros = client.lambda_vector <= 1e-3 if lambda_zeros.sum() > 0: client.lambda_vector[lambda_zeros] = 1e-3 client.lambda_vector /= client.lambda_vector.sum() # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() log( 'This round communication time is: {}'.format( client.args.comm_time[-1]), client.args.debug) dist.barrier(group=client.all_clients_group) return
def train_and_validate_federated_afl(client): """The training scheme of Federated Learning systems. This the implementation of Agnostic Federated Learning https://arxiv.org/abs/1902.00146 """ log('start training and validation with Federated setting.', client.args.debug) if client.args.evaluate and client.args.graph.rank == 0: # Do the testing on the server and return do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') return # Initialize lambda variable proportianate to their sample size if client.args.graph.rank == 0: gather_list_size = [ torch.tensor(0.0) for _ in range(client.args.graph.n_nodes) ] dist.gather(torch.tensor(client.args.num_samples_per_epoch, dtype=torch.float32), gather_list=gather_list_size, dst=0) client.lambda_vector = torch.stack( gather_list_size) / client.args.train_dataset_size else: dist.gather(torch.tensor(client.args.num_samples_per_epoch, dtype=torch.float32), dst=0) client.lambda_vector = torch.tensor([1 / client.args.graph.n_nodes] * client.args.graph.n_nodes) tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.', client.args.debug) # Number of communication rounds in federated setting should be defined for n_c in range(client.args.num_comms): client.args.rounds_comm += 1 client.args.comm_time.append(0.0) # Configuring the devices for this round of communication # TODO: not make the server rank hard coded log("Starting round {} of training".format(n_c + 1), client.args.debug) online_clients = set_online_clients(client.args) if n_c == 0: # The first round server should be in the communication to initilize its own training online_clients = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_server = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_group = dist.new_group(online_clients_server) if client.args.graph.rank in online_clients_server: st = time.time() client.model_server = distribute_model_server(client.model_server, online_clients_group, src=0) dist.broadcast(client.lambda_vector, src=0, group=online_clients_group) client.args.comm_time[-1] += time.time() - st client.model.load_state_dict(client.model_server.state_dict()) # This loss tensor is for those clients not participating in the first round loss = torch.tensor(0.0) # Start running updates on local machines if client.args.graph.rank in online_clients: is_sync = False while not is_sync: for _input, _target in client.train_loader: client.model.train() # update local step. logging_load_time(tracker) # update local index and get local step client.args.local_index += 1 client.args.local_data_seen += len(_target) get_current_epoch(client.args) local_step = get_current_local_step(client.args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler) # load data _input, _target = load_data_batch( client.args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: is_sync = is_sync_fed(client.args) break # inference and get current performance. client.optimizer.zero_grad() loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target) # compute gradient and do local SGD step. loss.backward() client.optimizer.step( apply_lr=True, apply_in_momentum=client.args.in_momentum, apply_out_momentum=False) # logging locally. # logging_computing(tracker, loss, performance, _input, lr) # display the logging info. # logging_display_training(args, tracker) # reset load time for the tracker. tracker['start_load_time'] = time.time() is_sync = is_sync_fed(client.args) if is_sync: break else: log("Offline in this round. Waiting on others to finish!", client.args.debug) # Validate the local models befor sync do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train', local=True) if client.args.fed_personal: do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation', local=True) # Sync the model server based on client models log('Enter synching', client.args.debug) tracker['start_sync_time'] = time.time() client.args.global_index += 1 client.model_server, loss_tensor_online = afl_aggregation( client.args, client.model_server, client.model, client.lambda_vector[client.args.graph.rank].item(), torch.tensor(loss.item()), online_clients_group, online_clients, client.optimizer) # evaluate the sync time logging_sync_time(tracker) # Do the validation on the server model do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train') if client.args.fed_personal: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation') # Updating lambda variable if client.args.graph.rank == 0: num_online_clients = len( online_clients ) if 0 in online_clients else len(online_clients) + 1 loss_tensor = torch.zeros(client.args.graph.n_nodes) loss_tensor[sorted(online_clients_server)] = loss_tensor_online # Dual update client.lambda_vector += client.args.drfa_gamma * loss_tensor # Projection into a simplex client.lambda_vector = euclidean_proj_simplex( client.lambda_vector) # Avoid zero probability lambda_zeros = client.lambda_vector <= 1e-3 if lambda_zeros.sum() > 0: client.lambda_vector[lambda_zeros] = 1e-3 client.lambda_vector /= client.lambda_vector.sum() # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() # validate the model at the server if client.args.graph.rank == 0: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.test_loader, online_clients_group, data_mode='test') log( 'This round communication time is: {}'.format( client.args.comm_time[-1]), client.args.debug) else: log("Offline in this round. Waiting on others to finish!", client.args.debug) dist.barrier(group=client.all_clients_group) return
def train_and_validate(client): """The training scheme of Distributed Local SGD.""" log('start training and validation.', client.args.debug) if client.args.evaluate and client.args.graph.rank==0: # Do the training on the server and return do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') return tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.', client.args.debug) client.args.comm_time.append(0.0) # break until finish expected full epoch training. while True: # configure local step. for _input, _target in client.train_loader: client.model.train() # update local step. logging_load_time(tracker) # update local index and get local step client.args.local_index += 1 client.args.local_data_seen += len(_target) get_current_epoch(client.args) local_step = get_current_local_step(client.args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler) # load data _input, _target = load_data_batch(client.args, _input, _target, tracker) # inference and get current performance. client.optimizer.zero_grad() loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target) # compute gradient and do local SGD step. loss.backward() client.optimizer.step( apply_lr=True, apply_in_momentum=client.args.in_momentum, apply_out_momentum=False ) # logging locally. logging_computing(tracker, loss, performance, _input, lr) # evaluate the status. is_sync = client.args.local_index % local_step == 0 if client.args.epoch_ % 1 == 0: client.args.finish_one_epoch = True # sync if is_sync: log('Enter synching', client.args.debug) client.args.global_index += 1 # broadcast gradients to other nodes by using reduce_sum. client.model_server = aggregate_gradients(client.args, client.model_server, client.model, client.optimizer, is_sync) # evaluate the sync time logging_sync_time(tracker) # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() # finish one epoch training and to decide if we want to val our model. if client.args.finish_one_epoch: if client.args.epoch % client.args.eval_freq ==0 and client.args.graph.rank == 0: do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') dist.barrier(group=client.all_clients_group) # refresh the logging cache at the begining of each epoch. client.args.finish_one_epoch = False tracker = define_local_training_tracker() # determine if the training is finished. if is_stop(client.args): #Last Sync log('Enter synching', client.args.debug) client.args.global_index += 1 # broadcast gradients to other nodes by using reduce_sum. client.model_server = aggregate_gradients(client.args, client.model_server, client.model, client.optimizer, is_sync) print("Total number of samples seen on device {} is {}".format(client.args.graph.rank, client.args.local_data_seen)) if client.args.graph.rank == 0: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') return # display the logging info. logging_display_training(client.args, tracker) # reset load time for the tracker. tracker['start_load_time'] = time.time() # reshuffle the data. if client.args.reshuffle_per_epoch: log('reshuffle the dataset.', client.args.debug) del client.train_loader, client.test_loader gc.collect() log('reshuffle the dataset.', client.args.debug) client.load_local_dataset()
def train_and_validate_perfedme_centered(Clients, Server): log('start training and validation with Federated setting in a centered way.' ) # For Sparsified SGD tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.') # Number of communication rounds in federated setting should be defined for n_c in range(Server.args.num_comms): Server.args.rounds_comm += 1 Server.args.local_index += 1 Server.args.quant_error = 0.0 # Preset variables for this round of communication Server.zero_grad() Server.reset_tracker(Server.local_val_tracker) Server.reset_tracker(Server.global_val_tracker) if Server.args.fed_personal: Server.reset_tracker(Server.local_personal_val_tracker) Server.reset_tracker(Server.global_personal_val_tracker) # Configuring the devices for this round of communication log("Starting round {} of training".format(n_c + 1)) online_clients = set_online_clients_centered(Server.args) for oc in online_clients: Clients[oc].model.load_state_dict(Server.model.state_dict()) Clients[oc].args.rounds_comm = Server.args.rounds_comm local_steps = 0 is_sync = False do_validate_centered(Clients[oc].args, Server.model, Server.criterion, Server.metrics, Server.optimizer, Clients[oc].train_loader, Server.global_val_tracker, val=False, local=False) if Server.args.fed_personal: do_validate_centered(Clients[oc].args, Server.model, Server.criterion, Server.metrics, Server.optimizer, Clients[oc].val_loader, Server.global_personal_val_tracker, val=True, local=False) while not is_sync: for _input, _target in Clients[oc].train_loader: local_steps += 1 Clients[oc].model.train() Clients[oc].model_personal.train() # update local step. logging_load_time(tracker) # update local index and get local step Clients[oc].args.local_index += 1 Clients[oc].args.local_data_seen += len(_target) get_current_epoch(Clients[oc].args) local_step = get_current_local_step(Clients[oc].args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(Clients[oc].args, Clients[oc].optimizer_personal, Clients[oc].scheduler) # load data _input, _target = load_data_batch(Clients[oc].args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: is_sync = is_sync_fed(Clients[oc].args) break # inference and get current performance. Clients[oc].optimizer_personal.zero_grad() loss, performance = inference(Clients[oc].model_personal, Clients[oc].criterion, Clients[oc].metrics, _input, _target) # print("loss in rank {} is {}".format(oc,loss)) # compute gradient and do local SGD step. loss.backward() for client_param, personal_param in zip( Clients[oc].model.parameters(), Clients[oc].model_personal.parameters()): # loss += Clients[oc].args.perfedme_lambda * torch.norm(personal_param.data - client_param.data)**2 personal_param.grad.data += Clients[ oc].args.perfedme_lambda * (personal_param.data - client_param.data) Clients[oc].optimizer_personal.step( apply_lr=True, apply_in_momentum=Clients[oc].args.in_momentum, apply_out_momentum=False) if Clients[oc].args.local_index == 1: Clients[oc].optimizer.zero_grad() loss, performance = inference(Clients[oc].model, Clients[oc].criterion, Clients[oc].metrics, _input, _target) loss.backward() is_sync = is_sync_fed(Clients[oc].args) if Clients[oc].args.local_index % 5 == 0 or is_sync: log('Updating the local version of the global model', Clients[oc].args.debug) lr = adjust_learning_rate(Clients[oc].args, Clients[oc].optimizer, Clients[oc].scheduler) Clients[oc].optimizer.zero_grad() for client_param, personal_param in zip( Clients[oc].model.parameters(), Clients[oc].model_personal.parameters()): client_param.grad.data = Clients[ oc].args.perfedme_lambda * ( client_param.data - personal_param.data) Clients[oc].optimizer.step( apply_lr=True, apply_in_momentum=Clients[oc].args.in_momentum, apply_out_momentum=False) # reset load time for the tracker. tracker['start_load_time'] = time.time() # model_local = deepcopy(model_client) if is_sync: break do_validate_centered(Clients[oc].args, Clients[oc].model_personal, Clients[oc].criterion, Clients[oc].metrics, Clients[oc].optimizer_personal, Clients[oc].train_loader, Server.local_val_tracker, val=False, local=True) if Server.args.fed_personal: do_validate_centered(Clients[oc].args, Clients[oc].model_personal, Clients[oc].criterion, Clients[oc].metrics, Clients[oc].optimizer_personal, Clients[oc].val_loader, Server.local_personal_val_tracker, val=True, local=True) # Sync the model server based on model_clients tracker['start_sync_time'] = time.time() Server.args.global_index += 1 logging_sync_time(tracker) fedavg_aggregation_centered(Clients, Server, online_clients) # Log local performance log_validation_centered(Server.args, Server.local_val_tracker, val=False, local=True) if Server.args.fed_personal: log_validation_centered(Server.args, Server.local_personal_val_tracker, val=True, local=True) # Log server performance log_validation_centered(Server.args, Server.global_val_tracker, val=False, local=False) if Server.args.fed_personal: log_validation_centered(Server.args, Server.global_personal_val_tracker, val=True, local=False) # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() # validate the model at the server # if args.graph.rank == 0: # do_test(args, model_server, optimizer, criterion, metrics, test_loader) # do_validate_test(args, model_server, optimizer, criterion, metrics, test_loader) return
def do_validate_centered( args, model, criterion, metrics, optimizer, val_loader, val_tracker, personal=False, val=True, model_personal=None, alpha=0.0, local=False, ): """Evaluate the model on the validation dataset.""" # Finding the robust loss using gradient ascent if 'robust' in args.arch: tmp_noise = torch.clone(model.noise.data) model.noise.data = torch.zeros_like(tmp_noise) for _input, _target in val_loader: _input, _target = _load_data_batch(args, _input, _target) optimizer.zero_grad() loss, performance = inference(model, criterion, metrics, _input, _target) if model.noise.grad is None: loss.backward() optimizer.zero_grad() loss, performance = inference(model, criterion, metrics, _input, _target) model.noise.grad.data = torch.autograd.grad(loss, model.noise)[0] optimizer.step(apply_lr=False, scale=-0.01, apply_in_momentum=False, apply_out_momentum=args.out_momentum) if torch.norm(model.noise.data) > 1: model.noise.data /= torch.norm(model.noise.data) # switch to evaluation mode model.eval() if personal: if model_personal is None: raise ValueError( "model_personal should not be None for personalized mode!") model_personal.eval() # log('Do validation on the personal models.', args.debug) # else: # log('Do validation on the client models.', args.debug) for _input, _target in val_loader: # load data and check performance. _input, _target = _load_data_batch(args, _input, _target) if _input.size(0) == 1: break with torch.no_grad(): if personal: loss, performance = inference_personal(model_personal, model, alpha, criterion, metrics, _input, _target) else: loss, performance = inference(model, criterion, metrics, _input, _target, rnn=args.arch in ['rnn']) val_tracker = update_performancec_tracker(val_tracker, loss, performance, _input.size(0)) # if personal and not val: # print("acc in rank {} is {}".format(args.graph.rank,val_tracker['top1'].avg)) if 'robust' in args.arch: model.noise.data = tmp_noise return
def train_and_validate_federated(client): """The training scheme of Federated Learning systems. The basic model is FedAvg https://arxiv.org/abs/1602.05629 TODO: Merge different models under this method """ log('start training and validation with Federated setting.', client.args.debug) if client.args.evaluate and client.args.graph.rank == 0: # Do the training on the server and return do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') return # init global variable. tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.', client.args.debug) # Number of communication rounds in federated setting should be defined for n_c in range(client.args.num_comms): client.args.rounds_comm += 1 client.args.comm_time.append(0.0) # Configuring the devices for this round of communication log("Starting round {} of training".format(n_c + 1), client.args.debug) online_clients = set_online_clients(client.args) if (n_c == 0) and (0 not in online_clients): online_clients += [0] online_clients_server = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_group = dist.new_group(online_clients_server) if client.args.graph.rank in online_clients_server: if client.args.federated_type == 'scaffold': st = time.time() client.model_server, client.model_server_control = distribute_model_server_control( client.model_server, client.model_server_control, online_clients_group, src=0) client.args.comm_time[-1] += time.time() - st else: st = time.time() client.model_server = distribute_model_server( client.model_server, online_clients_group, src=0) client.args.comm_time[-1] += time.time() - st client.model.load_state_dict(client.model_server.state_dict()) local_steps = 0 if client.args.graph.rank in online_clients: is_sync = False while not is_sync: for _input, _target in client.train_loader: local_steps += 1 client.model.train() # update local step. logging_load_time(tracker) # update local index and get local step client.args.local_index += 1 client.args.local_data_seen += len(_target) get_current_epoch(client.args) local_step = get_current_local_step(client.args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler) # load data _input, _target = load_data_batch( client.args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: is_sync = is_sync_fed(client.args) break # inference and get current performance. client.optimizer.zero_grad() loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target) # compute gradient and do local SGD step. loss.backward() if client.args.federated_type == 'fedgate': # Update gradients with control variates for client_param, delta_param in zip( client.model.parameters(), client.model_delta.parameters()): client_param.grad.data -= delta_param.data elif client.args.federated_type == 'scaffold': for cp, ccp, scp in zip( client.model.parameters(), client.model_client_control.parameters(), client.model_server_control.parameters()): cp.grad.data += scp.data - ccp.data elif client.args.federated_type == 'fedprox': # Adding proximal gradients and loss for fedprox for client_param, server_param in zip( client.model.parameters(), client.model_server.parameters()): if client.args.graph.rank == 0: print( "distance norm for prox is:{}".format( torch.norm(client_param.data - server_param.data))) loss += client.args.fedprox_mu / 2 * torch.norm( client_param.data - server_param.data) client_param.grad.data += client.args.fedprox_mu * ( client_param.data - server_param.data) if 'robust' in client.args.arch: client.model.noise.grad.data *= -1 client.optimizer.step( apply_lr=True, apply_in_momentum=client.args.in_momentum, apply_out_momentum=False) if 'robust' in client.args.arch: if torch.norm(client.model.noise.data) > 1: client.model.noise.data /= torch.norm( client.model.noise.data) # logging locally. # logging_computing(tracker, loss_v, performance_v, _input, lr) # display the logging info. # logging_display_training(args, tracker) # reset load time for the tracker. tracker['start_load_time'] = time.time() # model_local = deepcopy(model_client) is_sync = is_sync_fed(client.args) if is_sync: break else: log("Offline in this round. Waiting on others to finish!", client.args.debug) # Validate the local models befor sync do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train', local=True) if client.args.fed_personal: do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation', local=True) # Sync the model server based on client models log('Enter synching', client.args.debug) tracker['start_sync_time'] = time.time() client.args.global_index += 1 if client.args.federated_type == 'fedgate': client.model_server, client.model_delta = fedgate_aggregation( client.args, client.model_server, client.model, client.model_delta, client.model_memory, online_clients_group, online_clients, client.optimizer, lr, local_steps) elif client.args.federated_type == 'scaffold': client.model_server, client.model_client_control, client.model_server_control = scaffold_aggregation( client.args, client.model_server, client.model, client.model_server_control, client.model_client_control, online_clients_group, online_clients, client.optimizer, lr, local_steps) elif client.args.federated_type == 'qsparse': client.model_server = qsparse_aggregation( client.args, client.model_server, client.model, online_clients_group, online_clients, client.optimizer, client.model_memory) else: client.model_server = fedavg_aggregation( client.args, client.model_server, client.model, online_clients_group, online_clients, client.optimizer) # evaluate the sync time logging_sync_time(tracker) # Do the validation on the server model do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train') if client.args.fed_personal: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation') # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() # validate the model at the server if client.args.graph.rank == 0: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.test_loader, online_clients_group, data_mode='test') log( 'This round communication time is: {}'.format( client.args.comm_time[-1]), client.args.debug) else: log("Offline in this round. Waiting on others to finish!", client.args.debug) dist.barrier(group=client.all_clients_group) return
def train_and_validate_federated_apfl(client): """The training scheme of Personalized Federated Learning. Official implementation for https://arxiv.org/abs/2003.13461 """ log('start training and validation with Federated setting.', client.args.debug) if client.args.evaluate and client.args.graph.rank == 0: # Do the testing on the server and return do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') return tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.', client.args.debug) # Number of communication rounds in federated setting should be defined for n_c in range(client.args.num_comms): client.args.rounds_comm += 1 client.args.comm_time.append(0.0) # Configuring the devices for this round of communication # TODO: not make the server rank hard coded log("Starting round {} of training".format(n_c), client.args.debug) online_clients = set_online_clients(client.args) if (n_c == 0) and (0 not in online_clients): online_clients += [0] online_clients_server = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_group = dist.new_group(online_clients_server) if client.args.graph.rank in online_clients_server: client.model_server = distribute_model_server(client.model_server, online_clients_group, src=0) client.model.load_state_dict(client.model_server.state_dict()) if client.args.graph.rank in online_clients: is_sync = False ep = -1 # counting number of epochs while not is_sync: ep += 1 for i, (_input, _target) in enumerate(client.train_loader): client.model.train() # update local step. logging_load_time(tracker) # update local index and get local step client.args.local_index += 1 client.args.local_data_seen += len(_target) get_current_epoch(client.args) local_step = get_current_local_step(client.args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler) # load data _input, _target = load_data_batch( client.args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: is_sync = is_sync_fed(client.args) break # inference and get current performance. client.optimizer.zero_grad() loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target) # compute gradient and do local SGD step. loss.backward() client.optimizer.step( apply_lr=True, apply_in_momentum=client.args.in_momentum, apply_out_momentum=False) client.optimizer.zero_grad() client.optimizer_personal.zero_grad() loss_personal, performance_personal = inference_personal( client.model_personal, client.model, client.args.fed_personal_alpha, client.criterion, client.metrics, _input, _target) # compute gradient and do local SGD step. loss_personal.backward() client.optimizer_personal.step( apply_lr=True, apply_in_momentum=client.args.in_momentum, apply_out_momentum=False) # update alpha if client.args.fed_adaptive_alpha and i == 0 and ep == 0: client.args.fed_personal_alpha = alpha_update( client.model, client.model_personal, client.args.fed_personal_alpha, lr) #0.1/np.sqrt(1+args.local_index)) average_alpha = client.args.fed_personal_alpha average_alpha = global_average( average_alpha, client.args.graph.n_nodes, group=online_clients_group) log("New alpha is:{}".format(average_alpha.item()), client.args.debug) # logging locally. # logging_computing(tracker, loss, performance, _input, lr) # display the logging info. # logging_display_training(args, tracker) # reset load time for the tracker. tracker['start_load_time'] = time.time() is_sync = is_sync_fed(client.args) if is_sync: break else: log("Offline in this round. Waiting on others to finish!", client.args.debug) do_validate(client.args, client.model, client.optimizer_personal, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train', personal=True, model_personal=client.model_personal, alpha=client.args.fed_personal_alpha) if client.args.fed_personal: do_validate(client.args, client.model, client.optimizer_personal, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation', personal=True, model_personal=client.model_personal, alpha=client.args.fed_personal_alpha) # Sync the model server based on model_clients log('Enter synching', client.args.debug) tracker['start_sync_time'] = time.time() client.args.global_index += 1 client.model_server = fedavg_aggregation( client.args, client.model_server, client.model, online_clients_group, online_clients, client.optimizer) # evaluate the sync time logging_sync_time(tracker) # Do the validation on the server model do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train') if client.args.fed_personal: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation') # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() # validate the models at the test data if client.args.fed_personal_test: do_validate(client.args, client.model_client, client.optimizer_personal, client.criterion, client.metrics, client.test_loader, online_clients_group, data_mode='test', personal=True, model_personal=client.model_personal, alpha=client.args.fed_personal_alpha) elif client.args.graph.rank == 0: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.test_loader, online_clients_group, data_mode='test') else: log("Offline in this round. Waiting on others to finish!", client.args.debug) dist.barrier(group=client.all_clients_group)