class FLServer(object): MIN_NUM_WORKERS = 4 # MIN_NUM_WORKERS = 10 # MAX_NUM_ROUNDS = 100 MAX_NUM_ROUNDS = 36 NUM_CLIENTS_CONTACTED_PER_ROUND = 4 # NUM_CLIENTS_CONTACTED_PER_ROUND = 10 ROUNDS_BETWEEN_VALIDATIONS = 1 # LENET5_MODEL_FEMNIST, "127.0.0.1", 5000, gpu, output=args.output, aggregation=args.aggregation def __init__(self, global_model, aggregation="normal_atten"): # FLServer(GlobalModel_MNIST_CNN, "127.0.0.1", 5000, gpu) # os.environ['CUDA_VISIBLE_DEVICES'] = '%d'%gpu self.global_model = global_model() self.ready_client_sids = set() # self.host = host # self.port = port self.client_resource = {} self.wait_time = 0 self.model_id = str(uuid.uuid4()) self.aggregation = aggregation self.attention_mechanism = Attention() ##### # training states self.current_round = -1 # -1 for not yet started self.current_round_client_updates = [] self.eval_client_updates = [] ##### self.invalid_tolerate = 0 def handle_client_update(self, data): self.current_round_client_updates = data uploaded_weights = [ x['weights'] for x in self.current_round_client_updates ] if self.aggregation in ["normal_atten", "atten", "rule_out"]: if self.aggregation == "normal_atten": # Same atttention print("### Update with normal attention mechanism! ###") attention = np.tile(np.array([1.0]), len(uploaded_weights)) else: print("### Update with calculated attention mechanism! ###") # attention = self.attention_mechanism.cal_weights(np.array( uploaded_weights )) attention = self.attention_mechanism.cal_weights( np.array(uploaded_weights)) print("old attention", attention) # type(attention): <class 'numpy.ndarray'> shape (10, ) if self.aggregation == "rule_out": # Rule out new_attention = np.zeros(attention.shape) for idx in range(len(attention)): if attention[idx] > np.mean(attention): new_attention[idx] = 1.0 attention = new_attention print("new attention", attention) attack_label = [ "{}_{}".format(x['attack_mode'], x['assigned_label']) for x in self.current_round_client_updates ] self.global_model.update_weights_with_attention( uploaded_weights, [x['train_size'] for x in self.current_round_client_updates], attention, attack_label) else: print("### Update with baseline methods! ###") self.global_model.update_weights_baseline( uploaded_weights, [x['train_size'] for x in self.current_round_client_updates], self.aggregation) aggr_train_loss, aggr_train_accuracy = self.global_model.aggregate_train_loss_accuracy( [x['train_loss'] for x in self.current_round_client_updates], [x['train_accuracy'] for x in self.current_round_client_updates], [x['train_size'] for x in self.current_round_client_updates], self.current_round) if self.global_model.prev_train_loss is not None and self.global_model.prev_train_loss < aggr_train_loss: self.invalid_tolerate = self.invalid_tolerate + 1 else: self.invalid_tolerate = 0 self.global_model.prev_train_loss = aggr_train_loss def handle_client_eval(self, data): if self.eval_client_updates is None: return self.eval_client_updates = data # tolerate 30% unresponsive clients aggr_test_loss, aggr_test_accuracy = self.global_model.aggregate_loss_accuracy( [x['test_loss'] for x in self.eval_client_updates], [x['test_accuracy'] for x in self.eval_client_updates], [x['test_size'] for x in self.eval_client_updates], ) print("\n--------Aggregating test loss---------\n") print("aggr_test_loss", aggr_test_loss) print("aggr_test_accuracy", aggr_test_accuracy) print("best model at round ", self.global_model.best_round, ", get the best loss ", self.global_model.best_loss) print("== done ==") self.eval_client_updates = None # special value, forbid evaling again # Note: we assume that during training the #workers will be >= MIN_NUM_WORKERS def train_next_round(self, clients): self.current_round += 1 # buffers all client updates self.current_round_client_updates = [] print("\n ### Round ", self.current_round, "### \n") # print("request updates from", client_sids_selected) # by default each client cnn is in its own "room" # path = os.path.join("../",'saved_weights', 'iteration_' + str(self.current_round)) # if not os.path.exists(path): # os.makedirs(path) # np.save( os.path.join(path, "server_weights"), self.global_model.current_weights) train_next_round_info = { 'model_id': self.model_id, 'round_number': self.current_round, # 'current_weights': obj_to_pickle_string(self.global_model.current_weights), 'current_weights': self.global_model.current_weights, 'weights_format': 'not pickle', 'run_validation': self.current_round % FLServer.ROUNDS_BETWEEN_VALIDATIONS == 0, } return train_next_round_info def stop_and_eval(self): self.eval_client_updates = [] for rid in self.ready_client_sids: #emit('stop_and_eval', { # 'model_id': self.model_id, # 'current_weights': obj_to_pickle_string(self.global_model.current_weights), # 'weights_format': 'pickle' # }, room=rid) self.emit( 'stop_and_eval', { 'model_id': self.model_id, # 'current_weights': obj_to_pickle_string(self.global_model.best_weight), 'current_weights': self.global_model.best_weight, 'weights_format': 'not pickle' }, room=rid)