def __init__(self): self.worker_updates = {} self.global_model = ExampleModelClass() with open("egm_global_model.torch", 'wb') as f: torch.save(self.global_model, f) self.global_model_version = 0 self.server = DCFServer( register_worker_callback=self.register_worker, unregister_worker_callback=self.unregister_worker, return_global_model_callback=self.return_global_model, is_global_model_most_recent=self.is_global_model_most_recent, receive_worker_update_callback=self.receive_worker_update, server_mode_safe=False, key_list_file=None, load_last_session_workers=False)
def __init__(self, global_model_trainer, key_list_file, update_lim=10, server_host_ip=None, server_port=8080, ssl_enabled=False, ssl_keyfile=None, ssl_certfile=None): logger.info( f"Initializing FedAvg server for model class {global_model_trainer.get_model().__class__.__name__}") self.worker_updates = {} self.global_model_trainer = global_model_trainer self.update_lim = update_lim self.last_global_model_update_timestamp = datetime(1980, 10, 10) self.server = DCFServer( register_worker_callback=self.register_worker, unregister_worker_callback=self.unregister_worker, return_global_model_callback=self.return_global_model, is_global_model_most_recent=self.is_global_model_most_recent, receive_worker_update_callback=self.receive_worker_update, server_mode_safe=key_list_file is not None, load_last_session_workers=False, key_list_file=key_list_file, server_host_ip=server_host_ip, server_port=server_port, ssl_enabled=ssl_enabled, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, model_check_interval = 1 ) self.unique_updates_since_last_agg = 0 self.iteration = 0 self.model_version = 0
def test_worker_persistence(): worker_ids = [] added_workers = [] worker_updates = {} global_model_version = "1" worker_global_model_version = "0" os.environ[ADMIN_USERNAME] = 'admin' os.environ[ADMIN_PASSWORD] = 'str0ng_s3cr3t' admin_auth = ('admin', 'str0ng_s3cr3t') public_keys = [] private_keys = [] num_workers = 6 num_pre_load_workers = 3 worker_key_file_prefix = 'worker_key_file' for n in range(num_workers): private_key, public_key = gen_pair(worker_key_file_prefix + f'_{n}') private_keys.append( private_key.encode(encoder=HexEncoder).decode('utf-8')) public_keys.append( public_key.encode(encoder=HexEncoder).decode('utf-8')) # write the pre-loaded keys to the worker_key_file = 'worker_public_keys.txt' with open(worker_key_file, 'w') as f: for public_key in public_keys[0:num_pre_load_workers]: f.write(public_key + os.linesep) def begin_server(server, server_adapter): server.start_server(server_adapter) def test_register_func_cb(id): worker_ids.append(id) def test_unregister_func_cb(id): worker_ids.remove(id) def test_ret_global_model_cb(): return create_model_dict(msgpack.packb("Pickle dump of a string"), global_model_version) def is_global_model_most_recent(version): return int(version) == global_model_version def test_rec_server_update_cb(worker_id, update): if worker_id in worker_ids: worker_updates[worker_id] = update return f"Update received for worker {worker_id[0:WID_LEN]}." else: return f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update." def get_signed_phrase(private_key, phrase=b'test phrase'): return SigningKey(private_key, encoder=HexEncoder).sign(phrase).hex() if os.path.exists('workers_db.json'): os.remove('workers_db.json') server = DCFServer( register_worker_callback=test_register_func_cb, unregister_worker_callback=test_unregister_func_cb, return_global_model_callback=test_ret_global_model_cb, is_global_model_most_recent=is_global_model_most_recent, receive_worker_update_callback=test_rec_server_update_cb, server_mode_safe=True, load_last_session_workers=True, path_to_keys_db='workers_db.json', key_list_file=worker_key_file) worker_updates = {} worker_ids = [] added_workers = [] stoppable_server = StoppableServer(host=get_host_ip(), port=8080) server_gl = Greenlet.spawn(begin_server, server, stoppable_server) sleep(2) assert len(server.worker_manager.public_keys_db) == 3 # Register a set of workers using the admin API and test registration for i in range(num_pre_load_workers, num_workers): admin_registered_worker = { PUBLIC_KEY_STR: public_keys[i], REGISTRATION_STATUS_KEY: True } response = requests.post( f"http://{server.server_host_ip}:{server.server_port}/{WORKERS_ROUTE}", json=admin_registered_worker, auth=admin_auth) added_worker_dict = json.loads(response.content.decode('utf-8')) idx = i - num_pre_load_workers assert len(worker_ids) == idx + 1 assert worker_ids[idx] == added_worker_dict[WORKER_ID_KEY] added_workers.append(added_worker_dict[WORKER_ID_KEY]) assert len(server.worker_manager.public_keys_db) == 6 for doc in server.worker_manager.public_keys_db.all(): assert doc[PUBLIC_KEY_STR] in public_keys # Send updates and receive global updates for the registered workers # This should succeed worker_updates = {} for i in range(num_pre_load_workers, num_workers): # send updates signed_phrase = get_signed_phrase( private_keys[i], hashlib.sha256(msgpack.packb("Model update!!")).digest()) response = requests.post( f"http://{server.server_host_ip}:{server.server_port}/" f"{RECEIVE_WORKER_UPDATE_ROUTE}/{added_workers[i - num_pre_load_workers]}", files={ WORKER_MODEL_UPDATE_KEY: zlib.compress(msgpack.packb("Model update!!")), SIGNED_PHRASE: signed_phrase }).content assert msgpack.unpackb(worker_updates[worker_ids[ i - num_pre_load_workers]]) == "Model update!!" assert response.decode( "UTF-8" ) == f"Update received for worker {added_workers[i - num_pre_load_workers][0:WID_LEN]}." # receive updates challenge_phrase = requests.get( f"http://{server.server_host_ip}:{server.server_port}/" f"{CHALLENGE_PHRASE_ROUTE}/{added_workers[i - num_pre_load_workers]}" ).content model_return_binary = requests.post( f"http://{server.server_host_ip}:{server.server_port}/{RETURN_GLOBAL_MODEL_ROUTE}", json={ WORKER_ID_KEY: added_workers[i - num_pre_load_workers], SIGNED_PHRASE: get_signed_phrase(private_keys[i], challenge_phrase), LAST_WORKER_MODEL_VERSION: "0" }).content model_return = msgpack.unpackb(zlib.decompress(model_return_binary)) assert isinstance(model_return, dict) assert model_return[GLOBAL_MODEL_VERSION] == global_model_version assert msgpack.unpackb( model_return[GLOBAL_MODEL]) == "Pickle dump of a string" stoppable_server.shutdown() worker_ids = [] worker_updates = {} server = DCFServer( register_worker_callback=test_register_func_cb, unregister_worker_callback=test_unregister_func_cb, return_global_model_callback=test_ret_global_model_cb, is_global_model_most_recent=is_global_model_most_recent, receive_worker_update_callback=test_rec_server_update_cb, server_mode_safe=True, load_last_session_workers=True, path_to_keys_db='workers_db.json', key_list_file=worker_key_file) assert len(server.worker_manager.public_keys_db) == 6 assert len(server.worker_manager.allowed_workers) == 6 for doc in server.worker_manager.public_keys_db.all(): assert doc[PUBLIC_KEY_STR] in server.worker_manager.allowed_workers stoppable_server = StoppableServer(host=get_host_ip(), port=8080) server_gl = Greenlet.spawn(begin_server, server, stoppable_server) sleep(2) # Delete existing workers and check this works. for i in range(num_pre_load_workers): response = requests.delete( f"http://{server.server_host_ip}:{server.server_port}/{WORKERS_ROUTE}" f"/{added_workers[i]}", auth=admin_auth) message_dict = json.loads(response.content.decode('utf-8')) assert SUCCESS_MESSAGE_KEY in message_dict assert len(worker_ids) == 0 assert len(server.worker_manager.public_keys_db) == 3 assert len(server.worker_manager.allowed_workers) == 3 for doc in server.worker_manager.public_keys_db.all(): assert doc[PUBLIC_KEY_STR] in server.worker_manager.allowed_workers stoppable_server.shutdown() # delete the files for n in range(num_workers): os.remove(worker_key_file_prefix + f'_{n}') os.remove(worker_key_file_prefix + f'_{n}.pub') os.remove(worker_key_file) os.remove('workers_db.json') os.remove('workers_db.json.bak')
class FedAvgServer(object): """ This class implements the server-side of the FedAvg algorithm using the dc_federated.backend package. Parameters ---------- global_model_trainer: FedAvgModelTrainer The name of the python model-class for this problem. update_lim: int Number of unique updates that needs to be received before the last global update before we update the global model. key_list_file: str The list of public keys of valid workers. No authentication is performed if file not given. server_host_ip: str (default None) The hostname or IP address the server will bind to. If not given, it will default to the machine IP. server_port: int (default 8080) The port at which the server should listen to. ssl_enabled: bool (default False) Enable SSL/TLS for server/workers communications. ssl_keyfile: str Must be a valid path to the key file. This is mandatory if ssl_enabled is True, ignored otherwise. ssl_certfile: str Must be a valid path to the certificate. This is mandatory if ssl_enabled is True, ignored otherwise. """ def __init__(self, global_model_trainer, key_list_file, update_lim=10, server_host_ip=None, server_port=8080, ssl_enabled=False, ssl_keyfile=None, ssl_certfile=None): logger.info( f"Initializing FedAvg server for model class {global_model_trainer.get_model().__class__.__name__}") self.worker_updates = {} self.global_model_trainer = global_model_trainer self.update_lim = update_lim self.last_global_model_update_timestamp = datetime(1980, 10, 10) self.server = DCFServer( register_worker_callback=self.register_worker, unregister_worker_callback=self.unregister_worker, return_global_model_callback=self.return_global_model, is_global_model_most_recent=self.is_global_model_most_recent, receive_worker_update_callback=self.receive_worker_update, server_mode_safe=key_list_file is not None, load_last_session_workers=False, key_list_file=key_list_file, server_host_ip=server_host_ip, server_port=server_port, ssl_enabled=ssl_enabled, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, model_check_interval = 1 ) self.unique_updates_since_last_agg = 0 self.iteration = 0 self.model_version = 0 def register_worker(self, worker_id): """ Register the given worker_id by initializing its update to None. Parameters ---------- worker_id: int The id of the new worker. """ logger.info(f"Registered worker {worker_id[0:WID_LEN]}") self.worker_updates[worker_id] = None def unregister_worker(self, worker_id): """ Unregister the given worker_id by removing it from updates. Parameters ---------- worker_id: int The id of the worker to be removed. """ logger.info(f"Unregistered worker {worker_id[0:WID_LEN]}") self.worker_updates.pop(worker_id) def return_global_model(self): """ Serializes the current global torch model, puts it in the proper dictionary, and sends it back. Returns ---------- dict: A dictionary with keys: GLOBAL_MODEL: serialized global model. GLOBAL_MODEL_VERSION: version of the global model """ model_data = io.BytesIO() torch.save(self.global_model_trainer.get_model(), model_data) return { GLOBAL_MODEL: model_data.getvalue(), GLOBAL_MODEL_VERSION: self.model_version } def is_global_model_most_recent(self, model_version): """ Returns a default model update time of 2018/10/10. Parameters ---------- model_version: int The version of most recent global model that the worker has. Returns ---------- str: String format of the last model update time. """ return self.model_version == model_version def receive_worker_update(self, worker_id, model_update): """ Given an update for a worker, adds its update to the dictionary of updates. It also agg_model() to update the global model if necessary. Returns ---------- str: String format of the last model update time. """ if worker_id in self.worker_updates: # update the number of unique updates received if self.worker_updates[worker_id] is None or \ self.worker_updates[worker_id][0] < self.last_global_model_update_timestamp: self.unique_updates_since_last_agg += 1 update_size, model_bytes = msgpack.unpackb(model_update) self.worker_updates[worker_id] = ( datetime.now(), update_size, torch.load(io.BytesIO(model_bytes)) ) logger.info(f"Model update from worker {worker_id[0:WID_LEN]} accepted.") if self.agg_model(): self.global_model_trainer.test() return f"Update received for worker {worker_id[0:WID_LEN]}" else: logger.warning( f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update.") return f"Please register before sending an update." def agg_model(self): """ Updates the global model by aggregating all the most recent updates from the workers, assuming that the number of unique updates received since the last global model update is above the threshold. """ if self.unique_updates_since_last_agg < self.update_lim: return False logger.info("Updating the global model.\n") def agg_params(key, state_dicts, update_sizes): agg_val = state_dicts[0][key] * update_sizes[0] for sd, sz in zip(state_dicts[1:], update_sizes[1:]): agg_val = agg_val + sd[key] * sz agg_val = agg_val / sum(update_sizes) return torch.tensor(agg_val.cpu().clone().numpy()) # gather the model-updates to use for the update state_dicts_to_update_with = [] update_sizes = [] # each item in the worker_updates dictionary contains a # (timestamp update, update-size, model) for wi in self.worker_updates: if self.worker_updates[wi][0] > self.last_global_model_update_timestamp: state_dicts_to_update_with.append( self.worker_updates[wi][2].state_dict()) update_sizes.append(self.worker_updates[wi][1]) # now update the global model global_model_dict = OrderedDict() for key in state_dicts_to_update_with[0].keys(): global_model_dict[key] = agg_params( key, state_dicts_to_update_with, update_sizes) self.global_model_trainer.load_model_from_state_dict(global_model_dict) self.last_global_model_update_timestamp = datetime.now() self.unique_updates_since_last_agg = 0 self.iteration += 1 self.model_version += 1 return True def start(self): self.server.start_server()
def test_server_functionality(): """ Unit tests for the DCFServer and DCFWorker classes. """ worker_ids = [] added_workers = [] worker_updates = {} global_model_version = "1" worker_global_model_version = "0" os.environ[ADMIN_USERNAME] = 'admin' os.environ[ADMIN_PASSWORD] = 'str0ng_s3cr3t' admin_auth = ('admin', 'str0ng_s3cr3t') public_keys = [] private_keys = [] num_workers = 3 worker_key_file_prefix = 'worker_key_file' for n in range(num_workers): private_key, public_key = gen_pair(worker_key_file_prefix + f'_{n}') private_keys.append(private_key.encode(encoder=HexEncoder)) public_keys.append(public_key.encode(encoder=HexEncoder)) def begin_server(server, server_adapter): server.start_server(server_adapter) def test_register_func_cb(id): worker_ids.append(id) def test_unregister_func_cb(id): worker_ids.remove(id) def test_ret_global_model_cb(): return create_model_dict(msgpack.packb("Pickle dump of a string"), global_model_version) def is_global_model_most_recent(version): return int(version) == global_model_version def test_rec_server_update_cb(worker_id, update): if worker_id in worker_ids: worker_updates[worker_id] = update return f"Update received for worker {worker_id[0:WID_LEN]}." else: return f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update." dcf_server_safe = DCFServer( register_worker_callback=test_register_func_cb, unregister_worker_callback=test_unregister_func_cb, return_global_model_callback=test_ret_global_model_cb, is_global_model_most_recent=is_global_model_most_recent, receive_worker_update_callback=test_rec_server_update_cb, server_mode_safe=True, key_list_file=None, load_last_session_workers=False) dcf_server_unsafe = DCFServer( register_worker_callback=test_register_func_cb, unregister_worker_callback=test_unregister_func_cb, return_global_model_callback=test_ret_global_model_cb, is_global_model_most_recent=is_global_model_most_recent, receive_worker_update_callback=test_rec_server_update_cb, server_mode_safe=False, key_list_file=None, load_last_session_workers=False) def get_worker_key(mode, i): if mode == 'safe': return public_keys[i].decode('utf-8') else: return 'dummy_public_key' def get_signed_phrase(mode, i, phrase=b'test phrase'): if mode == 'safe': return SigningKey(private_keys[i], encoder=HexEncoder).sign(phrase).hex() else: return 'dummy_signed_phrase' for server, mode in zip([dcf_server_unsafe, dcf_server_safe], ['unsafe', 'safe']): worker_ids = [] added_workers = [] worker_updates = {} stoppable_server = StoppableServer(host=get_host_ip(), port=8080) server_gl = Greenlet.spawn(begin_server, server, stoppable_server) sleep(2) returned_ids = [] # Phase 1: register a set of workers using the admin API and test registration for i in range(num_workers): admin_registered_worker = { PUBLIC_KEY_STR: get_worker_key(mode, i), REGISTRATION_STATUS_KEY: True } response = requests.post( f"http://{server.server_host_ip}:{server.server_port}/{WORKERS_ROUTE}", json=admin_registered_worker, auth=admin_auth) added_worker_dict = json.loads(response.content.decode('utf-8')) assert len(worker_ids) == i + 1 assert worker_ids[i] == added_worker_dict[WORKER_ID_KEY] added_workers.append(added_worker_dict[WORKER_ID_KEY]) # Phase 2: Send updates and receive global updates for the registered workers # This should succeed worker_updates = {} for i in range(num_workers): # send updates signed_phrase = get_signed_phrase( mode, i, hashlib.sha256(msgpack.packb("Model update!!")).digest()) response = requests.post( f"http://{server.server_host_ip}:{server.server_port}/" f"{RECEIVE_WORKER_UPDATE_ROUTE}/{added_workers[i]}", files={ WORKER_MODEL_UPDATE_KEY: zlib.compress(msgpack.packb("Model update!!")), SIGNED_PHRASE: signed_phrase }).content print(response) assert msgpack.unpackb( worker_updates[worker_ids[i]]) == "Model update!!" assert response.decode( "UTF-8" ) == f"Update received for worker {added_workers[i][0:WID_LEN]}." # receive updates challenge_phrase = requests.get( f"http://{server.server_host_ip}:{server.server_port}/" f"{CHALLENGE_PHRASE_ROUTE}/{added_workers[i]}").content model_return_binary = requests.post( f"http://{server.server_host_ip}:{server.server_port}/{RETURN_GLOBAL_MODEL_ROUTE}", json={ WORKER_ID_KEY: added_workers[i], SIGNED_PHRASE: get_signed_phrase(mode, i, challenge_phrase), LAST_WORKER_MODEL_VERSION: "0" }).content model_return = msgpack.unpackb( zlib.decompress(model_return_binary)) assert isinstance(model_return, dict) assert model_return[GLOBAL_MODEL_VERSION] == global_model_version assert msgpack.unpackb( model_return[GLOBAL_MODEL]) == "Pickle dump of a string" # Phase 3: Unregister workers. for i in range(num_workers): admin_registered_worker = { PUBLIC_KEY_STR: get_worker_key(mode, i), REGISTRATION_STATUS_KEY: False } response = requests.put( f"http://{server.server_host_ip}:{server.server_port}/{WORKERS_ROUTE}" f"/{added_workers[i]}", json=admin_registered_worker, auth=admin_auth) unreg_worker_dict = json.loads(response.content.decode('utf-8')) assert not unreg_worker_dict[REGISTRATION_STATUS_KEY] assert len(worker_ids) == 0 # Phase 4: Try to send updates from the unregistered workers - this should fail worker_updates = {} for i in range(num_workers): # send updates signed_phrase = get_signed_phrase( mode, i, hashlib.sha256(msgpack.packb("Model update!!")).digest()) response = requests.post( f"http://{server.server_host_ip}:{server.server_port}/" f"{RECEIVE_WORKER_UPDATE_ROUTE}/{added_workers[i]}", files={ WORKER_MODEL_UPDATE_KEY: zlib.compress(msgpack.packb("Model update!!")), SIGNED_PHRASE: signed_phrase }).content assert added_workers[i] not in worker_updates assert response.decode('UTF-8') == UNREGISTERED_WORKER # receive updates model_return_binary = requests.post( f"http://{server.server_host_ip}:{server.server_port}/{RETURN_GLOBAL_MODEL_ROUTE}", json={ WORKER_ID_KEY: added_workers[i], LAST_WORKER_MODEL_VERSION: "0" }).content assert response.decode('UTF-8') == UNREGISTERED_WORKER # Phase 5: Re-register existing workers. for i in range(num_workers): admin_registered_worker = { PUBLIC_KEY_STR: get_worker_key(mode, i), REGISTRATION_STATUS_KEY: True } response = requests.put( f"http://{server.server_host_ip}:{server.server_port}/{WORKERS_ROUTE}" f"/{added_workers[i]}", json=admin_registered_worker, auth=admin_auth) unreg_worker_dict = json.loads(response.content.decode('utf-8')) assert unreg_worker_dict[REGISTRATION_STATUS_KEY] # Phase 6: Send updates and receive global updates for the registered workers # This should succeed worker_updates = {} for i in range(num_workers): # send updates signed_phrase = get_signed_phrase( mode, i, hashlib.sha256(msgpack.packb("Model update!!")).digest()) response = requests.post( f"http://{server.server_host_ip}:{server.server_port}/" f"{RECEIVE_WORKER_UPDATE_ROUTE}/{added_workers[i]}", files={ WORKER_MODEL_UPDATE_KEY: zlib.compress(msgpack.packb("Model update!!")), SIGNED_PHRASE: signed_phrase }).content assert msgpack.unpackb( worker_updates[worker_ids[i]]) == "Model update!!" assert response.decode( "UTF-8" ) == f"Update received for worker {added_workers[i][0:WID_LEN]}." # receive updates challenge_phrase = requests.get( f"http://{server.server_host_ip}:{server.server_port}/" f"{CHALLENGE_PHRASE_ROUTE}/{added_workers[i]}").content model_return_binary = requests.post( f"http://{server.server_host_ip}:{server.server_port}/{RETURN_GLOBAL_MODEL_ROUTE}", json={ WORKER_ID_KEY: added_workers[i], SIGNED_PHRASE: get_signed_phrase(mode, i, challenge_phrase), LAST_WORKER_MODEL_VERSION: "0" }).content model_return = msgpack.unpackb( zlib.decompress(model_return_binary)) assert isinstance(model_return, dict) assert model_return[GLOBAL_MODEL_VERSION] == global_model_version assert msgpack.unpackb( model_return[GLOBAL_MODEL]) == "Pickle dump of a string" # Phase 7: Delete existing workers. for i in range(num_workers): response = requests.delete( f"http://{server.server_host_ip}:{server.server_port}/{WORKERS_ROUTE}" f"/{added_workers[i]}", auth=admin_auth) message_dict = json.loads(response.content.decode('utf-8')) assert SUCCESS_MESSAGE_KEY in message_dict assert len(worker_ids) == 0 # Phase 8: Try to send updates to the deleted workers - this should fail worker_updates = {} for i in range(num_workers): # send updates signed_phrase = get_signed_phrase( mode, i, hashlib.sha256(msgpack.packb("Model update!!")).digest()) response = requests.post( f"http://{server.server_host_ip}:{server.server_port}/" f"{RECEIVE_WORKER_UPDATE_ROUTE}/{added_workers[i]}", files={ WORKER_MODEL_UPDATE_KEY: zlib.compress(msgpack.packb("Model update!!")), SIGNED_PHRASE: signed_phrase }).content assert added_workers[i] not in worker_updates assert response.decode('UTF-8') == INVALID_WORKER # receive updates challenge_phrase = requests.get( f"http://{server.server_host_ip}:{server.server_port}/" f"{CHALLENGE_PHRASE_ROUTE}/{added_workers[i]}").content model_return_binary = requests.post( f"http://{server.server_host_ip}:{server.server_port}/{RETURN_GLOBAL_MODEL_ROUTE}", json={ WORKER_ID_KEY: added_workers[i], SIGNED_PHRASE: get_signed_phrase(mode, i, challenge_phrase), LAST_WORKER_MODEL_VERSION: "0" }).content assert response.decode('UTF-8') == INVALID_WORKER # Phase 9: Try to register non-existent workers using the public API # - this should fail in the safe mode and succeed in the unsafe mode. for i in range(num_workers): registration_data = { PUBLIC_KEY_STR: get_worker_key(mode, i), SIGNED_PHRASE: get_signed_phrase(mode, i) } response = requests.post( f"http://{server.server_host_ip}:{server.server_port}/{REGISTER_WORKER_ROUTE}", json=registration_data) if mode == 'safe': assert response.content.decode('utf-8') == INVALID_WORKER else: assert 'unauthenticated' in response.content.decode('utf-8') # Phase 10 - for the safe mode try registering with the public and admin API # with invalid public keys - these should both fail if mode == 'safe': for i in range(num_workers): registration_data = { PUBLIC_KEY_STR: "dummy public key", SIGNED_PHRASE: get_signed_phrase(mode, i) } response = requests.post( f"http://{server.server_host_ip}:{server.server_port}/{REGISTER_WORKER_ROUTE}", json=registration_data) assert response.content.decode('utf-8') == INVALID_WORKER registration_data = { PUBLIC_KEY_STR: get_worker_key(mode, i), SIGNED_PHRASE: "dummy signed phrase key" } response = requests.post( f"http://{server.server_host_ip}:{server.server_port}/{REGISTER_WORKER_ROUTE}", json=registration_data) assert response.content.decode('utf-8') == INVALID_WORKER admin_registered_worker = { PUBLIC_KEY_STR: "dummy public key", REGISTRATION_STATUS_KEY: True } response = requests.post( f"http://{server.server_host_ip}:{server.server_port}/{WORKERS_ROUTE}", json=admin_registered_worker, auth=admin_auth) message = json.loads(response.content.decode('utf-8')) assert ERROR_MESSAGE_KEY in message key_short = "dummy public key"[0:WID_LEN] assert message[ERROR_MESSAGE_KEY] == \ f"Unable to validate public key (short) {key_short} " \ "- worker not added." stoppable_server.shutdown()
class ExampleGlobalModel(object): """ This is a simple class that illustrates how the DCFServer class may be used to implement a federated global model. For testing purposes, it writes all the models it creates and receives to disk. """ def __init__(self): self.worker_updates = {} self.global_model = ExampleModelClass() with open("egm_global_model.torch", 'wb') as f: torch.save(self.global_model, f) self.global_model_version = 0 self.server = DCFServer( register_worker_callback=self.register_worker, unregister_worker_callback=self.unregister_worker, return_global_model_callback=self.return_global_model, is_global_model_most_recent=self.is_global_model_most_recent, receive_worker_update_callback=self.receive_worker_update, server_mode_safe=False, key_list_file=None, load_last_session_workers=False) def register_worker(self, worker_id): """ Register the given worker_id by initializing its update to None. Parameters ---------- worker_id: int The id of the new worker. """ logger.info( f"Example Global Model: Registering worker {worker_id[0:WID_LEN]}") self.worker_updates[worker_id] = None def unregister_worker(self, worker_id): """ Unregister the given worker_id by removing it from updates. Parameters ---------- worker_id: int The id of the worker to be removed. """ logger.info( f"Example Global Model: Unregistering worker {worker_id[0:WID_LEN]}" ) self.worker_updates.pop(worker_id) def return_global_model(self): """ Serializes the current global torch model and sends it back to the worker. Returns ---------- dict: The model dictionary as per the specification in DCFSever """ logger.info(f"Example Global Model: returning global model") model_data = io.BytesIO() torch.save(self.global_model, model_data) return create_model_dict(model_data.getvalue(), self.global_model_version) def is_global_model_most_recent(self, model_version): """ Returns a default model update time of 2018/10/10. Parameter --------- model_version: int Returns ---------- str: String format of the last model update time. """ logger.info( f"Example Global Model: checking if model version is most recent.") return self.global_model_version == model_version def receive_worker_update(self, worker_id, model_update): """ Given an update for a worker, adds the the update to the list of updates. Returns ---------- str: String format of the last model update time. """ if worker_id in self.worker_updates: self.worker_updates[worker_id] = \ torch.load(io.BytesIO(model_update)) logger.info( f"Model update received from worker {worker_id[0:WID_LEN]}") logger.info(self.worker_updates[worker_id]) with open(f"egm_worker_update_{worker_id}.torch", 'wb') as f: torch.save(self.worker_updates[worker_id], f) self.global_model_version += 1 return f"Update received for worker {worker_id[0:WID_LEN]}" else: return f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update!!" def start(self): self.server.start_server()
def test_worker_authentication(): # Create a set of keys to be supplied to the server num_workers = 10 private_keys = [] public_keys = [] worker_key_file_prefix = 'worker_key_file' for n in range(num_workers): private_key, public_key = gen_pair(worker_key_file_prefix + f'_{n}') private_keys.append(private_key) public_keys.append(public_key) worker_ids = [] worker_updates = {} global_model_version = "1" worker_global_model_version = "0" def test_register_func_cb(id): worker_ids.append(id) def test_unregister_func_cb(id): worker_ids.remove(id) def test_ret_global_model_cb(): return create_model_dict(msgpack.packb("Serialized dump of a string"), global_model_version) def is_global_model_most_recent(version): return version == global_model_version def test_rec_server_update_cb(worker_id, update): if worker_id in worker_ids: worker_updates[worker_id] = update return f"Update received for worker {worker_id[0:WID_LEN]}." else: return f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update." def test_glob_mod_chng_cb(model_dict): nonlocal worker_global_model_version worker_global_model_version = model_dict[GLOBAL_MODEL_VERSION] def test_get_last_glob_model_ver(): nonlocal worker_global_model_version return worker_global_model_version worker_key_file = 'worker_public_keys.txt' with open(worker_key_file, 'w') as f: for public_key in public_keys[:-1]: f.write( public_key.encode(encoder=HexEncoder).decode('utf-8') + os.linesep) f.write(public_keys[-1].encode(encoder=HexEncoder).decode('utf-8') + os.linesep) dcf_server = DCFServer( register_worker_callback=test_register_func_cb, unregister_worker_callback=test_unregister_func_cb, return_global_model_callback=test_ret_global_model_cb, is_global_model_most_recent=is_global_model_most_recent, receive_worker_update_callback=test_rec_server_update_cb, server_mode_safe=True, key_list_file=worker_key_file, load_last_session_workers=False) stoppable_server = StoppableServer(host=get_host_ip(), port=8080) def begin_server(): dcf_server.start_server(stoppable_server) server_gl = Greenlet.spawn(begin_server) sleep(2) # create the workers workers = [ DCFWorker( server_protocol='http', server_host_ip=dcf_server.server_host_ip, server_port=dcf_server.server_port, global_model_version_changed_callback=test_glob_mod_chng_cb, get_worker_version_of_global_model=test_get_last_glob_model_ver, private_key_file=worker_key_file_prefix + f"_{n}") for n in range(num_workers) ] # test various worker actions for worker, key in zip(workers, public_keys): worker.register_worker() global_model_dict = worker.get_global_model() worker.send_model_update(b'model_update') assert is_valid_model_dict(global_model_dict) assert global_model_dict[GLOBAL_MODEL] == msgpack.packb( "Serialized dump of a string") assert global_model_dict[GLOBAL_MODEL_VERSION] == global_model_version assert worker_updates[worker.worker_id] == b'model_update' assert worker.worker_id == key.encode( encoder=HexEncoder).decode('utf-8') # try to authenticate a unregistered worker gen_pair('bad_worker') bad_worker = DCFWorker( server_protocol='http', server_host_ip=dcf_server.server_host_ip, server_port=dcf_server.server_port, global_model_version_changed_callback=test_glob_mod_chng_cb, get_worker_version_of_global_model=test_get_last_glob_model_ver, private_key_file='bad_worker') try: bad_worker.register_worker() except ValueError: assert True else: assert False # try to send an update through the using the bad worker public key with open('bad_worker', 'r') as f: bad_worker_key = f.read() id_and_model_dict_good = { WORKER_MODEL_UPDATE_KEY: zlib.compress(msgpack.packb("Bad Model update!!")), SIGNED_PHRASE: SigningKey(bad_worker_key.encode(), encoder=HexEncoder).sign(b"Bad Model update!!").hex() } response = requests.post( f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{RECEIVE_WORKER_UPDATE_ROUTE}/{bad_worker_key}", files=id_and_model_dict_good).content assert response.decode('utf-8') == INVALID_WORKER challenge_phrase = requests.get( f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/" f"{CHALLENGE_PHRASE_ROUTE}/{bad_worker_key}").content assert challenge_phrase.decode('utf-8') == INVALID_WORKER response = requests.post( f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{RETURN_GLOBAL_MODEL_ROUTE}", json={ WORKER_ID_KEY: bad_worker_key, SIGNED_PHRASE: SigningKey(bad_worker_key.encode(), encoder=HexEncoder).sign(b"Some phrase").hex() }).content assert response.decode('utf-8') == INVALID_WORKER # delete the files for n in range(num_workers): os.remove(worker_key_file_prefix + f'_{n}') os.remove(worker_key_file_prefix + f'_{n}.pub') os.remove(worker_key_file) os.remove("bad_worker") os.remove("bad_worker.pub") stoppable_server.shutdown()
def test_server_functionality(): """ Unit tests for the DCFServer and DCFWorker classes. """ worker_ids = [] worker_updates = {} global_model_version = "1" worker_global_model_version = "0" os.environ[ADMIN_USERNAME] = 'admin' os.environ[ADMIN_PASSWORD] = 'str0ng_s3cr3t' admin_auth = ('admin', 'str0ng_s3cr3t') stoppable_server = StoppableServer(host=get_host_ip(), port=8080) def begin_server(): dcf_server.start_server(stoppable_server) def test_register_func_cb(id): worker_ids.append(id) def test_unregister_func_cb(id): worker_ids.remove(id) def test_ret_global_model_cb(): return create_model_dict(msgpack.packb("Pickle dump of a string"), global_model_version) def is_global_model_most_recent(version): return int(version) == global_model_version def test_rec_server_update_cb(worker_id, update): if worker_id in worker_ids: worker_updates[worker_id] = update return f"Update received for worker {worker_id[0:WID_LEN]}." else: return f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update." def test_glob_mod_chng_cb(model_dict): nonlocal worker_global_model_version worker_global_model_version = model_dict[GLOBAL_MODEL_VERSION] def test_get_last_glob_model_ver(): nonlocal worker_global_model_version return worker_global_model_version # try to create a server with incorrect server mode, key file combination - should raise ValueError try: dcf_server = DCFServer( register_worker_callback=test_register_func_cb, unregister_worker_callback=test_unregister_func_cb, return_global_model_callback=test_ret_global_model_cb, is_global_model_most_recent=is_global_model_most_recent, receive_worker_update_callback=test_rec_server_update_cb, server_mode_safe=False, key_list_file="some_file_name.txt", load_last_session_workers=False) except ValueError as ve: error_str = "Server started in unsafe mode but list of public keys provided. " \ "Either explicitly start server in safe mode or do not " \ "supply a public key list." assert str(ve) == error_str else: assert False # now create the actual server instance to use dcf_server = DCFServer( register_worker_callback=test_register_func_cb, unregister_worker_callback=test_unregister_func_cb, return_global_model_callback=test_ret_global_model_cb, is_global_model_most_recent=is_global_model_most_recent, receive_worker_update_callback=test_rec_server_update_cb, server_mode_safe=False, key_list_file=None) server_gl = Greenlet.spawn(begin_server) sleep(2) # register a set of workers data = { PUBLIC_KEY_STR: "dummy public key", SIGNED_PHRASE: "dummy signed phrase" } for _ in range(3): requests.post( f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{REGISTER_WORKER_ROUTE}", json=data) assert len(worker_ids) == 3 assert len(set(worker_ids)) == 3 assert worker_ids[0].__class__ == worker_ids[1].__class__ == worker_ids[ 2].__class__ response = requests.get( f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{WORKERS_ROUTE}", auth=admin_auth).content workers_list = json.loads(response) assert all( [worker[WORKER_ID_KEY] in worker_ids for worker in workers_list]) requests.post( f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{WORKERS_ROUTE}", json={}, auth=admin_auth) assert len(worker_ids) == 3 admin_registered_worker = { PUBLIC_KEY_STR: "new_public_key", REGISTRATION_STATUS_KEY: True } response = requests.post( f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{WORKERS_ROUTE}", json=admin_registered_worker, auth=admin_auth) added_worker_dict = json.loads(response.content.decode('utf-8')) assert len(worker_ids) == 4 assert worker_ids[3] != admin_registered_worker[PUBLIC_KEY_STR] assert worker_ids[3] == added_worker_dict[WORKER_ID_KEY] requests.delete( f"http://{dcf_server.server_host_ip}:" f"{dcf_server.server_port}/{WORKERS_ROUTE}/{added_worker_dict[WORKER_ID_KEY]}", auth=admin_auth) assert len(worker_ids) == 3 # test getting the global model model_return_binary = requests.post( f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{RETURN_GLOBAL_MODEL_ROUTE}", json={ WORKER_ID_KEY: worker_ids[0], SIGNED_PHRASE: "", LAST_WORKER_MODEL_VERSION: "0" }).content model_return = msgpack.unpackb(zlib.decompress(model_return_binary)) assert isinstance(model_return, dict) assert model_return[GLOBAL_MODEL_VERSION] == global_model_version assert msgpack.unpackb( model_return[GLOBAL_MODEL]) == "Pickle dump of a string" # test sending the model update response = requests.post( f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{RECEIVE_WORKER_UPDATE_ROUTE}/{worker_ids[1]}", files={ WORKER_MODEL_UPDATE_KEY: zlib.compress(msgpack.packb("Model update!!")), SIGNED_PHRASE: "" }).content assert msgpack.unpackb(worker_updates[worker_ids[1]]) == "Model update!!" assert response.decode( "UTF-8") == f"Update received for worker {worker_ids[1][0:WID_LEN]}." response = requests.post( f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{RECEIVE_WORKER_UPDATE_ROUTE}/3", files={ WORKER_MODEL_UPDATE_KEY: zlib.compress( msgpack.packb("Model update for unregistered worker!!")), SIGNED_PHRASE: "" }).content assert 3 not in worker_updates assert response.decode('UTF-8') == INVALID_WORKER # *********** # # now test a DCFWorker on the same server. dcf_worker = DCFWorker( server_protocol='http', server_host_ip=dcf_server.server_host_ip, server_port=dcf_server.server_port, global_model_version_changed_callback=test_glob_mod_chng_cb, get_worker_version_of_global_model=test_get_last_glob_model_ver, private_key_file=None) # test worker registration dcf_worker.register_worker() assert dcf_worker.worker_id == worker_ids[3] # test getting the global model update global_model_dict = dcf_worker.get_global_model() assert is_valid_model_dict(global_model_dict) assert global_model_dict[GLOBAL_MODEL_VERSION] == global_model_version assert msgpack.unpackb( global_model_dict[GLOBAL_MODEL]) == "Pickle dump of a string" # test sending the model update response = dcf_worker.send_model_update( msgpack.packb("DCFWorker model update")) assert msgpack.unpackb( worker_updates[worker_ids[3]]) == "DCFWorker model update" assert response.decode( "UTF-8") == f"Update received for worker {worker_ids[3][0:WID_LEN]}." stoppable_server.shutdown()
def test_long_polling(): # Create a set of keys to be supplied to the server num_workers = 100 private_keys = [] public_keys = [] server_model_check_interval = 1 halt_time = 10 keys_folder = 'keys_folder' if not os.path.exists(keys_folder): os.mkdir(keys_folder) worker_key_file_prefix = 'worker_key_file' for n in range(num_workers): private_key, public_key = gen_pair( os.path.join(keys_folder, worker_key_file_prefix + f'_{n}')) private_keys.append(private_key) public_keys.append(public_key) worker_ids = [] worker_updates = {} global_model_version = "1" def test_register_func_cb(id): worker_ids.append(id) def test_unregister_func_cb(id): worker_ids.remove(id) def test_ret_global_model_cb(): return create_model_dict(msgpack.packb("Pickle dump of a string"), global_model_version) def is_global_model_most_recent(version): return version == global_model_version def test_rec_server_update_cb(worker_id, update): if worker_id in worker_ids: worker_updates[worker_id] = update return f"Update received for worker {worker_id[0:WID_LEN]}." else: return f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update." worker_key_file = os.path.join(keys_folder, 'worker_public_keys.txt') with open(worker_key_file, 'w') as f: for public_key in public_keys[:-1]: f.write( public_key.encode(encoder=HexEncoder).decode('utf-8') + os.linesep) f.write(public_keys[-1].encode(encoder=HexEncoder).decode('utf-8') + os.linesep) dcf_server = DCFServer( register_worker_callback=test_register_func_cb, unregister_worker_callback=test_unregister_func_cb, return_global_model_callback=test_ret_global_model_cb, is_global_model_most_recent=is_global_model_most_recent, receive_worker_update_callback=test_rec_server_update_cb, server_mode_safe=True, key_list_file=worker_key_file, model_check_interval=server_model_check_interval, load_last_session_workers=False) stoppable_server = StoppableServer(host=get_host_ip(), port=8080) def begin_server(): dcf_server.start_server(stoppable_server) server_gl = Greenlet.spawn(begin_server) sleep(2) # create the workers workers = [ SimpleLPWorker( dcf_server.server_host_ip, dcf_server.server_port, os.path.join(keys_folder, worker_key_file_prefix + f"_{n}")) for n in range(num_workers) ] for worker, key in zip(workers, public_keys): worker.worker.register_worker() # get the current global model and check for worker in workers: worker.global_model_changed_callback(worker.worker.get_global_model()) for worker in workers: assert worker.gm_version == global_model_version done_count = 0 # test that a single call to the server exits after 5 seconds. def run_wg(gl_worker): logger.info(f"Starting long poll for {gl_worker.worker.worker_id}") gl_worker.global_model_changed_callback( gl_worker.worker.get_global_model()) logger.info(f"Long poll for {gl_worker.worker.worker_id} finished") nonlocal done_count done_count += 1 for i, worker in enumerate(workers): Greenlet.spawn(run_wg, worker) if (i + 1) % 5 == 0: sleep(0.5) logger.info(f"The test will halt for {halt_time} seconds now...") sleep(halt_time) global_model_version = "2" start_time = datetime.now() # if it hasn't stopped after 100 seconds, it has failed. while done_count < num_workers and (datetime.now() - start_time).seconds < 100: sleep(1) logger.info( f"{done_count} workers have received the global model update - need to get to {num_workers}..." ) # all the calls to get the global model should have succeeded by now assert done_count == num_workers logger.info(f"All workers have received the global model update.") stoppable_server.shutdown() for f in os.listdir(keys_folder): os.remove(os.path.join(keys_folder, f))