def aggregate(self): start_time = time.time() model_list = [] training_num = 0 for idx in range(self.worker_num): if self.args.is_mobile == 1: self.model_dict[idx] = transform_list_to_tensor( self.model_dict[idx]) model_list.append( (self.sample_num_dict[idx], self.model_dict[idx])) training_num += self.sample_num_dict[idx] logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict))) # logging.info("################aggregate: %d" % len(model_list)) (num0, averaged_params) = model_list[0] for k in averaged_params.keys(): for i in range(0, len(model_list)): local_sample_number, local_model_params = model_list[i] w = local_sample_number / training_num if i == 0: averaged_params[k] = local_model_params[k] * w else: averaged_params[k] += local_model_params[k] * w # update the global model which is cached at the server side self.model.load_state_dict(averaged_params) end_time = time.time() logging.info("aggregate time cost: %d" % (end_time - start_time)) return averaged_params
def handle_message_init(self, msg_params): global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) if self.args.is_mobile == 1: global_model_params = transform_list_to_tensor(global_model_params) self.trainer.update_model(global_model_params) self.trainer.update_dataset(int(client_index)) self.round_idx = 0 self.__train()
def aggregate(self): start_time = time.time() model_list = [] training_num = 0 for idx in range(self.worker_num): if self.args.is_mobile == 1: self.model_dict[idx] = transform_list_to_tensor( self.model_dict[idx]) # conduct the defense here: local_sample_number, local_model_params = self.sample_num_dict[ idx], self.model_dict[idx] if self.robust_aggregator.defense_type in ("norm_diff_clipping", "weak_dp"): clipped_local_state_dict = self.robust_aggregator.norm_diff_clipping( local_model_params, self.model.state_dict()) else: raise NotImplementedError("Non-supported Defense type ... ") model_list.append((local_sample_number, clipped_local_state_dict)) training_num += self.sample_num_dict[idx] logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict))) # logging.info("################aggregate: %d" % len(model_list)) (num0, averaged_params) = model_list[0] for k in averaged_params.keys(): for i in range(0, len(model_list)): local_sample_number, local_model_params = model_list[i] w = local_sample_number / training_num local_layer_update = local_model_params[k] if self.robust_aggregator.defense_type == "weak_dp": if is_weight_param(k): local_layer_update = self.robust_aggregator.add_noise( local_layer_update, self.device) if i == 0: averaged_params[k] = local_model_params[k] * w else: averaged_params[k] += local_model_params[k] * w # update the global model which is cached at the server side self.model.load_state_dict(averaged_params) end_time = time.time() logging.info("aggregate time cost: %d" % (end_time - start_time)) return averaged_params
def handle_message_receive_model_from_server(self, msg_params): logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) if self.args.is_mobile == 1: model_params = transform_list_to_tensor(model_params) self.trainer.update_model(model_params) self.trainer.update_dataset(int(client_index)) self.round_idx += 1 self.__train() if self.round_idx == self.num_rounds - 1: self.finish()