class FedAvgServer(object): """ This class implements the server-side of the FedAvg algorithm using the dc_federated.backend package. Parameters ---------- global_model_trainer: FedAvgModelTrainer The name of the python model-class for this problem. update_lim: int Number of unique updates that needs to be received before the last global update before we update the global model. key_list_file: str The list of public keys of valid workers. No authentication is performed if file not given. server_host_ip: str (default None) The hostname or IP address the server will bind to. If not given, it will default to the machine IP. server_port: int (default 8080) The port at which the server should listen to. ssl_enabled: bool (default False) Enable SSL/TLS for server/workers communications. ssl_keyfile: str Must be a valid path to the key file. This is mandatory if ssl_enabled is True, ignored otherwise. ssl_certfile: str Must be a valid path to the certificate. This is mandatory if ssl_enabled is True, ignored otherwise. """ def __init__(self, global_model_trainer, key_list_file, update_lim=10, server_host_ip=None, server_port=8080, ssl_enabled=False, ssl_keyfile=None, ssl_certfile=None): logger.info( f"Initializing FedAvg server for model class {global_model_trainer.get_model().__class__.__name__}") self.worker_updates = {} self.global_model_trainer = global_model_trainer self.update_lim = update_lim self.last_global_model_update_timestamp = datetime(1980, 10, 10) self.server = DCFServer( register_worker_callback=self.register_worker, unregister_worker_callback=self.unregister_worker, return_global_model_callback=self.return_global_model, is_global_model_most_recent=self.is_global_model_most_recent, receive_worker_update_callback=self.receive_worker_update, server_mode_safe=key_list_file is not None, load_last_session_workers=False, key_list_file=key_list_file, server_host_ip=server_host_ip, server_port=server_port, ssl_enabled=ssl_enabled, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, model_check_interval = 1 ) self.unique_updates_since_last_agg = 0 self.iteration = 0 self.model_version = 0 def register_worker(self, worker_id): """ Register the given worker_id by initializing its update to None. Parameters ---------- worker_id: int The id of the new worker. """ logger.info(f"Registered worker {worker_id[0:WID_LEN]}") self.worker_updates[worker_id] = None def unregister_worker(self, worker_id): """ Unregister the given worker_id by removing it from updates. Parameters ---------- worker_id: int The id of the worker to be removed. """ logger.info(f"Unregistered worker {worker_id[0:WID_LEN]}") self.worker_updates.pop(worker_id) def return_global_model(self): """ Serializes the current global torch model, puts it in the proper dictionary, and sends it back. Returns ---------- dict: A dictionary with keys: GLOBAL_MODEL: serialized global model. GLOBAL_MODEL_VERSION: version of the global model """ model_data = io.BytesIO() torch.save(self.global_model_trainer.get_model(), model_data) return { GLOBAL_MODEL: model_data.getvalue(), GLOBAL_MODEL_VERSION: self.model_version } def is_global_model_most_recent(self, model_version): """ Returns a default model update time of 2018/10/10. Parameters ---------- model_version: int The version of most recent global model that the worker has. Returns ---------- str: String format of the last model update time. """ return self.model_version == model_version def receive_worker_update(self, worker_id, model_update): """ Given an update for a worker, adds its update to the dictionary of updates. It also agg_model() to update the global model if necessary. Returns ---------- str: String format of the last model update time. """ if worker_id in self.worker_updates: # update the number of unique updates received if self.worker_updates[worker_id] is None or \ self.worker_updates[worker_id][0] < self.last_global_model_update_timestamp: self.unique_updates_since_last_agg += 1 update_size, model_bytes = msgpack.unpackb(model_update) self.worker_updates[worker_id] = ( datetime.now(), update_size, torch.load(io.BytesIO(model_bytes)) ) logger.info(f"Model update from worker {worker_id[0:WID_LEN]} accepted.") if self.agg_model(): self.global_model_trainer.test() return f"Update received for worker {worker_id[0:WID_LEN]}" else: logger.warning( f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update.") return f"Please register before sending an update." def agg_model(self): """ Updates the global model by aggregating all the most recent updates from the workers, assuming that the number of unique updates received since the last global model update is above the threshold. """ if self.unique_updates_since_last_agg < self.update_lim: return False logger.info("Updating the global model.\n") def agg_params(key, state_dicts, update_sizes): agg_val = state_dicts[0][key] * update_sizes[0] for sd, sz in zip(state_dicts[1:], update_sizes[1:]): agg_val = agg_val + sd[key] * sz agg_val = agg_val / sum(update_sizes) return torch.tensor(agg_val.cpu().clone().numpy()) # gather the model-updates to use for the update state_dicts_to_update_with = [] update_sizes = [] # each item in the worker_updates dictionary contains a # (timestamp update, update-size, model) for wi in self.worker_updates: if self.worker_updates[wi][0] > self.last_global_model_update_timestamp: state_dicts_to_update_with.append( self.worker_updates[wi][2].state_dict()) update_sizes.append(self.worker_updates[wi][1]) # now update the global model global_model_dict = OrderedDict() for key in state_dicts_to_update_with[0].keys(): global_model_dict[key] = agg_params( key, state_dicts_to_update_with, update_sizes) self.global_model_trainer.load_model_from_state_dict(global_model_dict) self.last_global_model_update_timestamp = datetime.now() self.unique_updates_since_last_agg = 0 self.iteration += 1 self.model_version += 1 return True def start(self): self.server.start_server()
class ExampleGlobalModel(object): """ This is a simple class that illustrates how the DCFServer class may be used to implement a federated global model. For testing purposes, it writes all the models it creates and receives to disk. """ def __init__(self): self.worker_updates = {} self.global_model = ExampleModelClass() with open("egm_global_model.torch", 'wb') as f: torch.save(self.global_model, f) self.global_model_version = 0 self.server = DCFServer( register_worker_callback=self.register_worker, unregister_worker_callback=self.unregister_worker, return_global_model_callback=self.return_global_model, is_global_model_most_recent=self.is_global_model_most_recent, receive_worker_update_callback=self.receive_worker_update, server_mode_safe=False, key_list_file=None, load_last_session_workers=False) def register_worker(self, worker_id): """ Register the given worker_id by initializing its update to None. Parameters ---------- worker_id: int The id of the new worker. """ logger.info( f"Example Global Model: Registering worker {worker_id[0:WID_LEN]}") self.worker_updates[worker_id] = None def unregister_worker(self, worker_id): """ Unregister the given worker_id by removing it from updates. Parameters ---------- worker_id: int The id of the worker to be removed. """ logger.info( f"Example Global Model: Unregistering worker {worker_id[0:WID_LEN]}" ) self.worker_updates.pop(worker_id) def return_global_model(self): """ Serializes the current global torch model and sends it back to the worker. Returns ---------- dict: The model dictionary as per the specification in DCFSever """ logger.info(f"Example Global Model: returning global model") model_data = io.BytesIO() torch.save(self.global_model, model_data) return create_model_dict(model_data.getvalue(), self.global_model_version) def is_global_model_most_recent(self, model_version): """ Returns a default model update time of 2018/10/10. Parameter --------- model_version: int Returns ---------- str: String format of the last model update time. """ logger.info( f"Example Global Model: checking if model version is most recent.") return self.global_model_version == model_version def receive_worker_update(self, worker_id, model_update): """ Given an update for a worker, adds the the update to the list of updates. Returns ---------- str: String format of the last model update time. """ if worker_id in self.worker_updates: self.worker_updates[worker_id] = \ torch.load(io.BytesIO(model_update)) logger.info( f"Model update received from worker {worker_id[0:WID_LEN]}") logger.info(self.worker_updates[worker_id]) with open(f"egm_worker_update_{worker_id}.torch", 'wb') as f: torch.save(self.worker_updates[worker_id], f) self.global_model_version += 1 return f"Update received for worker {worker_id[0:WID_LEN]}" else: return f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update!!" def start(self): self.server.start_server()