def log_validation_per_client_centered(args, OnlineClients, online_clients, val=True, local=False): # log('Aggregate val performance from different clients.', args.debug) acc = [] for oc in online_clients: if local: if val: acc.append( OnlineClients[oc].local_personal_val_tracker['top1'].avg) else: acc.append(OnlineClients[oc].local_val_tracker['top1'].avg) else: if val: acc.append( OnlineClients[oc].global_personal_val_tracker['top1'].avg) else: acc.append(OnlineClients[oc].global_val_tracker['top1'].avg) log('{} per client stat for {} at batch: {}. Epoch: {}. Process: {}. Worst: {:.3f} Best: {:.3f} Var: {:.3f} Comm: {}' .format('Personal' if local else 'Global', 'validation' if val else 'train', args.local_index, args.epoch, args.graph.rank, np.min(acc), np.max(acc), np.std(acc), args.rounds_comm), debug=args.debug) return
def get_data_stat(args, train_loader, test_loader=None): # get the data statictics (on behalf of each worker) for train. # args.num_batches_train_per_device_per_epoch = \ # len(train_loader) // args.graph.n_nodes \ # if not args.partition_data else len(train_loader) args.num_batches_train_per_device_per_epoch = len(train_loader) args.num_whole_train_batches_per_worker = \ args.num_batches_train_per_device_per_epoch * args.num_epochs args.num_warmup_train_batches_per_worker = \ args.num_batches_train_per_device_per_epoch * args.lr_warmup_epochs args.num_iterations_per_worker = args.num_iterations #// args.graph.n_nodes # get the data statictics (on behalf of each worker) for val. if test_loader is not None: args.num_batches_val_per_device_per_epoch = len(test_loader) else: args.num_batches_val_per_device_per_epoch = 0 # define some parameters for training. log( 'we have {} epochs, \ {} mini-batches per device for training. \ {} mini-batches per device for test. \ The batch size: {}.'.format( args.num_epochs, args.num_batches_train_per_device_per_epoch, args.num_batches_val_per_device_per_epoch, args.batch_size), args.debug)
def _check_model_at_sync(iter, gpu_id, model, is_weight=False, is_gradient=True, debug=True): model_parameters = list(model.parameters()) param = model_parameters[0] if is_weight: log( "iter:{}, check process {}'s weights for 1st variable:{}".format( iter, gpu_id, torch.norm(param.data)), debug) if is_gradient: log( "iter:{}, check process {}'s gradients for 1st variable:{}".format( iter, gpu_id, torch.norm(param.grad.data)), debug)
def log_validation_centered(args, val_tracker, personal=False, val=True, local=False): # log('Aggregate val performance from different clients.', args.debug) performance = [val_tracker[x].avg for x in ['top1', 'top5', 'losses']] pretext = [] pretext.append('Personal' if personal or local else 'Global') pretext.append('validation' if val else 'train') log('{} performance for {} at batch: {}. Epoch: {}. Process: {}. Prec@1: {:.3f} Prec@5: {:.3f} Loss: {:.3f} Comm: {}' .format(pretext[0], pretext[1], args.local_index, args.epoch, args.graph.rank, performance[0], performance[1], performance[2], args.rounds_comm), debug=args.debug) return
def initialize(self): init_config(self.args) self.model, self.criterion, self.scheduler, self.optimizer, self.metrics = create_components( self.args) self.args.finish_one_epoch = False # Create a model server on each client to keep a copy of the server model at each communication round. self.model_server = deepcopy(self.model) configure_log(self.args) log_args(self.args, debug=self.args.debug) log('Rank {} with block {} on {} {}-{}'.format( self.args.graph.rank, self.args.graph.ranks_with_blocks[self.args.graph.rank], platform.node(), 'GPU' if self.args.graph.on_cuda else 'CPU', self.args.graph.device), debug=self.args.debug) self.all_clients_group = dist.new_group(self.args.graph.ranks)
def define_dataset(args, shuffle, test=True, Partitioner=None, return_partitioner=False): log('create {} dataset for rank {}'.format(args.data, args.graph.rank), args.debug) train_loader = partition_dataset(args, shuffle, dataset_type='train', Partitioner=Partitioner, return_partitioner=return_partitioner) if return_partitioner: train_loader, Partitioner = train_loader if args.fed_personal: if args.federated_type == 'perfedavg': train_loader, val_loader, val_loader1 = train_loader else: train_loader, val_loader = train_loader if test: test_loader = partition_dataset(args, shuffle, dataset_type='test') else: test_loader = None get_data_stat(args, train_loader, test_loader) if args.fed_personal: if args.federated_type == 'perfedavg': out = [train_loader, test_loader, val_loader, val_loader1] else: out = [train_loader, test_loader, val_loader] else: out = [train_loader, test_loader] if return_partitioner: out = (out, Partitioner) return out
def train_and_validate_drfa_centered(Clients, Server): log('start training and validation with Federated setting in a centered way.' ) tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.') for oc in range(Server.args.graph.n_nodes): Server.lambda_vector[oc] = Clients[oc].args.num_samples_per_epoch Server.lambda_vector /= Server.lambda_vector.sum() # Number of communication rounds in federated setting should be defined for n_c in range(Server.args.num_comms): Server.args.rounds_comm += 1 Server.args.local_index += 1 Server.args.quant_error = 0.0 # Preset variables for this round of communication Server.zero_grad() Server.reset_tracker(Server.local_val_tracker) Server.reset_tracker(Server.global_val_tracker) Server.reset_tracker(Server.global_test_tracker) if Server.args.fed_personal: Server.reset_tracker(Server.local_personal_val_tracker) Server.reset_tracker(Server.global_personal_val_tracker) # Configuring the devices for this round of communication log("Starting round {} of training".format(n_c + 1)) online_clients = set_online_clients_centered(Server.args) Server.args.drfa_gamma *= 0.9 for oc in online_clients: Clients[oc].model.load_state_dict(Server.model.state_dict()) Clients[oc].args.rounds_comm = Server.args.rounds_comm local_steps = 0 is_sync = False do_validate_centered(Clients[oc].args, Server.model, Server.criterion, Server.metrics, Server.optimizer, Clients[oc].train_loader, Server.global_val_tracker, val=False, local=False) if Server.args.per_class_acc: Clients[oc].reset_tracker(Clients[oc].local_val_tracker) Clients[oc].reset_tracker(Clients[oc].global_val_tracker) if Server.args.fed_personal: Clients[oc].reset_tracker( Clients[oc].local_personal_val_tracker) Clients[oc].reset_tracker( Clients[oc].global_personal_val_tracker) do_validate_centered( Clients[oc].args, Server.model, Server.criterion, Server.metrics, Server.optimizer, Clients[oc].val_loader, Clients[oc].global_personal_val_tracker, val=True, local=False) do_validate_centered(Clients[oc].args, Server.model, Server.criterion, Server.metrics, Server.optimizer, Clients[oc].train_loader, Clients[oc].global_val_tracker, val=False, local=False) if Server.args.fed_personal: do_validate_centered(Clients[oc].args, Server.model, Server.criterion, Server.metrics, Server.optimizer, Clients[oc].val_loader, Server.global_personal_val_tracker, val=True, local=False) if Server.args.federated_type == 'perfedavg': for _input_val, _target_val in Clients[oc].val_loader1: _input_val, _target_val = load_data_batch( Clients[oc].args, _input_val, _target_val, tracker) break k = torch.randint(low=1, high=Server.args.local_step, size=(1, )).item() while not is_sync: if Server.args.arch == 'rnn': Clients[oc].model.init_hidden(Server.args.batch_size) for _input, _target in Clients[oc].train_loader: local_steps += 1 if k == local_steps: Clients[oc].kth_model.load_state_dict( Clients[oc].model.state_dict()) Clients[oc].model.train() # update local step. logging_load_time(tracker) # update local index and get local step Clients[oc].args.local_index += 1 Clients[oc].args.local_data_seen += len(_target) get_current_epoch(Clients[oc].args) local_step = get_current_local_step(Clients[oc].args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(Clients[oc].args, Clients[oc].optimizer, Clients[oc].scheduler) # load data _input, _target = load_data_batch(Clients[oc].args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: is_sync = is_sync_fed(Clients[oc].args) break # inference and get current performance. Clients[oc].optimizer.zero_grad() loss, performance = inference(Clients[oc].model, Clients[oc].criterion, Clients[oc].metrics, _input, _target, rnn=Server.args.arch in ['rnn']) # compute gradient and do local SGD step. loss.backward() if Clients[oc].args.federated_type == 'fedgate': # Update gradients with control variates for client_param, delta_param in zip( Clients[oc].model.parameters(), Clients[oc].model_delta.parameters()): client_param.grad.data -= delta_param.data elif Clients[oc].args.federated_type == 'scaffold': for cp, ccp, scp in zip( Clients[oc].model.parameters(), Clients[oc].model_client_control.parameters(), Server.model_server_control.parameters()): cp.grad.data += scp.data - ccp.data elif Clients[oc].args.federated_type == 'fedprox': # Adding proximal gradients and loss for fedprox for client_param, server_param in zip( Clients[oc].model.parameters(), Server.model.parameters()): loss += Clients[ oc].args.fedprox_mu / 2 * torch.norm( client_param.data - server_param.data) client_param.grad.data += Clients[ oc].args.fedprox_mu * (client_param.data - server_param.data) if 'robust' in Clients[oc].args.arch: Clients[oc].model.noise.grad.data *= -1 Clients[oc].optimizer.step( apply_lr=True, apply_in_momentum=Clients[oc].args.in_momentum, apply_out_momentum=False) if 'robust' in Clients[oc].args.arch: if torch.norm(Clients[oc].model.noise.data) > 1: Clients[oc].model.noise.data /= torch.norm( Clients[oc].model.noise.data) if Clients[oc].args.federated_type == 'perfedavg': # _input_val, _target_val = Clients[oc].load_next_val_batch() lr = adjust_learning_rate( Clients[oc].args, Clients[oc].optimizer, Clients[oc].scheduler, lr_external=Clients[oc].args.perfedavg_beta) if _input_val.size(0) == 1: is_sync = is_sync_fed(Clients[oc].args) break Clients[oc].optimizer.zero_grad() loss, performance = inference(Clients[oc].model, Clients[oc].criterion, Clients[oc].metrics, _input_val, _target_val) loss.backward() Clients[oc].optimizer.step( apply_lr=True, apply_in_momentum=Clients[oc].args.in_momentum, apply_out_momentum=False) # reset load time for the tracker. tracker['start_load_time'] = time.time() # model_local = deepcopy(model_client) is_sync = is_sync_fed(Clients[oc].args) if is_sync: break do_validate_centered(Clients[oc].args, Clients[oc].model, Clients[oc].criterion, Clients[oc].metrics, Clients[oc].optimizer, Clients[oc].train_loader, Server.local_val_tracker, val=False, local=True) if Server.args.per_class_acc: do_validate_centered(Clients[oc].args, Clients[oc].model, Clients[oc].criterion, Clients[oc].metrics, Clients[oc].optimizer, Clients[oc].train_loader, Clients[oc].local_val_tracker, val=False, local=True) if Server.args.fed_personal: do_validate_centered( Clients[oc].args, Clients[oc].model, Clients[oc].criterion, Clients[oc].metrics, Clients[oc].optimizer, Clients[oc].val_loader, Clients[oc].local_personal_val_tracker, val=True, local=True) if Server.args.fed_personal: do_validate_centered(Clients[oc].args, Clients[oc].model, Clients[oc].criterion, Clients[oc].metrics, Clients[oc].optimizer, Clients[oc].val_loader, Server.local_personal_val_tracker, val=True, local=True) # Sync the model server based on model_clients tracker['start_sync_time'] = time.time() Server.args.global_index += 1 logging_sync_time(tracker) if Server.args.federated_type == 'scaffold': scaffold_aggregation_centered(Clients, Server, online_clients, local_steps, lr) elif Server.args.federated_type == 'fedgate': fedgate_aggregation_centered(Clients, Server, online_clients, local_steps, lr) elif Server.args.federated_type == 'qsparse': qsparse_aggregation_centered(Clients, Server, online_clients, local_steps, lr) else: fedavg_aggregation_centered(Clients, Server, online_clients, Server.lambda_vector.numpy()) # Aggregate Kth models aggregate_kth_model_centered(Clients, Server, online_clients) # Log performance # Client training performance log_validation_centered(Server.args, Server.local_val_tracker, val=False, local=True) # Server training performance log_validation_centered(Server.args, Server.global_val_tracker, val=False, local=False) if Server.args.fed_personal: # Client validation performance log_validation_centered(Server.args, Server.local_personal_val_tracker, val=True, local=True) # Server validation performance log_validation_centered(Server.args, Server.global_personal_val_tracker, val=True, local=False) # Per client stats if Server.args.per_class_acc: log_validation_per_client_centered(Server.args, Clients, online_clients, val=False, local=False) log_validation_per_client_centered(Server.args, Clients, online_clients, val=False, local=True) if Server.args.fed_personal: log_validation_per_client_centered(Server.args, Clients, online_clients, val=True, local=False) log_validation_per_client_centered(Server.args, Clients, online_clients, val=True, local=True) # Test on server do_validate_centered(Server.args, Server.model, Server.criterion, Server.metrics, Server.optimizer, Server.test_loader, Server.global_test_tracker, val=False, local=False) log_test_centered(Server.args, Server.global_test_tracker) online_clients_lambda = set_online_clients_centered(Server.args) loss_tensor = torch.zeros(Server.args.graph.n_nodes) num_online_clients = len(online_clients_lambda) for ocl in online_clients_lambda: for _input, _target in Clients[ocl].train_loader: _input, _target = load_data_batch(Clients[ocl].args, _input, _target, tracker) loss, _ = inference(Server.kth_model, Clients[ocl].criterion, Clients[ocl].metrics, _input, _target, rnn=Server.args.arch in ['rnn']) break loss_tensor[ocl] = loss * (Server.args.graph.n_nodes / num_online_clients) Server.lambda_vector += Server.args.drfa_gamma * Server.args.local_step * loss_tensor lambda_vector = projection_simplex_sort( Server.lambda_vector.detach().numpy()) print(lambda_vector) # Avoid zero probability lambda_zeros = np.argwhere(lambda_vector <= 1e-3) if len(lambda_zeros) > 0: lambda_vector[lambda_zeros[0]] = 1e-3 lambda_vector /= np.sum(lambda_vector) Server.lambda_vector = torch.tensor(lambda_vector) # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() # validate the model at the server # if args.graph.rank == 0: # do_test(args, model_server, optimizer, criterion, metrics, test_loader) # do_validate_test(args, model_server, optimizer, criterion, metrics, test_loader) return
def train_and_validate_federated_afl(client): """The training scheme of Federated Learning systems. This the implementation of Agnostic Federated Learning https://arxiv.org/abs/1902.00146 """ log('start training and validation with Federated setting.', client.args.debug) if client.args.evaluate and client.args.graph.rank == 0: # Do the testing on the server and return do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') return # Initialize lambda variable proportianate to their sample size if client.args.graph.rank == 0: gather_list_size = [ torch.tensor(0.0) for _ in range(client.args.graph.n_nodes) ] dist.gather(torch.tensor(client.args.num_samples_per_epoch, dtype=torch.float32), gather_list=gather_list_size, dst=0) client.lambda_vector = torch.stack( gather_list_size) / client.args.train_dataset_size else: dist.gather(torch.tensor(client.args.num_samples_per_epoch, dtype=torch.float32), dst=0) client.lambda_vector = torch.tensor([1 / client.args.graph.n_nodes] * client.args.graph.n_nodes) tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.', client.args.debug) # Number of communication rounds in federated setting should be defined for n_c in range(client.args.num_comms): client.args.rounds_comm += 1 client.args.comm_time.append(0.0) # Configuring the devices for this round of communication # TODO: not make the server rank hard coded log("Starting round {} of training".format(n_c + 1), client.args.debug) online_clients = set_online_clients(client.args) if n_c == 0: # The first round server should be in the communication to initilize its own training online_clients = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_server = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_group = dist.new_group(online_clients_server) if client.args.graph.rank in online_clients_server: st = time.time() client.model_server = distribute_model_server(client.model_server, online_clients_group, src=0) dist.broadcast(client.lambda_vector, src=0, group=online_clients_group) client.args.comm_time[-1] += time.time() - st client.model.load_state_dict(client.model_server.state_dict()) # This loss tensor is for those clients not participating in the first round loss = torch.tensor(0.0) # Start running updates on local machines if client.args.graph.rank in online_clients: is_sync = False while not is_sync: for _input, _target in client.train_loader: client.model.train() # update local step. logging_load_time(tracker) # update local index and get local step client.args.local_index += 1 client.args.local_data_seen += len(_target) get_current_epoch(client.args) local_step = get_current_local_step(client.args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler) # load data _input, _target = load_data_batch( client.args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: is_sync = is_sync_fed(client.args) break # inference and get current performance. client.optimizer.zero_grad() loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target) # compute gradient and do local SGD step. loss.backward() client.optimizer.step( apply_lr=True, apply_in_momentum=client.args.in_momentum, apply_out_momentum=False) # logging locally. # logging_computing(tracker, loss, performance, _input, lr) # display the logging info. # logging_display_training(args, tracker) # reset load time for the tracker. tracker['start_load_time'] = time.time() is_sync = is_sync_fed(client.args) if is_sync: break else: log("Offline in this round. Waiting on others to finish!", client.args.debug) # Validate the local models befor sync do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train', local=True) if client.args.fed_personal: do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation', local=True) # Sync the model server based on client models log('Enter synching', client.args.debug) tracker['start_sync_time'] = time.time() client.args.global_index += 1 client.model_server, loss_tensor_online = afl_aggregation( client.args, client.model_server, client.model, client.lambda_vector[client.args.graph.rank].item(), torch.tensor(loss.item()), online_clients_group, online_clients, client.optimizer) # evaluate the sync time logging_sync_time(tracker) # Do the validation on the server model do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train') if client.args.fed_personal: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation') # Updating lambda variable if client.args.graph.rank == 0: num_online_clients = len( online_clients ) if 0 in online_clients else len(online_clients) + 1 loss_tensor = torch.zeros(client.args.graph.n_nodes) loss_tensor[sorted(online_clients_server)] = loss_tensor_online # Dual update client.lambda_vector += client.args.drfa_gamma * loss_tensor # Projection into a simplex client.lambda_vector = euclidean_proj_simplex( client.lambda_vector) # Avoid zero probability lambda_zeros = client.lambda_vector <= 1e-3 if lambda_zeros.sum() > 0: client.lambda_vector[lambda_zeros] = 1e-3 client.lambda_vector /= client.lambda_vector.sum() # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() # validate the model at the server if client.args.graph.rank == 0: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.test_loader, online_clients_group, data_mode='test') log( 'This round communication time is: {}'.format( client.args.comm_time[-1]), client.args.debug) else: log("Offline in this round. Waiting on others to finish!", client.args.debug) dist.barrier(group=client.all_clients_group) return
def train_and_validate_federated_drfa(client): """The training scheme of Distributionally Robust Federated Learning DRFA. paper: https://papers.nips.cc/paper/2020/hash/ac450d10e166657ec8f93a1b65ca1b14-Abstract.html """ log('start training and validation with Federated setting.', client.args.debug) if client.args.evaluate and client.args.graph.rank == 0: # Do the testing on the server and return do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') return # Initialize lambda variable proportianate to their sample size if client.args.graph.rank == 0: gather_list_size = [ torch.tensor(0.0) for _ in range(client.args.graph.n_nodes) ] dist.gather(torch.tensor(client.args.num_samples_per_epoch, dtype=torch.float32), gather_list=gather_list_size, dst=0) lambda_vector = torch.stack( gather_list_size) / client.args.train_dataset_size else: dist.gather(torch.tensor(client.args.num_samples_per_epoch, dtype=torch.float32), dst=0) lambda_vector = torch.tensor([1 / client.args.graph.n_nodes] * client.args.graph.n_nodes) tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.', client.args.debug) # Number of communication rounds in federated setting should be defined for n_c in range(client.args.num_comms): client.args.rounds_comm += 1 client.args.comm_time.append(0.0) # Configuring the devices for this round of communication # TODO: not make the server rank hard coded log("Starting round {} of training".format(n_c + 1), client.args.debug) online_clients = set_online_clients(client.args) if n_c == 0: # The first round server should be in the communication to initilize its own training online_clients = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_server = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_group = dist.new_group(online_clients_server) client.args.drfa_gamma *= 0.9 if client.args.graph.rank in online_clients_server: if client.args.federated_type == 'scaffold': st = time.time() client.model_server, client.model_server_control = distribute_model_server_control( client.model_server, client.model_server_control, online_clients_group, src=0) client.args.comm_time[-1] += time.time() - st else: st = time.time() model_server = distribute_model_server(client.model_server, online_clients_group, src=0) client.args.comm_time[-1] += time.time() - st client.model.load_state_dict(client.model_server.state_dict()) # Send related variables to drfa algorithm st = time.time() dist.broadcast(client.lambda_vector, src=0, group=online_clients_group) # Sending the random number k to all nodes: # Does not fully support the epoch mode now k = torch.randint(low=1, high=client.args.local_step, size=(1, )) dist.broadcast(k, src=0, group=online_clients_group) client.args.comm_time[-1] += time.time() - st k = k.tolist()[0] local_steps = 0 # Start running updates on local machines if client.args.graph.rank in online_clients: is_sync = False while not is_sync: for _input, _target in client.train_loader: local_steps += 1 # Getting the k-th model for dual variable update if k == local_steps: client.kth_model.load_state_dict( client.model.state_dict()) client.model.train() # update local step. logging_load_time(tracker) # update local index and get local step client.args.local_index += 1 client.args.local_data_seen += len(_target) get_current_epoch(client.args) local_step = get_current_local_step(client.args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler) # load data _input, _target = load_data_batch( client.args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: is_sync = is_sync_fed(client.args) break # inference and get current performance. client.optimizer.zero_grad() loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target) # compute gradient and do local SGD step. loss.backward() if client.args.federated_type == 'fedgate': for client_param, delta_param in zip( client.model.parameters(), client.model_delta.parameters()): client_param.grad.data -= delta_param.data elif client.args.federated_type == 'scaffold': for cp, ccp, scp in zip( client.model.parameters(), client.model_client_control.parameters(), client.model_server_control.parameters()): cp.grad.data += scp.data - ccp.data client.optimizer.step( apply_lr=True, apply_in_momentum=client.args.in_momentum, apply_out_momentum=False) # logging locally. # logging_computing(tracker, loss, performance, _input, lr) # display the logging info. # logging_display_training(client.args, tracker) # reset load time for the tracker. tracker['start_load_time'] = time.time() is_sync = is_sync_fed(client.args) if is_sync: break else: log("Offline in this round. Waiting on others to finish!", client.args.debug) # Validate the local models befor sync do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train', local=True) if client.args.fed_personal: do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation', local=True) # Sync the model server based on model_clients log('Enter synching', client.args.debug) tracker['start_sync_time'] = time.time() client.args.global_index += 1 if client.args.federated_type == 'fedgate': client.model_server, client.model_delta = fedgate_aggregation( client.args, client.model_server, client.model, client.model_delta, client.model_memory, online_clients_group, online_clients, client.optimizer, lr, local_steps, lambda_weight=client.lambda_vector[ client.args.graph.rank].item()) elif client.args.federated_type == 'scaffold': client.model_server, client.model_client_control, client.model_server_control = scaffold_aggregation( client.args, client.model_server, client.model, client.model_server_control, client.model_client_control, online_clients_group, online_clients, client.optimizer, lr, local_steps, lambda_weight=client.lambda_vector[ client.args.graph.rank].item()) else: client.model_server = fedavg_aggregation( client.args, client.model_server, client.model, online_clients_group, online_clients, client.optimizer, lambda_weight=client.lambda_vector[ client.args.graph.rank].item()) # Average the kth_model client.kth_model = aggregate_models_virtual( client.args, client.kth_model, online_clients_group, online_clients) # evaluate the sync time logging_sync_time(tracker) # Do the validation on the server model do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train') if client.args.fed_personal: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation') # validate the model at the server if client.args.graph.rank == 0: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.test_loader, online_clients_group, data_mode='test') else: log("Offline in this round. Waiting on others to finish!", client.args.debug) # Update lambda parameters online_clients_lambda = set_online_clients(client.args) online_clients_server_lambda = online_clients_lambda if 0 in online_clients_lambda else [ 0 ] + online_clients_lambda online_clients_group_lambda = dist.new_group( online_clients_server_lambda) if client.args.graph.rank in online_clients_server_lambda: st = time.time() client.kth_model = distribute_model_server( client.kth_model, online_clients_group_lambda, src=0) client.args.comm_time[-1] += time.time() - st loss = torch.tensor(0.0) if client.args.graph.rank in online_clients_lambda: for _input, _target in client.train_loader: _input, _target = load_data_batch(client.args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: break loss, _ = inference(client.kth_model, client.criterion, client.metrics, _input, _target) break loss_tensor_online = loss_gather( client.args, torch.tensor(loss.item()), group=online_clients_group_lambda, online_clients=online_clients_lambda) if client.args.graph.rank == 0: num_online_clients = len( online_clients_lambda ) if 0 in online_clients_lambda else len( online_clients_lambda) + 1 loss_tensor = torch.zeros(client.args.graph.n_nodes) loss_tensor[sorted( online_clients_server_lambda)] = loss_tensor_online * ( client.args.graph.n_nodes / num_online_clients) # Dual update client.lambda_vector += client.args.drfa_gamma * client.args.local_step * loss_tensor client.lambda_vector = euclidean_proj_simplex( client.lambda_vector) # Avoid zero probability lambda_zeros = client.lambda_vector <= 1e-3 if lambda_zeros.sum() > 0: client.lambda_vector[lambda_zeros] = 1e-3 client.lambda_vector /= client.lambda_vector.sum() # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() log( 'This round communication time is: {}'.format( client.args.comm_time[-1]), client.args.debug) dist.barrier(group=client.all_clients_group) return
def train_and_validate(client): """The training scheme of Distributed Local SGD.""" log('start training and validation.', client.args.debug) if client.args.evaluate and client.args.graph.rank==0: # Do the training on the server and return do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') return tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.', client.args.debug) client.args.comm_time.append(0.0) # break until finish expected full epoch training. while True: # configure local step. for _input, _target in client.train_loader: client.model.train() # update local step. logging_load_time(tracker) # update local index and get local step client.args.local_index += 1 client.args.local_data_seen += len(_target) get_current_epoch(client.args) local_step = get_current_local_step(client.args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler) # load data _input, _target = load_data_batch(client.args, _input, _target, tracker) # inference and get current performance. client.optimizer.zero_grad() loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target) # compute gradient and do local SGD step. loss.backward() client.optimizer.step( apply_lr=True, apply_in_momentum=client.args.in_momentum, apply_out_momentum=False ) # logging locally. logging_computing(tracker, loss, performance, _input, lr) # evaluate the status. is_sync = client.args.local_index % local_step == 0 if client.args.epoch_ % 1 == 0: client.args.finish_one_epoch = True # sync if is_sync: log('Enter synching', client.args.debug) client.args.global_index += 1 # broadcast gradients to other nodes by using reduce_sum. client.model_server = aggregate_gradients(client.args, client.model_server, client.model, client.optimizer, is_sync) # evaluate the sync time logging_sync_time(tracker) # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() # finish one epoch training and to decide if we want to val our model. if client.args.finish_one_epoch: if client.args.epoch % client.args.eval_freq ==0 and client.args.graph.rank == 0: do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') dist.barrier(group=client.all_clients_group) # refresh the logging cache at the begining of each epoch. client.args.finish_one_epoch = False tracker = define_local_training_tracker() # determine if the training is finished. if is_stop(client.args): #Last Sync log('Enter synching', client.args.debug) client.args.global_index += 1 # broadcast gradients to other nodes by using reduce_sum. client.model_server = aggregate_gradients(client.args, client.model_server, client.model, client.optimizer, is_sync) print("Total number of samples seen on device {} is {}".format(client.args.graph.rank, client.args.local_data_seen)) if client.args.graph.rank == 0: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') return # display the logging info. logging_display_training(client.args, tracker) # reset load time for the tracker. tracker['start_load_time'] = time.time() # reshuffle the data. if client.args.reshuffle_per_epoch: log('reshuffle the dataset.', client.args.debug) del client.train_loader, client.test_loader gc.collect() log('reshuffle the dataset.', client.args.debug) client.load_local_dataset()
def train_and_validate_perfedme_centered(Clients, Server): log('start training and validation with Federated setting in a centered way.' ) # For Sparsified SGD tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.') # Number of communication rounds in federated setting should be defined for n_c in range(Server.args.num_comms): Server.args.rounds_comm += 1 Server.args.local_index += 1 Server.args.quant_error = 0.0 # Preset variables for this round of communication Server.zero_grad() Server.reset_tracker(Server.local_val_tracker) Server.reset_tracker(Server.global_val_tracker) if Server.args.fed_personal: Server.reset_tracker(Server.local_personal_val_tracker) Server.reset_tracker(Server.global_personal_val_tracker) # Configuring the devices for this round of communication log("Starting round {} of training".format(n_c + 1)) online_clients = set_online_clients_centered(Server.args) for oc in online_clients: Clients[oc].model.load_state_dict(Server.model.state_dict()) Clients[oc].args.rounds_comm = Server.args.rounds_comm local_steps = 0 is_sync = False do_validate_centered(Clients[oc].args, Server.model, Server.criterion, Server.metrics, Server.optimizer, Clients[oc].train_loader, Server.global_val_tracker, val=False, local=False) if Server.args.fed_personal: do_validate_centered(Clients[oc].args, Server.model, Server.criterion, Server.metrics, Server.optimizer, Clients[oc].val_loader, Server.global_personal_val_tracker, val=True, local=False) while not is_sync: for _input, _target in Clients[oc].train_loader: local_steps += 1 Clients[oc].model.train() Clients[oc].model_personal.train() # update local step. logging_load_time(tracker) # update local index and get local step Clients[oc].args.local_index += 1 Clients[oc].args.local_data_seen += len(_target) get_current_epoch(Clients[oc].args) local_step = get_current_local_step(Clients[oc].args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(Clients[oc].args, Clients[oc].optimizer_personal, Clients[oc].scheduler) # load data _input, _target = load_data_batch(Clients[oc].args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: is_sync = is_sync_fed(Clients[oc].args) break # inference and get current performance. Clients[oc].optimizer_personal.zero_grad() loss, performance = inference(Clients[oc].model_personal, Clients[oc].criterion, Clients[oc].metrics, _input, _target) # print("loss in rank {} is {}".format(oc,loss)) # compute gradient and do local SGD step. loss.backward() for client_param, personal_param in zip( Clients[oc].model.parameters(), Clients[oc].model_personal.parameters()): # loss += Clients[oc].args.perfedme_lambda * torch.norm(personal_param.data - client_param.data)**2 personal_param.grad.data += Clients[ oc].args.perfedme_lambda * (personal_param.data - client_param.data) Clients[oc].optimizer_personal.step( apply_lr=True, apply_in_momentum=Clients[oc].args.in_momentum, apply_out_momentum=False) if Clients[oc].args.local_index == 1: Clients[oc].optimizer.zero_grad() loss, performance = inference(Clients[oc].model, Clients[oc].criterion, Clients[oc].metrics, _input, _target) loss.backward() is_sync = is_sync_fed(Clients[oc].args) if Clients[oc].args.local_index % 5 == 0 or is_sync: log('Updating the local version of the global model', Clients[oc].args.debug) lr = adjust_learning_rate(Clients[oc].args, Clients[oc].optimizer, Clients[oc].scheduler) Clients[oc].optimizer.zero_grad() for client_param, personal_param in zip( Clients[oc].model.parameters(), Clients[oc].model_personal.parameters()): client_param.grad.data = Clients[ oc].args.perfedme_lambda * ( client_param.data - personal_param.data) Clients[oc].optimizer.step( apply_lr=True, apply_in_momentum=Clients[oc].args.in_momentum, apply_out_momentum=False) # reset load time for the tracker. tracker['start_load_time'] = time.time() # model_local = deepcopy(model_client) if is_sync: break do_validate_centered(Clients[oc].args, Clients[oc].model_personal, Clients[oc].criterion, Clients[oc].metrics, Clients[oc].optimizer_personal, Clients[oc].train_loader, Server.local_val_tracker, val=False, local=True) if Server.args.fed_personal: do_validate_centered(Clients[oc].args, Clients[oc].model_personal, Clients[oc].criterion, Clients[oc].metrics, Clients[oc].optimizer_personal, Clients[oc].val_loader, Server.local_personal_val_tracker, val=True, local=True) # Sync the model server based on model_clients tracker['start_sync_time'] = time.time() Server.args.global_index += 1 logging_sync_time(tracker) fedavg_aggregation_centered(Clients, Server, online_clients) # Log local performance log_validation_centered(Server.args, Server.local_val_tracker, val=False, local=True) if Server.args.fed_personal: log_validation_centered(Server.args, Server.local_personal_val_tracker, val=True, local=True) # Log server performance log_validation_centered(Server.args, Server.global_val_tracker, val=False, local=False) if Server.args.fed_personal: log_validation_centered(Server.args, Server.global_personal_val_tracker, val=True, local=False) # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() # validate the model at the server # if args.graph.rank == 0: # do_test(args, model_server, optimizer, criterion, metrics, test_loader) # do_validate_test(args, model_server, optimizer, criterion, metrics, test_loader) return
def log_test_centered(args, val_tracker): performance = [val_tracker[x].avg for x in ['top1', 'top5', 'losses']] log('Test at batch: {}. Epoch: {}. Process: {}. Prec@1: {:.3f} Prec@5: {:.3f} Loss: {:.3f} Comm: {}' .format(args.local_index, args.epoch, args.graph.rank, performance[0], performance[1], performance[2], args.rounds_comm), debug=args.debug)
def train_and_validate_federated(client): """The training scheme of Federated Learning systems. The basic model is FedAvg https://arxiv.org/abs/1602.05629 TODO: Merge different models under this method """ log('start training and validation with Federated setting.', client.args.debug) if client.args.evaluate and client.args.graph.rank == 0: # Do the training on the server and return do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') return # init global variable. tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.', client.args.debug) # Number of communication rounds in federated setting should be defined for n_c in range(client.args.num_comms): client.args.rounds_comm += 1 client.args.comm_time.append(0.0) # Configuring the devices for this round of communication log("Starting round {} of training".format(n_c + 1), client.args.debug) online_clients = set_online_clients(client.args) if (n_c == 0) and (0 not in online_clients): online_clients += [0] online_clients_server = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_group = dist.new_group(online_clients_server) if client.args.graph.rank in online_clients_server: if client.args.federated_type == 'scaffold': st = time.time() client.model_server, client.model_server_control = distribute_model_server_control( client.model_server, client.model_server_control, online_clients_group, src=0) client.args.comm_time[-1] += time.time() - st else: st = time.time() client.model_server = distribute_model_server( client.model_server, online_clients_group, src=0) client.args.comm_time[-1] += time.time() - st client.model.load_state_dict(client.model_server.state_dict()) local_steps = 0 if client.args.graph.rank in online_clients: is_sync = False while not is_sync: for _input, _target in client.train_loader: local_steps += 1 client.model.train() # update local step. logging_load_time(tracker) # update local index and get local step client.args.local_index += 1 client.args.local_data_seen += len(_target) get_current_epoch(client.args) local_step = get_current_local_step(client.args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler) # load data _input, _target = load_data_batch( client.args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: is_sync = is_sync_fed(client.args) break # inference and get current performance. client.optimizer.zero_grad() loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target) # compute gradient and do local SGD step. loss.backward() if client.args.federated_type == 'fedgate': # Update gradients with control variates for client_param, delta_param in zip( client.model.parameters(), client.model_delta.parameters()): client_param.grad.data -= delta_param.data elif client.args.federated_type == 'scaffold': for cp, ccp, scp in zip( client.model.parameters(), client.model_client_control.parameters(), client.model_server_control.parameters()): cp.grad.data += scp.data - ccp.data elif client.args.federated_type == 'fedprox': # Adding proximal gradients and loss for fedprox for client_param, server_param in zip( client.model.parameters(), client.model_server.parameters()): if client.args.graph.rank == 0: print( "distance norm for prox is:{}".format( torch.norm(client_param.data - server_param.data))) loss += client.args.fedprox_mu / 2 * torch.norm( client_param.data - server_param.data) client_param.grad.data += client.args.fedprox_mu * ( client_param.data - server_param.data) if 'robust' in client.args.arch: client.model.noise.grad.data *= -1 client.optimizer.step( apply_lr=True, apply_in_momentum=client.args.in_momentum, apply_out_momentum=False) if 'robust' in client.args.arch: if torch.norm(client.model.noise.data) > 1: client.model.noise.data /= torch.norm( client.model.noise.data) # logging locally. # logging_computing(tracker, loss_v, performance_v, _input, lr) # display the logging info. # logging_display_training(args, tracker) # reset load time for the tracker. tracker['start_load_time'] = time.time() # model_local = deepcopy(model_client) is_sync = is_sync_fed(client.args) if is_sync: break else: log("Offline in this round. Waiting on others to finish!", client.args.debug) # Validate the local models befor sync do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train', local=True) if client.args.fed_personal: do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation', local=True) # Sync the model server based on client models log('Enter synching', client.args.debug) tracker['start_sync_time'] = time.time() client.args.global_index += 1 if client.args.federated_type == 'fedgate': client.model_server, client.model_delta = fedgate_aggregation( client.args, client.model_server, client.model, client.model_delta, client.model_memory, online_clients_group, online_clients, client.optimizer, lr, local_steps) elif client.args.federated_type == 'scaffold': client.model_server, client.model_client_control, client.model_server_control = scaffold_aggregation( client.args, client.model_server, client.model, client.model_server_control, client.model_client_control, online_clients_group, online_clients, client.optimizer, lr, local_steps) elif client.args.federated_type == 'qsparse': client.model_server = qsparse_aggregation( client.args, client.model_server, client.model, online_clients_group, online_clients, client.optimizer, client.model_memory) else: client.model_server = fedavg_aggregation( client.args, client.model_server, client.model, online_clients_group, online_clients, client.optimizer) # evaluate the sync time logging_sync_time(tracker) # Do the validation on the server model do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train') if client.args.fed_personal: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation') # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() # validate the model at the server if client.args.graph.rank == 0: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.test_loader, online_clients_group, data_mode='test') log( 'This round communication time is: {}'.format( client.args.comm_time[-1]), client.args.debug) else: log("Offline in this round. Waiting on others to finish!", client.args.debug) dist.barrier(group=client.all_clients_group) return
def partition_dataset(args, shuffle, dataset_type, Partitioner=None, return_partitioner=False): """ Given a dataset, partition it. """ if Partitioner is None: dataset = get_dataset(args, args.data, args.data_dir, split=dataset_type) else: dataset = Partitioner.data # Federated Dataset with Validation # if args.data in ['emnist'] and args.fed_personal and dataset_type=='train': # dataset, dataset_val = dataset batch_size = args.batch_size world_size = args.graph.n_nodes # partition data. if args.partition_data and dataset_type == 'train': if args.iid_data: if args.data in [ 'emnist', 'emnist_full', 'synthetic', 'shakespeare' ]: raise ValueError( 'The dataset {} does not have a structure for iid distribution of data' .format(args.data)) if args.growing_batch_size: pt = 'growing' else: pt = 'normal' else: if args.data not in [ 'mnist', 'fashion_mnist', 'emnist', 'emnist_full', 'cifar10', 'cifar100', 'adult', 'synthetic', 'shakespeare' ]: raise NotImplementedError( """Non-iid distribution of data for dataset {} is not implemented. Set the distribution to iid.""".format(args.data)) if args.growing_batch_size: raise ValueError( 'Growing Minibatch Size is not designed for non-iid data distribution' ) else: pt = 'noniid' if Partitioner is None: if return_partitioner: data_to_load, Partitioner = partitioner( args, dataset, shuffle, world_size, partition_type=pt, return_partitioner=True) else: data_to_load = partitioner(args, dataset, shuffle, world_size, partition_type=pt) log('Make {} data partitions and use the subdata.'.format(pt), args.debug) else: data_to_load = Partitioner.use(args.graph.rank) log('use the loaded partitioner to load the data.', args.debug) else: if Partitioner is not None: raise ValueError( 'Partitioner is provided but data partition method is not defined!' ) # If test dataset needs to be partitioned this should be changed data_to_load = dataset log('used whole data.', args.debug) # Log stats about the dataset to laod if dataset_type == 'train': args.train_dataset_size = len(dataset) log( ' We have {} samples for {}, \ load {} data for process (rank {}).'.format( len(dataset), dataset_type, len(data_to_load), args.graph.rank), args.debug) else: args.val_dataset_size = len(dataset) log( ' We have {} samples for {}, \ load {} val data for process (rank {}).'.format( len(dataset), dataset_type, len(data_to_load), args.graph.rank), args.debug) # Batching if (args.growing_batch_size) and (dataset_type == 'train'): batch_sampler = GrowingMinibatchSampler( data_source=data_to_load, num_epochs=args.num_epochs, num_iterations=args.num_iterations, base_batch_size=args.base_batch_size, max_batch_size=args.max_batch_size) args.num_epochs = batch_sampler.num_epochs args.num_iterations = batch_sampler.num_iterations args.total_data_size = len(data_to_load) args.num_samples_per_epoch = len(data_to_load) / args.num_epochs data_loader = torch.utils.data.DataLoader(data_to_load, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=args.pin_memory) log( 'we have {} batches for {} for rank {}.'.format( len(data_loader), dataset_type, args.graph.rank), args.debug) elif dataset_type == 'train': # Adjust stopping criteria if args.stop_criteria == 'epoch': args.num_iterations = int( len(data_to_load) * args.num_epochs / batch_size) else: args.num_epochs = int(args.num_iterations * batch_size / len(data_to_load)) args.total_data_size = len(data_to_load) * args.num_epochs args.num_samples_per_epoch = len(data_to_load) # Generate validation data part if args.fed_personal: if args.data in ['emnist', 'emnist_full', 'shakespeare']: data_to_load_train = data_to_load data_to_load_val = get_dataset(args, args.data, args.data_dir, split='val') if args.federated_type == "perfedavg": if args.data in ['emnist', 'emnist_full', 'shakespeare']: #TODO: make this size a paramter val_size = int(0.1 * len(data_to_load)) data_to_load_train, data_to_load_val1 = torch.utils.data.random_split( data_to_load, [len(data_to_load) - val_size, val_size]) else: val_size = int(0.1 * len(data_to_load)) data_to_load_train, data_to_load_val, data_to_load_val1 = torch.utils.data.random_split( data_to_load, [ len(data_to_load) - 3 * val_size, 2 * val_size, val_size ]) data_loader_val1 = torch.utils.data.DataLoader( data_to_load_val1, batch_size=batch_size, shuffle=True, num_workers=5, pin_memory=args.pin_memory, drop_last=False) else: if args.data not in ['emnist', 'emnist_full', 'shakespeare']: val_size = int(0.2 * len(data_to_load)) data_to_load_train, data_to_load_val = torch.utils.data.random_split( data_to_load, [len(data_to_load) - val_size, val_size]) # Generate data loaders data_loader_val = torch.utils.data.DataLoader( data_to_load_val, batch_size=batch_size, shuffle=True, num_workers=5, pin_memory=args.pin_memory, drop_last=False) data_loader_train = torch.utils.data.DataLoader( data_to_load_train, batch_size=batch_size, shuffle=True, num_workers=5, pin_memory=args.pin_memory, drop_last=False) if args.federated_type == 'perfedavg': data_loader = [ data_loader_train, data_loader_val, data_loader_val1 ] else: data_loader = [data_loader_train, data_loader_val] log( 'we have {} batches for {} for rank {}.'.format( len(data_loader[0]), 'train', args.graph.rank), args.debug) log( 'we have {} batches for {} for rank {}.'.format( len(data_loader[1]), 'val', args.graph.rank), args.debug) else: data_loader = torch.utils.data.DataLoader( data_to_load, batch_size=batch_size, shuffle=True, num_workers=5, pin_memory=args.pin_memory, drop_last=False) log( 'we have {} batches for {} for rank {}.'.format( len(data_loader), 'train', args.graph.rank), args.debug) else: data_loader = torch.utils.data.DataLoader(data_to_load, batch_size=batch_size, shuffle=False, num_workers=5, pin_memory=args.pin_memory, drop_last=False) log( 'we have {} batches for {} for rank {}.'.format( len(data_loader), dataset_type, args.graph.rank), args.debug) if return_partitioner: return data_loader, Partitioner else: return data_loader
def train_and_validate_federated_apfl(client): """The training scheme of Personalized Federated Learning. Official implementation for https://arxiv.org/abs/2003.13461 """ log('start training and validation with Federated setting.', client.args.debug) if client.args.evaluate and client.args.graph.rank == 0: # Do the testing on the server and return do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') return tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.', client.args.debug) # Number of communication rounds in federated setting should be defined for n_c in range(client.args.num_comms): client.args.rounds_comm += 1 client.args.comm_time.append(0.0) # Configuring the devices for this round of communication # TODO: not make the server rank hard coded log("Starting round {} of training".format(n_c), client.args.debug) online_clients = set_online_clients(client.args) if (n_c == 0) and (0 not in online_clients): online_clients += [0] online_clients_server = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_group = dist.new_group(online_clients_server) if client.args.graph.rank in online_clients_server: client.model_server = distribute_model_server(client.model_server, online_clients_group, src=0) client.model.load_state_dict(client.model_server.state_dict()) if client.args.graph.rank in online_clients: is_sync = False ep = -1 # counting number of epochs while not is_sync: ep += 1 for i, (_input, _target) in enumerate(client.train_loader): client.model.train() # update local step. logging_load_time(tracker) # update local index and get local step client.args.local_index += 1 client.args.local_data_seen += len(_target) get_current_epoch(client.args) local_step = get_current_local_step(client.args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler) # load data _input, _target = load_data_batch( client.args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: is_sync = is_sync_fed(client.args) break # inference and get current performance. client.optimizer.zero_grad() loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target) # compute gradient and do local SGD step. loss.backward() client.optimizer.step( apply_lr=True, apply_in_momentum=client.args.in_momentum, apply_out_momentum=False) client.optimizer.zero_grad() client.optimizer_personal.zero_grad() loss_personal, performance_personal = inference_personal( client.model_personal, client.model, client.args.fed_personal_alpha, client.criterion, client.metrics, _input, _target) # compute gradient and do local SGD step. loss_personal.backward() client.optimizer_personal.step( apply_lr=True, apply_in_momentum=client.args.in_momentum, apply_out_momentum=False) # update alpha if client.args.fed_adaptive_alpha and i == 0 and ep == 0: client.args.fed_personal_alpha = alpha_update( client.model, client.model_personal, client.args.fed_personal_alpha, lr) #0.1/np.sqrt(1+args.local_index)) average_alpha = client.args.fed_personal_alpha average_alpha = global_average( average_alpha, client.args.graph.n_nodes, group=online_clients_group) log("New alpha is:{}".format(average_alpha.item()), client.args.debug) # logging locally. # logging_computing(tracker, loss, performance, _input, lr) # display the logging info. # logging_display_training(args, tracker) # reset load time for the tracker. tracker['start_load_time'] = time.time() is_sync = is_sync_fed(client.args) if is_sync: break else: log("Offline in this round. Waiting on others to finish!", client.args.debug) do_validate(client.args, client.model, client.optimizer_personal, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train', personal=True, model_personal=client.model_personal, alpha=client.args.fed_personal_alpha) if client.args.fed_personal: do_validate(client.args, client.model, client.optimizer_personal, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation', personal=True, model_personal=client.model_personal, alpha=client.args.fed_personal_alpha) # Sync the model server based on model_clients log('Enter synching', client.args.debug) tracker['start_sync_time'] = time.time() client.args.global_index += 1 client.model_server = fedavg_aggregation( client.args, client.model_server, client.model, online_clients_group, online_clients, client.optimizer) # evaluate the sync time logging_sync_time(tracker) # Do the validation on the server model do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train') if client.args.fed_personal: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation') # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() # validate the models at the test data if client.args.fed_personal_test: do_validate(client.args, client.model_client, client.optimizer_personal, client.criterion, client.metrics, client.test_loader, online_clients_group, data_mode='test', personal=True, model_personal=client.model_personal, alpha=client.args.fed_personal_alpha) elif client.args.graph.rank == 0: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.test_loader, online_clients_group, data_mode='test') else: log("Offline in this round. Waiting on others to finish!", client.args.debug) dist.barrier(group=client.all_clients_group)