def handle_message_receive_model_from_client(self, msg_params): sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) # local参数的数量 公式中的ni local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) # aggregator 会记录客户端发过来的参数和客户端id self.aggregator.add_local_trained_result(sender_id - 1, model_params, local_sample_number) # 检查是否所有人都发送完毕了 b_all_received = self.aggregator.check_whether_all_receive() logging.info("b_all_received = " + str(b_all_received)) if b_all_received: # 启动aggregation global_model_params = self.aggregator.aggregate() self.aggregator.test_on_all_clients(self.round_idx) # start the next round self.round_idx += 1 if self.round_idx == self.round_num: self.finish() return # sampling clients 下一轮采样 (用户量比较大的情况) uniform sampling client_indexes = self.aggregator.client_sampling(self.round_idx, self.args.client_num_in_total, self.args.client_num_per_round) print("size = %d" % self.size) if self.args.is_mobile == 1: print("transform_tensor_to_list") global_model_params = transform_tensor_to_list(global_model_params) for receiver_id in range(1, self.size): self.send_message_sync_model_to_client(receiver_id, global_model_params, client_indexes[receiver_id-1])
def train(self): self.model.to(self.device) # change to train mode self.model.train() epoch_loss = [] for epoch in range(self.args.epochs): batch_loss = [] for batch_idx, (images, labels) in enumerate(self.train_local): # logging.info(images.shape) images, labels = images.to(self.device), labels.to(self.device) self.optimizer.zero_grad() log_probs = self.model(images) loss = self.criterion(log_probs, labels) loss.backward() self.optimizer.step() batch_loss.append(loss.item()) if len(batch_loss) > 0: epoch_loss.append(sum(batch_loss) / len(batch_loss)) logging.info( '(client {}. Local Training Epoch: {} \tLoss: {:.6f}'. format(self.client_index, epoch, sum(epoch_loss) / len(epoch_loss))) weights = self.model.cpu().state_dict() # transform Tensor to list if self.args.is_mobile == 1: weights = transform_tensor_to_list(weights) return weights, self.local_sample_number
def handle_message_receive_model_from_client(self, msg_params): sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) self.aggregator.add_local_trained_result(sender_id - 1, model_params, local_sample_number) b_all_received = self.aggregator.check_whether_all_receive() logging.info("b_all_received = " + str(b_all_received)) if b_all_received: global_model_params = self.aggregator.aggregate() self.aggregator.test_on_all_clients(self.round_idx) # measure the target task accuracy self.aggregator.test_target_accuracy(self.round_idx) # start the next round self.round_idx += 1 if self.round_idx == self.round_num: self.finish() return # sampling clients client_indexes = self.aggregator.client_sampling( self.round_idx, self.args.client_num_in_total, self.args.client_num_per_round) print("size = %d" % self.size) if self.args.is_mobile == 1: print("transform_tensor_to_list") global_model_params = transform_tensor_to_list( global_model_params) for receiver_id in range(1, self.size): self.send_message_sync_model_to_client( receiver_id, global_model_params, client_indexes[receiver_id - 1])
def train(self): try: if self.args.client_num_in_total == self.args.client_num_per_round: #apply residue weights = self.model.cpu().state_dict() for k in self.quant_residue.keys(): #if 'bias' in k: # logging.info(str(k)) if 'weight' in k and 'bn' not in k: weights[k] = copy.deepcopy(weights[k] + self.quant_residue[k]) elif 'bias' in k and 'bn' not in k: weights[k] = copy.deepcopy(weights[k] + self.quant_residue[k]) self.update_model(weights) else: logging.info('residue compensation skipped') #logging.info('Before training starts..............') self.model.to(self.device) # change to train mode self.model.train() #logging.info('model init done !!!!!!!!!!!!!') epoch_loss = [] for epoch in range(self.args.epochs): batch_loss = [] try: for batch_idx, (x, labels) in enumerate(self.train_local): #logging.info('Beginning of training on one batch !!!!!!!!!!!!!!!!!!!!!!!!!!') _iters = self.glb_epoch * len( self.train_local) + batch_idx #cyclic_period = int((self.args.comm_round * self.args.epochs * len(self.train_local)) // self.cyclic_period) cyclic_period = int( (self.args.comm_round * self.args.epochs * len(self.train_local)) // 16) #logging.info("cyclic period: "+ str(cyclic_period)) #cyclic_period = int((self.args.comm_round * self.args.epochs * len(self.train_local)) // self.cyclic_period)*2 #if self.glb_epoch % 2 == 0: # self.args.cyclic_num_bits_schedule=[8,32] #else: # self.args.cyclic_num_bits_schedule=[4,32] if (self.args.cyclic_num_bits_schedule[0] == 0 or self.args.cyclic_num_bits_schedule[1] == 0): num_bits = 0 #elif self.glb_epoch>=self.args.comm_round-3: # num_bits = 8 # elif self.glb_epoch>=self.args.comm_round-10: # #self.args.inference_bits=32 # self.args.cyclic_num_bits_schedule=[4,32] # cyclic_period = int((self.args.comm_round * len(self.train_local)) // 64) # offset=self.offset_finder(self.args.cyclic_num_bits_schedule[1],cyclic_period,len(self.train_local),self.lr_steps) # offseted_iters=min(max(0,_iters-offset),self.lr_steps) # num_bits = self.cyclic_adjust_precision(offseted_iters, cyclic_period) else: offset = self.offset_finder( self.args.cyclic_num_bits_schedule[1], cyclic_period, len(self.train_local), self.lr_steps) offseted_iters = min(max(0, _iters - offset), self.lr_steps) num_bits = self.cyclic_adjust_precision( offseted_iters, cyclic_period) # num_bits = self.cyclic_adjust_precision(_iters, cyclic_period) #if num_bits == 32: # num_bits =0 #logging.info('Right before data moving starts!!!!!') # logging.info(images.shape) x, labels = x.to(self.device), labels.to(self.device) self.optimizer.zero_grad() # if epoch < 10 and self.first_run: # log_probs = self.model(x, num_bits=0) # else: #logging.info('Right before training started !!!!!!!!!!!!!!!!!') log_probs = self.model(x, num_bits=num_bits) loss = self.criterion(log_probs, labels) loss.backward() if self.args.client_num_in_total == self.args.client_num_per_round: g_norm = nn.utils.clip_grad_norm_( self.model.parameters(), 0.9, 'inf') else: g_norm = nn.utils.clip_grad_norm_( self.model.parameters(), 0.9, 'inf') #logging.info(str(g_norm)) self.optimizer.step() batch_loss.append(loss.item()) self.scheduler.step() #logging.info('End of training on one batch !!!!!!!!!!!!!!!!!!!!!!!!!!') self.glb_epoch += 1 if len(batch_loss) > 0: epoch_loss.append(sum(batch_loss) / len(batch_loss)) logging.info( '(client {}. Local Training Epoch: {} \tLoss: {:.6f}' .format(self.client_index, epoch, sum(epoch_loss) / len(epoch_loss))) except Exception as e: logging.info(str(e)) # if self.comm_round < (self.args.lr_decay_step_size+1): # self.scheduler.step() self.comm_round += 1 for g in self.optimizer.param_groups: logging.info("===current learning rate===: " + str(g['lr'])) break logging.info("========= number of batches =======: " + str(batch_idx + 1)) #logging.info("========= Transmitted bits ========: "+str(num_bits)) self.first_run = False weights = self.model.cpu().state_dict() # transform Tensor to list if self.args.is_mobile == 1: weights = transform_tensor_to_list(weights) latent_weight = copy.deepcopy(weights) logging.info('Quantizing model') if num_bits != 0: for k in weights.keys(): #if 'bias' in k: # logging.info(str(k)) try: if 'weight' in k and 'bn' not in k and 'downsample.1' not in k: print(k) weight_qparams = calculate_qparams( copy.deepcopy(weights[k]), num_bits=num_bits, flatten_dims=(1, -1), reduce_dim=None) weights[k] = quantize(copy.deepcopy(weights[k]), qparams=weight_qparams) self.quant_residue[k] = copy.deepcopy( latent_weight[k] - weights[k]) elif 'bias' in k and 'bn' not in k and 'downsample.1' not in k: weights[k] = quantize(copy.deepcopy(weights[k]), num_bits=num_bits, flatten_dims=(0, -1)) self.quant_residue[k] = copy.deepcopy( latent_weight[k] - weights[k]) except Exception as e: logging.info(str(k)) logging.info(str(e)) exc_type, exc_obj, exc_tb = sys.exc_info() fname = os.path.split( exc_tb.tb_frame.f_code.co_filename)[1] logging.info( str(exc_type) + " " + str(fname) + " " + str(exc_tb.tb_lineno)) logging.info('Sending model') return weights, self.local_sample_number, num_bits, latent_weight except Exception as e: logging.info(str(e)) exc_type, exc_obj, exc_tb = sys.exc_info() fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] logging.info( str(exc_type) + " " + str(fname) + " " + str(exc_tb.tb_lineno))