Exemplo n.º 1
0
    def __init__(self):
        self.worker_updates = {}
        self.global_model = ExampleModelClass()
        with open("egm_global_model.torch", 'wb') as f:
            torch.save(self.global_model, f)

        self.global_model_version = 0

        self.server = DCFServer(
            register_worker_callback=self.register_worker,
            unregister_worker_callback=self.unregister_worker,
            return_global_model_callback=self.return_global_model,
            is_global_model_most_recent=self.is_global_model_most_recent,
            receive_worker_update_callback=self.receive_worker_update,
            server_mode_safe=False,
            key_list_file=None,
            load_last_session_workers=False)
Exemplo n.º 2
0
    def __init__(self,
                 global_model_trainer,
                 key_list_file,
                 update_lim=10,
                 server_host_ip=None,
                 server_port=8080,
                 ssl_enabled=False,
                 ssl_keyfile=None,
                 ssl_certfile=None):
        logger.info(
            f"Initializing FedAvg server for model class {global_model_trainer.get_model().__class__.__name__}")

        self.worker_updates = {}
        self.global_model_trainer = global_model_trainer
        self.update_lim = update_lim

        self.last_global_model_update_timestamp = datetime(1980, 10, 10)
        self.server = DCFServer(
            register_worker_callback=self.register_worker,
            unregister_worker_callback=self.unregister_worker,
            return_global_model_callback=self.return_global_model,
            is_global_model_most_recent=self.is_global_model_most_recent,
            receive_worker_update_callback=self.receive_worker_update,
            server_mode_safe=key_list_file is not None,
            load_last_session_workers=False,
            key_list_file=key_list_file,
            server_host_ip=server_host_ip,
            server_port=server_port,
            ssl_enabled=ssl_enabled,
            ssl_keyfile=ssl_keyfile,
            ssl_certfile=ssl_certfile,
            model_check_interval = 1
        )

        self.unique_updates_since_last_agg = 0
        self.iteration = 0
        self.model_version = 0
Exemplo n.º 3
0
def test_worker_persistence():
    worker_ids = []
    added_workers = []
    worker_updates = {}

    global_model_version = "1"
    worker_global_model_version = "0"
    os.environ[ADMIN_USERNAME] = 'admin'
    os.environ[ADMIN_PASSWORD] = 'str0ng_s3cr3t'
    admin_auth = ('admin', 'str0ng_s3cr3t')

    public_keys = []
    private_keys = []
    num_workers = 6
    num_pre_load_workers = 3
    worker_key_file_prefix = 'worker_key_file'
    for n in range(num_workers):
        private_key, public_key = gen_pair(worker_key_file_prefix + f'_{n}')
        private_keys.append(
            private_key.encode(encoder=HexEncoder).decode('utf-8'))
        public_keys.append(
            public_key.encode(encoder=HexEncoder).decode('utf-8'))

    # write the pre-loaded keys to the
    worker_key_file = 'worker_public_keys.txt'
    with open(worker_key_file, 'w') as f:
        for public_key in public_keys[0:num_pre_load_workers]:
            f.write(public_key + os.linesep)

    def begin_server(server, server_adapter):
        server.start_server(server_adapter)

    def test_register_func_cb(id):
        worker_ids.append(id)

    def test_unregister_func_cb(id):
        worker_ids.remove(id)

    def test_ret_global_model_cb():
        return create_model_dict(msgpack.packb("Pickle dump of a string"),
                                 global_model_version)

    def is_global_model_most_recent(version):
        return int(version) == global_model_version

    def test_rec_server_update_cb(worker_id, update):
        if worker_id in worker_ids:
            worker_updates[worker_id] = update
            return f"Update received for worker {worker_id[0:WID_LEN]}."
        else:
            return f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update."

    def get_signed_phrase(private_key, phrase=b'test phrase'):
        return SigningKey(private_key, encoder=HexEncoder).sign(phrase).hex()

    if os.path.exists('workers_db.json'):
        os.remove('workers_db.json')

    server = DCFServer(
        register_worker_callback=test_register_func_cb,
        unregister_worker_callback=test_unregister_func_cb,
        return_global_model_callback=test_ret_global_model_cb,
        is_global_model_most_recent=is_global_model_most_recent,
        receive_worker_update_callback=test_rec_server_update_cb,
        server_mode_safe=True,
        load_last_session_workers=True,
        path_to_keys_db='workers_db.json',
        key_list_file=worker_key_file)

    worker_updates = {}
    worker_ids = []
    added_workers = []
    stoppable_server = StoppableServer(host=get_host_ip(), port=8080)
    server_gl = Greenlet.spawn(begin_server, server, stoppable_server)
    sleep(2)

    assert len(server.worker_manager.public_keys_db) == 3
    # Register a set of workers using the admin API and test registration
    for i in range(num_pre_load_workers, num_workers):

        admin_registered_worker = {
            PUBLIC_KEY_STR: public_keys[i],
            REGISTRATION_STATUS_KEY: True
        }
        response = requests.post(
            f"http://{server.server_host_ip}:{server.server_port}/{WORKERS_ROUTE}",
            json=admin_registered_worker,
            auth=admin_auth)

        added_worker_dict = json.loads(response.content.decode('utf-8'))
        idx = i - num_pre_load_workers
        assert len(worker_ids) == idx + 1
        assert worker_ids[idx] == added_worker_dict[WORKER_ID_KEY]
        added_workers.append(added_worker_dict[WORKER_ID_KEY])

    assert len(server.worker_manager.public_keys_db) == 6

    for doc in server.worker_manager.public_keys_db.all():
        assert doc[PUBLIC_KEY_STR] in public_keys

    # Send updates and receive global updates for the registered workers
    # This should succeed
    worker_updates = {}
    for i in range(num_pre_load_workers, num_workers):
        # send updates

        signed_phrase = get_signed_phrase(
            private_keys[i],
            hashlib.sha256(msgpack.packb("Model update!!")).digest())
        response = requests.post(
            f"http://{server.server_host_ip}:{server.server_port}/"
            f"{RECEIVE_WORKER_UPDATE_ROUTE}/{added_workers[i - num_pre_load_workers]}",
            files={
                WORKER_MODEL_UPDATE_KEY:
                zlib.compress(msgpack.packb("Model update!!")),
                SIGNED_PHRASE:
                signed_phrase
            }).content
        assert msgpack.unpackb(worker_updates[worker_ids[
            i - num_pre_load_workers]]) == "Model update!!"
        assert response.decode(
            "UTF-8"
        ) == f"Update received for worker {added_workers[i - num_pre_load_workers][0:WID_LEN]}."

        # receive updates

        challenge_phrase = requests.get(
            f"http://{server.server_host_ip}:{server.server_port}/"
            f"{CHALLENGE_PHRASE_ROUTE}/{added_workers[i - num_pre_load_workers]}"
        ).content
        model_return_binary = requests.post(
            f"http://{server.server_host_ip}:{server.server_port}/{RETURN_GLOBAL_MODEL_ROUTE}",
            json={
                WORKER_ID_KEY: added_workers[i - num_pre_load_workers],
                SIGNED_PHRASE: get_signed_phrase(private_keys[i],
                                                 challenge_phrase),
                LAST_WORKER_MODEL_VERSION: "0"
            }).content
        model_return = msgpack.unpackb(zlib.decompress(model_return_binary))
        assert isinstance(model_return, dict)
        assert model_return[GLOBAL_MODEL_VERSION] == global_model_version
        assert msgpack.unpackb(
            model_return[GLOBAL_MODEL]) == "Pickle dump of a string"

    stoppable_server.shutdown()

    worker_ids = []
    worker_updates = {}

    server = DCFServer(
        register_worker_callback=test_register_func_cb,
        unregister_worker_callback=test_unregister_func_cb,
        return_global_model_callback=test_ret_global_model_cb,
        is_global_model_most_recent=is_global_model_most_recent,
        receive_worker_update_callback=test_rec_server_update_cb,
        server_mode_safe=True,
        load_last_session_workers=True,
        path_to_keys_db='workers_db.json',
        key_list_file=worker_key_file)

    assert len(server.worker_manager.public_keys_db) == 6
    assert len(server.worker_manager.allowed_workers) == 6
    for doc in server.worker_manager.public_keys_db.all():
        assert doc[PUBLIC_KEY_STR] in server.worker_manager.allowed_workers

    stoppable_server = StoppableServer(host=get_host_ip(), port=8080)
    server_gl = Greenlet.spawn(begin_server, server, stoppable_server)
    sleep(2)

    # Delete existing workers and check this works.
    for i in range(num_pre_load_workers):
        response = requests.delete(
            f"http://{server.server_host_ip}:{server.server_port}/{WORKERS_ROUTE}"
            f"/{added_workers[i]}",
            auth=admin_auth)
        message_dict = json.loads(response.content.decode('utf-8'))
        assert SUCCESS_MESSAGE_KEY in message_dict
    assert len(worker_ids) == 0

    assert len(server.worker_manager.public_keys_db) == 3
    assert len(server.worker_manager.allowed_workers) == 3
    for doc in server.worker_manager.public_keys_db.all():
        assert doc[PUBLIC_KEY_STR] in server.worker_manager.allowed_workers

    stoppable_server.shutdown()

    # delete the files
    for n in range(num_workers):
        os.remove(worker_key_file_prefix + f'_{n}')
        os.remove(worker_key_file_prefix + f'_{n}.pub')
    os.remove(worker_key_file)

    os.remove('workers_db.json')
    os.remove('workers_db.json.bak')
Exemplo n.º 4
0
class FedAvgServer(object):
    """
    This class implements the server-side of the FedAvg algorithm using the
    dc_federated.backend package.

    Parameters
    ----------

    global_model_trainer: FedAvgModelTrainer
        The name of the python model-class for this problem.

    update_lim: int
        Number of unique updates that needs to be received before the last
        global update before we update the global model.

    key_list_file: str
        The list of public keys of valid workers. No authentication is performed
        if file not given.

    server_host_ip: str (default None)
        The hostname or IP address the server will bind to.
        If not given, it will default to the machine IP.

    server_port: int (default 8080)
        The port at which the server should listen to.

    ssl_enabled: bool (default False)
        Enable SSL/TLS for server/workers communications.

    ssl_keyfile: str
        Must be a valid path to the key file.
        This is mandatory if ssl_enabled is True, ignored otherwise.

    ssl_certfile: str
        Must be a valid path to the certificate.
        This is mandatory if ssl_enabled is True, ignored otherwise.
    """

    def __init__(self,
                 global_model_trainer,
                 key_list_file,
                 update_lim=10,
                 server_host_ip=None,
                 server_port=8080,
                 ssl_enabled=False,
                 ssl_keyfile=None,
                 ssl_certfile=None):
        logger.info(
            f"Initializing FedAvg server for model class {global_model_trainer.get_model().__class__.__name__}")

        self.worker_updates = {}
        self.global_model_trainer = global_model_trainer
        self.update_lim = update_lim

        self.last_global_model_update_timestamp = datetime(1980, 10, 10)
        self.server = DCFServer(
            register_worker_callback=self.register_worker,
            unregister_worker_callback=self.unregister_worker,
            return_global_model_callback=self.return_global_model,
            is_global_model_most_recent=self.is_global_model_most_recent,
            receive_worker_update_callback=self.receive_worker_update,
            server_mode_safe=key_list_file is not None,
            load_last_session_workers=False,
            key_list_file=key_list_file,
            server_host_ip=server_host_ip,
            server_port=server_port,
            ssl_enabled=ssl_enabled,
            ssl_keyfile=ssl_keyfile,
            ssl_certfile=ssl_certfile,
            model_check_interval = 1
        )

        self.unique_updates_since_last_agg = 0
        self.iteration = 0
        self.model_version = 0

    def register_worker(self, worker_id):
        """
        Register the given worker_id by initializing its update to None.

        Parameters
        ----------

        worker_id: int
            The id of the new worker.
        """
        logger.info(f"Registered worker {worker_id[0:WID_LEN]}")
        self.worker_updates[worker_id] = None

    def unregister_worker(self, worker_id):
        """
        Unregister the given worker_id by removing it from updates.

        Parameters
        ----------

        worker_id: int
            The id of the worker to be removed.
        """
        logger.info(f"Unregistered worker {worker_id[0:WID_LEN]}")
        self.worker_updates.pop(worker_id)

    def return_global_model(self):
        """
        Serializes the current global torch model, puts it in the proper
        dictionary, and sends it back.

        Returns
        ----------

        dict:
            A dictionary with keys:
            GLOBAL_MODEL: serialized global model.
            GLOBAL_MODEL_VERSION: version of the global model
        """
        model_data = io.BytesIO()
        torch.save(self.global_model_trainer.get_model(), model_data)

        return {
            GLOBAL_MODEL: model_data.getvalue(),
            GLOBAL_MODEL_VERSION: self.model_version
        }

    def is_global_model_most_recent(self, model_version):
        """
        Returns a default model update time of 2018/10/10.

        Parameters
        ----------

        model_version: int
            The version of most recent global model that the
            worker has.

        Returns
        ----------

        str:
            String format of the last model update time.
        """
        return self.model_version == model_version

    def receive_worker_update(self, worker_id, model_update):
        """
        Given an update for a worker, adds its update to the dictionary of updates.
        It also agg_model() to update the global model if necessary.

        Returns
        ----------

        str:
            String format of the last model update time.
        """
        if worker_id in self.worker_updates:
            # update the number of unique updates received
            if self.worker_updates[worker_id] is None or \
                    self.worker_updates[worker_id][0] < self.last_global_model_update_timestamp:
                self.unique_updates_since_last_agg += 1
            update_size, model_bytes = msgpack.unpackb(model_update)
            self.worker_updates[worker_id] = (
                datetime.now(),
                update_size,
                torch.load(io.BytesIO(model_bytes))
            )
            logger.info(f"Model update from worker {worker_id[0:WID_LEN]} accepted.")
            if self.agg_model():
                self.global_model_trainer.test()
            return f"Update received for worker {worker_id[0:WID_LEN]}"
        else:
            logger.warning(
                f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update.")
            return f"Please register before sending an update."

    def agg_model(self):
        """
        Updates the global model by aggregating all the most recent updates
        from the workers, assuming that the number of unique updates received
        since the last global model update is above the threshold.
        """
        if self.unique_updates_since_last_agg < self.update_lim:
            return False

        logger.info("Updating the global model.\n")

        def agg_params(key, state_dicts, update_sizes):
            agg_val = state_dicts[0][key] * update_sizes[0]
            for sd, sz in zip(state_dicts[1:], update_sizes[1:]):
                agg_val = agg_val + sd[key] * sz
            agg_val = agg_val / sum(update_sizes)
            return torch.tensor(agg_val.cpu().clone().numpy())

        # gather the model-updates to use for the update
        state_dicts_to_update_with = []
        update_sizes = []
        # each item in the worker_updates dictionary contains a
        # (timestamp update, update-size, model)
        for wi in self.worker_updates:
            if self.worker_updates[wi][0] > self.last_global_model_update_timestamp:
                state_dicts_to_update_with.append(
                    self.worker_updates[wi][2].state_dict())
                update_sizes.append(self.worker_updates[wi][1])

        # now update the global model
        global_model_dict = OrderedDict()
        for key in state_dicts_to_update_with[0].keys():
            global_model_dict[key] = agg_params(
                key, state_dicts_to_update_with, update_sizes)

        self.global_model_trainer.load_model_from_state_dict(global_model_dict)

        self.last_global_model_update_timestamp = datetime.now()
        self.unique_updates_since_last_agg = 0
        self.iteration += 1
        self.model_version += 1

        return True

    def start(self):
        self.server.start_server()
Exemplo n.º 5
0
def test_server_functionality():
    """
    Unit tests for the DCFServer and DCFWorker classes.
    """
    worker_ids = []
    added_workers = []
    worker_updates = {}
    global_model_version = "1"
    worker_global_model_version = "0"
    os.environ[ADMIN_USERNAME] = 'admin'
    os.environ[ADMIN_PASSWORD] = 'str0ng_s3cr3t'
    admin_auth = ('admin', 'str0ng_s3cr3t')

    public_keys = []
    private_keys = []
    num_workers = 3
    worker_key_file_prefix = 'worker_key_file'
    for n in range(num_workers):
        private_key, public_key = gen_pair(worker_key_file_prefix + f'_{n}')
        private_keys.append(private_key.encode(encoder=HexEncoder))
        public_keys.append(public_key.encode(encoder=HexEncoder))

    def begin_server(server, server_adapter):
        server.start_server(server_adapter)

    def test_register_func_cb(id):
        worker_ids.append(id)

    def test_unregister_func_cb(id):
        worker_ids.remove(id)

    def test_ret_global_model_cb():
        return create_model_dict(msgpack.packb("Pickle dump of a string"),
                                 global_model_version)

    def is_global_model_most_recent(version):
        return int(version) == global_model_version

    def test_rec_server_update_cb(worker_id, update):
        if worker_id in worker_ids:
            worker_updates[worker_id] = update
            return f"Update received for worker {worker_id[0:WID_LEN]}."
        else:
            return f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update."

    dcf_server_safe = DCFServer(
        register_worker_callback=test_register_func_cb,
        unregister_worker_callback=test_unregister_func_cb,
        return_global_model_callback=test_ret_global_model_cb,
        is_global_model_most_recent=is_global_model_most_recent,
        receive_worker_update_callback=test_rec_server_update_cb,
        server_mode_safe=True,
        key_list_file=None,
        load_last_session_workers=False)

    dcf_server_unsafe = DCFServer(
        register_worker_callback=test_register_func_cb,
        unregister_worker_callback=test_unregister_func_cb,
        return_global_model_callback=test_ret_global_model_cb,
        is_global_model_most_recent=is_global_model_most_recent,
        receive_worker_update_callback=test_rec_server_update_cb,
        server_mode_safe=False,
        key_list_file=None,
        load_last_session_workers=False)

    def get_worker_key(mode, i):
        if mode == 'safe': return public_keys[i].decode('utf-8')
        else: return 'dummy_public_key'

    def get_signed_phrase(mode, i, phrase=b'test phrase'):
        if mode == 'safe':
            return SigningKey(private_keys[i],
                              encoder=HexEncoder).sign(phrase).hex()
        else:
            return 'dummy_signed_phrase'

    for server, mode in zip([dcf_server_unsafe, dcf_server_safe],
                            ['unsafe', 'safe']):
        worker_ids = []
        added_workers = []
        worker_updates = {}

        stoppable_server = StoppableServer(host=get_host_ip(), port=8080)
        server_gl = Greenlet.spawn(begin_server, server, stoppable_server)
        sleep(2)

        returned_ids = []
        # Phase 1: register a set of workers using the admin API and test registration
        for i in range(num_workers):

            admin_registered_worker = {
                PUBLIC_KEY_STR: get_worker_key(mode, i),
                REGISTRATION_STATUS_KEY: True
            }
            response = requests.post(
                f"http://{server.server_host_ip}:{server.server_port}/{WORKERS_ROUTE}",
                json=admin_registered_worker,
                auth=admin_auth)

            added_worker_dict = json.loads(response.content.decode('utf-8'))
            assert len(worker_ids) == i + 1
            assert worker_ids[i] == added_worker_dict[WORKER_ID_KEY]
            added_workers.append(added_worker_dict[WORKER_ID_KEY])

        # Phase 2: Send updates and receive global updates for the registered workers
        # This should succeed
        worker_updates = {}
        for i in range(num_workers):
            # send updates
            signed_phrase = get_signed_phrase(
                mode, i,
                hashlib.sha256(msgpack.packb("Model update!!")).digest())
            response = requests.post(
                f"http://{server.server_host_ip}:{server.server_port}/"
                f"{RECEIVE_WORKER_UPDATE_ROUTE}/{added_workers[i]}",
                files={
                    WORKER_MODEL_UPDATE_KEY:
                    zlib.compress(msgpack.packb("Model update!!")),
                    SIGNED_PHRASE:
                    signed_phrase
                }).content
            print(response)
            assert msgpack.unpackb(
                worker_updates[worker_ids[i]]) == "Model update!!"
            assert response.decode(
                "UTF-8"
            ) == f"Update received for worker {added_workers[i][0:WID_LEN]}."

            # receive updates
            challenge_phrase = requests.get(
                f"http://{server.server_host_ip}:{server.server_port}/"
                f"{CHALLENGE_PHRASE_ROUTE}/{added_workers[i]}").content
            model_return_binary = requests.post(
                f"http://{server.server_host_ip}:{server.server_port}/{RETURN_GLOBAL_MODEL_ROUTE}",
                json={
                    WORKER_ID_KEY: added_workers[i],
                    SIGNED_PHRASE: get_signed_phrase(mode, i,
                                                     challenge_phrase),
                    LAST_WORKER_MODEL_VERSION: "0"
                }).content
            model_return = msgpack.unpackb(
                zlib.decompress(model_return_binary))
            assert isinstance(model_return, dict)
            assert model_return[GLOBAL_MODEL_VERSION] == global_model_version
            assert msgpack.unpackb(
                model_return[GLOBAL_MODEL]) == "Pickle dump of a string"

        # Phase 3: Unregister workers.
        for i in range(num_workers):
            admin_registered_worker = {
                PUBLIC_KEY_STR: get_worker_key(mode, i),
                REGISTRATION_STATUS_KEY: False
            }
            response = requests.put(
                f"http://{server.server_host_ip}:{server.server_port}/{WORKERS_ROUTE}"
                f"/{added_workers[i]}",
                json=admin_registered_worker,
                auth=admin_auth)
            unreg_worker_dict = json.loads(response.content.decode('utf-8'))
            assert not unreg_worker_dict[REGISTRATION_STATUS_KEY]
        assert len(worker_ids) == 0

        # Phase 4: Try to send updates from the unregistered workers - this should fail
        worker_updates = {}
        for i in range(num_workers):
            # send updates
            signed_phrase = get_signed_phrase(
                mode, i,
                hashlib.sha256(msgpack.packb("Model update!!")).digest())
            response = requests.post(
                f"http://{server.server_host_ip}:{server.server_port}/"
                f"{RECEIVE_WORKER_UPDATE_ROUTE}/{added_workers[i]}",
                files={
                    WORKER_MODEL_UPDATE_KEY:
                    zlib.compress(msgpack.packb("Model update!!")),
                    SIGNED_PHRASE:
                    signed_phrase
                }).content
            assert added_workers[i] not in worker_updates
            assert response.decode('UTF-8') == UNREGISTERED_WORKER

            # receive updates
            model_return_binary = requests.post(
                f"http://{server.server_host_ip}:{server.server_port}/{RETURN_GLOBAL_MODEL_ROUTE}",
                json={
                    WORKER_ID_KEY: added_workers[i],
                    LAST_WORKER_MODEL_VERSION: "0"
                }).content
            assert response.decode('UTF-8') == UNREGISTERED_WORKER

        # Phase 5: Re-register existing workers.
        for i in range(num_workers):
            admin_registered_worker = {
                PUBLIC_KEY_STR: get_worker_key(mode, i),
                REGISTRATION_STATUS_KEY: True
            }
            response = requests.put(
                f"http://{server.server_host_ip}:{server.server_port}/{WORKERS_ROUTE}"
                f"/{added_workers[i]}",
                json=admin_registered_worker,
                auth=admin_auth)
            unreg_worker_dict = json.loads(response.content.decode('utf-8'))
            assert unreg_worker_dict[REGISTRATION_STATUS_KEY]

        # Phase 6: Send updates and receive global updates for the registered workers
        # This should succeed
        worker_updates = {}
        for i in range(num_workers):
            # send updates
            signed_phrase = get_signed_phrase(
                mode, i,
                hashlib.sha256(msgpack.packb("Model update!!")).digest())
            response = requests.post(
                f"http://{server.server_host_ip}:{server.server_port}/"
                f"{RECEIVE_WORKER_UPDATE_ROUTE}/{added_workers[i]}",
                files={
                    WORKER_MODEL_UPDATE_KEY:
                    zlib.compress(msgpack.packb("Model update!!")),
                    SIGNED_PHRASE:
                    signed_phrase
                }).content
            assert msgpack.unpackb(
                worker_updates[worker_ids[i]]) == "Model update!!"
            assert response.decode(
                "UTF-8"
            ) == f"Update received for worker {added_workers[i][0:WID_LEN]}."

            # receive updates
            challenge_phrase = requests.get(
                f"http://{server.server_host_ip}:{server.server_port}/"
                f"{CHALLENGE_PHRASE_ROUTE}/{added_workers[i]}").content
            model_return_binary = requests.post(
                f"http://{server.server_host_ip}:{server.server_port}/{RETURN_GLOBAL_MODEL_ROUTE}",
                json={
                    WORKER_ID_KEY: added_workers[i],
                    SIGNED_PHRASE: get_signed_phrase(mode, i,
                                                     challenge_phrase),
                    LAST_WORKER_MODEL_VERSION: "0"
                }).content
            model_return = msgpack.unpackb(
                zlib.decompress(model_return_binary))
            assert isinstance(model_return, dict)
            assert model_return[GLOBAL_MODEL_VERSION] == global_model_version
            assert msgpack.unpackb(
                model_return[GLOBAL_MODEL]) == "Pickle dump of a string"

        # Phase 7: Delete existing workers.
        for i in range(num_workers):
            response = requests.delete(
                f"http://{server.server_host_ip}:{server.server_port}/{WORKERS_ROUTE}"
                f"/{added_workers[i]}",
                auth=admin_auth)
            message_dict = json.loads(response.content.decode('utf-8'))
            assert SUCCESS_MESSAGE_KEY in message_dict
        assert len(worker_ids) == 0

        # Phase 8: Try to send updates to the deleted workers - this should fail
        worker_updates = {}
        for i in range(num_workers):
            # send updates
            signed_phrase = get_signed_phrase(
                mode, i,
                hashlib.sha256(msgpack.packb("Model update!!")).digest())
            response = requests.post(
                f"http://{server.server_host_ip}:{server.server_port}/"
                f"{RECEIVE_WORKER_UPDATE_ROUTE}/{added_workers[i]}",
                files={
                    WORKER_MODEL_UPDATE_KEY:
                    zlib.compress(msgpack.packb("Model update!!")),
                    SIGNED_PHRASE:
                    signed_phrase
                }).content
            assert added_workers[i] not in worker_updates
            assert response.decode('UTF-8') == INVALID_WORKER

            # receive updates
            challenge_phrase = requests.get(
                f"http://{server.server_host_ip}:{server.server_port}/"
                f"{CHALLENGE_PHRASE_ROUTE}/{added_workers[i]}").content
            model_return_binary = requests.post(
                f"http://{server.server_host_ip}:{server.server_port}/{RETURN_GLOBAL_MODEL_ROUTE}",
                json={
                    WORKER_ID_KEY: added_workers[i],
                    SIGNED_PHRASE: get_signed_phrase(mode, i,
                                                     challenge_phrase),
                    LAST_WORKER_MODEL_VERSION: "0"
                }).content
            assert response.decode('UTF-8') == INVALID_WORKER

        # Phase 9: Try to register non-existent workers using the public API
        # - this should fail in the safe mode and succeed in the unsafe mode.
        for i in range(num_workers):
            registration_data = {
                PUBLIC_KEY_STR: get_worker_key(mode, i),
                SIGNED_PHRASE: get_signed_phrase(mode, i)
            }
            response = requests.post(
                f"http://{server.server_host_ip}:{server.server_port}/{REGISTER_WORKER_ROUTE}",
                json=registration_data)
            if mode == 'safe':
                assert response.content.decode('utf-8') == INVALID_WORKER
            else:
                assert 'unauthenticated' in response.content.decode('utf-8')

        # Phase 10 - for the safe mode try registering with the public and admin API
        # with invalid public keys - these should both fail
        if mode == 'safe':
            for i in range(num_workers):
                registration_data = {
                    PUBLIC_KEY_STR: "dummy public key",
                    SIGNED_PHRASE: get_signed_phrase(mode, i)
                }
                response = requests.post(
                    f"http://{server.server_host_ip}:{server.server_port}/{REGISTER_WORKER_ROUTE}",
                    json=registration_data)
                assert response.content.decode('utf-8') == INVALID_WORKER

                registration_data = {
                    PUBLIC_KEY_STR: get_worker_key(mode, i),
                    SIGNED_PHRASE: "dummy signed phrase key"
                }
                response = requests.post(
                    f"http://{server.server_host_ip}:{server.server_port}/{REGISTER_WORKER_ROUTE}",
                    json=registration_data)
                assert response.content.decode('utf-8') == INVALID_WORKER

                admin_registered_worker = {
                    PUBLIC_KEY_STR: "dummy public key",
                    REGISTRATION_STATUS_KEY: True
                }
                response = requests.post(
                    f"http://{server.server_host_ip}:{server.server_port}/{WORKERS_ROUTE}",
                    json=admin_registered_worker,
                    auth=admin_auth)
                message = json.loads(response.content.decode('utf-8'))
                assert ERROR_MESSAGE_KEY in message
                key_short = "dummy public key"[0:WID_LEN]
                assert message[ERROR_MESSAGE_KEY] == \
                       f"Unable to validate public key (short) {key_short} " \
                       "- worker not added."

        stoppable_server.shutdown()
Exemplo n.º 6
0
class ExampleGlobalModel(object):
    """
    This is a simple class that illustrates how the DCFServer class may be used to
    implement a federated global model. For testing purposes, it writes all the
    models it creates and receives to disk.
    """
    def __init__(self):
        self.worker_updates = {}
        self.global_model = ExampleModelClass()
        with open("egm_global_model.torch", 'wb') as f:
            torch.save(self.global_model, f)

        self.global_model_version = 0

        self.server = DCFServer(
            register_worker_callback=self.register_worker,
            unregister_worker_callback=self.unregister_worker,
            return_global_model_callback=self.return_global_model,
            is_global_model_most_recent=self.is_global_model_most_recent,
            receive_worker_update_callback=self.receive_worker_update,
            server_mode_safe=False,
            key_list_file=None,
            load_last_session_workers=False)

    def register_worker(self, worker_id):
        """
        Register the given worker_id by initializing its update to None.

        Parameters
        ----------

        worker_id: int
            The id of the new worker.
        """
        logger.info(
            f"Example Global Model: Registering worker {worker_id[0:WID_LEN]}")
        self.worker_updates[worker_id] = None

    def unregister_worker(self, worker_id):
        """
        Unregister the given worker_id by removing it from updates.

        Parameters
        ----------

        worker_id: int
            The id of the worker to be removed.
        """
        logger.info(
            f"Example Global Model: Unregistering worker {worker_id[0:WID_LEN]}"
        )
        self.worker_updates.pop(worker_id)

    def return_global_model(self):
        """
        Serializes the current global torch model and sends it back to the worker.

        Returns
        ----------

        dict:
            The model dictionary as per the specification in DCFSever
        """
        logger.info(f"Example Global Model: returning global model")
        model_data = io.BytesIO()
        torch.save(self.global_model, model_data)
        return create_model_dict(model_data.getvalue(),
                                 self.global_model_version)

    def is_global_model_most_recent(self, model_version):
        """
        Returns a default model update time of 2018/10/10.

        Parameter
        ---------

        model_version: int

        Returns
        ----------

        str:
            String format of the last model update time.
        """
        logger.info(
            f"Example Global Model: checking if model version is most recent.")
        return self.global_model_version == model_version

    def receive_worker_update(self, worker_id, model_update):
        """
        Given an update for a worker, adds the the update to the list of updates.

        Returns
        ----------

        str:
            String format of the last model update time.
        """
        if worker_id in self.worker_updates:
            self.worker_updates[worker_id] = \
                torch.load(io.BytesIO(model_update))
            logger.info(
                f"Model update received from worker {worker_id[0:WID_LEN]}")
            logger.info(self.worker_updates[worker_id])
            with open(f"egm_worker_update_{worker_id}.torch", 'wb') as f:
                torch.save(self.worker_updates[worker_id], f)
            self.global_model_version += 1
            return f"Update received for worker {worker_id[0:WID_LEN]}"
        else:
            return f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update!!"

    def start(self):
        self.server.start_server()
Exemplo n.º 7
0
def test_worker_authentication():
    # Create a set of keys to be supplied to the server
    num_workers = 10
    private_keys = []
    public_keys = []
    worker_key_file_prefix = 'worker_key_file'
    for n in range(num_workers):
        private_key, public_key = gen_pair(worker_key_file_prefix + f'_{n}')
        private_keys.append(private_key)
        public_keys.append(public_key)

    worker_ids = []
    worker_updates = {}
    global_model_version = "1"
    worker_global_model_version = "0"

    def test_register_func_cb(id):
        worker_ids.append(id)

    def test_unregister_func_cb(id):
        worker_ids.remove(id)

    def test_ret_global_model_cb():
        return create_model_dict(msgpack.packb("Serialized dump of a string"),
                                 global_model_version)

    def is_global_model_most_recent(version):
        return version == global_model_version

    def test_rec_server_update_cb(worker_id, update):
        if worker_id in worker_ids:
            worker_updates[worker_id] = update
            return f"Update received for worker {worker_id[0:WID_LEN]}."
        else:
            return f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update."

    def test_glob_mod_chng_cb(model_dict):
        nonlocal worker_global_model_version
        worker_global_model_version = model_dict[GLOBAL_MODEL_VERSION]

    def test_get_last_glob_model_ver():
        nonlocal worker_global_model_version
        return worker_global_model_version

    worker_key_file = 'worker_public_keys.txt'
    with open(worker_key_file, 'w') as f:
        for public_key in public_keys[:-1]:
            f.write(
                public_key.encode(encoder=HexEncoder).decode('utf-8') +
                os.linesep)
        f.write(public_keys[-1].encode(encoder=HexEncoder).decode('utf-8') +
                os.linesep)

    dcf_server = DCFServer(
        register_worker_callback=test_register_func_cb,
        unregister_worker_callback=test_unregister_func_cb,
        return_global_model_callback=test_ret_global_model_cb,
        is_global_model_most_recent=is_global_model_most_recent,
        receive_worker_update_callback=test_rec_server_update_cb,
        server_mode_safe=True,
        key_list_file=worker_key_file,
        load_last_session_workers=False)
    stoppable_server = StoppableServer(host=get_host_ip(), port=8080)

    def begin_server():
        dcf_server.start_server(stoppable_server)

    server_gl = Greenlet.spawn(begin_server)
    sleep(2)

    # create the workers
    workers = [
        DCFWorker(
            server_protocol='http',
            server_host_ip=dcf_server.server_host_ip,
            server_port=dcf_server.server_port,
            global_model_version_changed_callback=test_glob_mod_chng_cb,
            get_worker_version_of_global_model=test_get_last_glob_model_ver,
            private_key_file=worker_key_file_prefix + f"_{n}")
        for n in range(num_workers)
    ]

    # test various worker actions
    for worker, key in zip(workers, public_keys):
        worker.register_worker()
        global_model_dict = worker.get_global_model()
        worker.send_model_update(b'model_update')
        assert is_valid_model_dict(global_model_dict)
        assert global_model_dict[GLOBAL_MODEL] == msgpack.packb(
            "Serialized dump of a string")
        assert global_model_dict[GLOBAL_MODEL_VERSION] == global_model_version
        assert worker_updates[worker.worker_id] == b'model_update'
        assert worker.worker_id == key.encode(
            encoder=HexEncoder).decode('utf-8')

    # try to authenticate a unregistered worker
    gen_pair('bad_worker')
    bad_worker = DCFWorker(
        server_protocol='http',
        server_host_ip=dcf_server.server_host_ip,
        server_port=dcf_server.server_port,
        global_model_version_changed_callback=test_glob_mod_chng_cb,
        get_worker_version_of_global_model=test_get_last_glob_model_ver,
        private_key_file='bad_worker')
    try:
        bad_worker.register_worker()
    except ValueError:
        assert True
    else:
        assert False

    # try to send an update through the using the bad worker public key
    with open('bad_worker', 'r') as f:
        bad_worker_key = f.read()

    id_and_model_dict_good = {
        WORKER_MODEL_UPDATE_KEY:
        zlib.compress(msgpack.packb("Bad Model update!!")),
        SIGNED_PHRASE:
        SigningKey(bad_worker_key.encode(),
                   encoder=HexEncoder).sign(b"Bad Model update!!").hex()
    }

    response = requests.post(
        f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{RECEIVE_WORKER_UPDATE_ROUTE}/{bad_worker_key}",
        files=id_and_model_dict_good).content
    assert response.decode('utf-8') == INVALID_WORKER

    challenge_phrase = requests.get(
        f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/"
        f"{CHALLENGE_PHRASE_ROUTE}/{bad_worker_key}").content
    assert challenge_phrase.decode('utf-8') == INVALID_WORKER

    response = requests.post(
        f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{RETURN_GLOBAL_MODEL_ROUTE}",
        json={
            WORKER_ID_KEY:
            bad_worker_key,
            SIGNED_PHRASE:
            SigningKey(bad_worker_key.encode(),
                       encoder=HexEncoder).sign(b"Some phrase").hex()
        }).content
    assert response.decode('utf-8') == INVALID_WORKER

    # delete the files
    for n in range(num_workers):
        os.remove(worker_key_file_prefix + f'_{n}')
        os.remove(worker_key_file_prefix + f'_{n}.pub')
    os.remove(worker_key_file)
    os.remove("bad_worker")
    os.remove("bad_worker.pub")

    stoppable_server.shutdown()
Exemplo n.º 8
0
def test_server_functionality():
    """
    Unit tests for the DCFServer and DCFWorker classes.
    """
    worker_ids = []
    worker_updates = {}
    global_model_version = "1"
    worker_global_model_version = "0"
    os.environ[ADMIN_USERNAME] = 'admin'
    os.environ[ADMIN_PASSWORD] = 'str0ng_s3cr3t'
    admin_auth = ('admin', 'str0ng_s3cr3t')

    stoppable_server = StoppableServer(host=get_host_ip(), port=8080)

    def begin_server():
        dcf_server.start_server(stoppable_server)

    def test_register_func_cb(id):
        worker_ids.append(id)

    def test_unregister_func_cb(id):
        worker_ids.remove(id)

    def test_ret_global_model_cb():
        return create_model_dict(msgpack.packb("Pickle dump of a string"),
                                 global_model_version)

    def is_global_model_most_recent(version):
        return int(version) == global_model_version

    def test_rec_server_update_cb(worker_id, update):
        if worker_id in worker_ids:
            worker_updates[worker_id] = update
            return f"Update received for worker {worker_id[0:WID_LEN]}."
        else:
            return f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update."

    def test_glob_mod_chng_cb(model_dict):
        nonlocal worker_global_model_version
        worker_global_model_version = model_dict[GLOBAL_MODEL_VERSION]

    def test_get_last_glob_model_ver():
        nonlocal worker_global_model_version
        return worker_global_model_version

    # try to create a server with incorrect server mode, key file combination - should raise ValueError

    try:
        dcf_server = DCFServer(
            register_worker_callback=test_register_func_cb,
            unregister_worker_callback=test_unregister_func_cb,
            return_global_model_callback=test_ret_global_model_cb,
            is_global_model_most_recent=is_global_model_most_recent,
            receive_worker_update_callback=test_rec_server_update_cb,
            server_mode_safe=False,
            key_list_file="some_file_name.txt",
            load_last_session_workers=False)
    except ValueError as ve:
        error_str = "Server started in unsafe mode but list of public keys provided. " \
                    "Either explicitly start server in safe mode or do not " \
                    "supply a public key list."
        assert str(ve) == error_str
    else:
        assert False

    # now create the actual server instance to use
    dcf_server = DCFServer(
        register_worker_callback=test_register_func_cb,
        unregister_worker_callback=test_unregister_func_cb,
        return_global_model_callback=test_ret_global_model_cb,
        is_global_model_most_recent=is_global_model_most_recent,
        receive_worker_update_callback=test_rec_server_update_cb,
        server_mode_safe=False,
        key_list_file=None)
    server_gl = Greenlet.spawn(begin_server)
    sleep(2)

    # register a set of workers
    data = {
        PUBLIC_KEY_STR: "dummy public key",
        SIGNED_PHRASE: "dummy signed phrase"
    }
    for _ in range(3):
        requests.post(
            f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{REGISTER_WORKER_ROUTE}",
            json=data)

    assert len(worker_ids) == 3
    assert len(set(worker_ids)) == 3
    assert worker_ids[0].__class__ == worker_ids[1].__class__ == worker_ids[
        2].__class__

    response = requests.get(
        f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{WORKERS_ROUTE}",
        auth=admin_auth).content

    workers_list = json.loads(response)
    assert all(
        [worker[WORKER_ID_KEY] in worker_ids for worker in workers_list])

    requests.post(
        f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{WORKERS_ROUTE}",
        json={},
        auth=admin_auth)
    assert len(worker_ids) == 3

    admin_registered_worker = {
        PUBLIC_KEY_STR: "new_public_key",
        REGISTRATION_STATUS_KEY: True
    }
    response = requests.post(
        f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{WORKERS_ROUTE}",
        json=admin_registered_worker,
        auth=admin_auth)

    added_worker_dict = json.loads(response.content.decode('utf-8'))

    assert len(worker_ids) == 4
    assert worker_ids[3] != admin_registered_worker[PUBLIC_KEY_STR]
    assert worker_ids[3] == added_worker_dict[WORKER_ID_KEY]

    requests.delete(
        f"http://{dcf_server.server_host_ip}:"
        f"{dcf_server.server_port}/{WORKERS_ROUTE}/{added_worker_dict[WORKER_ID_KEY]}",
        auth=admin_auth)
    assert len(worker_ids) == 3

    # test getting the global model
    model_return_binary = requests.post(
        f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{RETURN_GLOBAL_MODEL_ROUTE}",
        json={
            WORKER_ID_KEY: worker_ids[0],
            SIGNED_PHRASE: "",
            LAST_WORKER_MODEL_VERSION: "0"
        }).content
    model_return = msgpack.unpackb(zlib.decompress(model_return_binary))
    assert isinstance(model_return, dict)
    assert model_return[GLOBAL_MODEL_VERSION] == global_model_version
    assert msgpack.unpackb(
        model_return[GLOBAL_MODEL]) == "Pickle dump of a string"

    # test sending the model update
    response = requests.post(
        f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{RECEIVE_WORKER_UPDATE_ROUTE}/{worker_ids[1]}",
        files={
            WORKER_MODEL_UPDATE_KEY:
            zlib.compress(msgpack.packb("Model update!!")),
            SIGNED_PHRASE: ""
        }).content

    assert msgpack.unpackb(worker_updates[worker_ids[1]]) == "Model update!!"
    assert response.decode(
        "UTF-8") == f"Update received for worker {worker_ids[1][0:WID_LEN]}."

    response = requests.post(
        f"http://{dcf_server.server_host_ip}:{dcf_server.server_port}/{RECEIVE_WORKER_UPDATE_ROUTE}/3",
        files={
            WORKER_MODEL_UPDATE_KEY:
            zlib.compress(
                msgpack.packb("Model update for unregistered worker!!")),
            SIGNED_PHRASE:
            ""
        }).content

    assert 3 not in worker_updates
    assert response.decode('UTF-8') == INVALID_WORKER

    # *********** #
    # now test a DCFWorker on the same server.
    dcf_worker = DCFWorker(
        server_protocol='http',
        server_host_ip=dcf_server.server_host_ip,
        server_port=dcf_server.server_port,
        global_model_version_changed_callback=test_glob_mod_chng_cb,
        get_worker_version_of_global_model=test_get_last_glob_model_ver,
        private_key_file=None)

    # test worker registration
    dcf_worker.register_worker()
    assert dcf_worker.worker_id == worker_ids[3]

    # test getting the global model update
    global_model_dict = dcf_worker.get_global_model()
    assert is_valid_model_dict(global_model_dict)
    assert global_model_dict[GLOBAL_MODEL_VERSION] == global_model_version
    assert msgpack.unpackb(
        global_model_dict[GLOBAL_MODEL]) == "Pickle dump of a string"

    # test sending the model update
    response = dcf_worker.send_model_update(
        msgpack.packb("DCFWorker model update"))
    assert msgpack.unpackb(
        worker_updates[worker_ids[3]]) == "DCFWorker model update"
    assert response.decode(
        "UTF-8") == f"Update received for worker {worker_ids[3][0:WID_LEN]}."

    stoppable_server.shutdown()
Exemplo n.º 9
0
def test_long_polling():
    # Create a set of keys to be supplied to the server
    num_workers = 100
    private_keys = []
    public_keys = []
    server_model_check_interval = 1
    halt_time = 10

    keys_folder = 'keys_folder'
    if not os.path.exists(keys_folder):
        os.mkdir(keys_folder)
    worker_key_file_prefix = 'worker_key_file'

    for n in range(num_workers):
        private_key, public_key = gen_pair(
            os.path.join(keys_folder, worker_key_file_prefix + f'_{n}'))
        private_keys.append(private_key)
        public_keys.append(public_key)

    worker_ids = []
    worker_updates = {}
    global_model_version = "1"

    def test_register_func_cb(id):
        worker_ids.append(id)

    def test_unregister_func_cb(id):
        worker_ids.remove(id)

    def test_ret_global_model_cb():
        return create_model_dict(msgpack.packb("Pickle dump of a string"),
                                 global_model_version)

    def is_global_model_most_recent(version):
        return version == global_model_version

    def test_rec_server_update_cb(worker_id, update):
        if worker_id in worker_ids:
            worker_updates[worker_id] = update
            return f"Update received for worker {worker_id[0:WID_LEN]}."
        else:
            return f"Unregistered worker {worker_id[0:WID_LEN]} tried to send an update."

    worker_key_file = os.path.join(keys_folder, 'worker_public_keys.txt')
    with open(worker_key_file, 'w') as f:
        for public_key in public_keys[:-1]:
            f.write(
                public_key.encode(encoder=HexEncoder).decode('utf-8') +
                os.linesep)
        f.write(public_keys[-1].encode(encoder=HexEncoder).decode('utf-8') +
                os.linesep)

    dcf_server = DCFServer(
        register_worker_callback=test_register_func_cb,
        unregister_worker_callback=test_unregister_func_cb,
        return_global_model_callback=test_ret_global_model_cb,
        is_global_model_most_recent=is_global_model_most_recent,
        receive_worker_update_callback=test_rec_server_update_cb,
        server_mode_safe=True,
        key_list_file=worker_key_file,
        model_check_interval=server_model_check_interval,
        load_last_session_workers=False)

    stoppable_server = StoppableServer(host=get_host_ip(), port=8080)

    def begin_server():
        dcf_server.start_server(stoppable_server)

    server_gl = Greenlet.spawn(begin_server)
    sleep(2)

    # create the workers
    workers = [
        SimpleLPWorker(
            dcf_server.server_host_ip, dcf_server.server_port,
            os.path.join(keys_folder, worker_key_file_prefix + f"_{n}"))
        for n in range(num_workers)
    ]

    for worker, key in zip(workers, public_keys):
        worker.worker.register_worker()

    # get the current global model and check
    for worker in workers:
        worker.global_model_changed_callback(worker.worker.get_global_model())

    for worker in workers:
        assert worker.gm_version == global_model_version

    done_count = 0

    # test that a single call to the server exits after 5 seconds.
    def run_wg(gl_worker):
        logger.info(f"Starting long poll for {gl_worker.worker.worker_id}")
        gl_worker.global_model_changed_callback(
            gl_worker.worker.get_global_model())
        logger.info(f"Long poll for {gl_worker.worker.worker_id} finished")
        nonlocal done_count
        done_count += 1

    for i, worker in enumerate(workers):
        Greenlet.spawn(run_wg, worker)
        if (i + 1) % 5 == 0:
            sleep(0.5)

    logger.info(f"The test will halt for {halt_time} seconds now...")

    sleep(halt_time)
    global_model_version = "2"

    start_time = datetime.now()
    # if it hasn't stopped after 100 seconds, it has failed.
    while done_count < num_workers and (datetime.now() -
                                        start_time).seconds < 100:
        sleep(1)
        logger.info(
            f"{done_count} workers have received the global model update - need to get to {num_workers}..."
        )
    # all the calls to get the global model should have succeeded by now

    assert done_count == num_workers
    logger.info(f"All workers have received the global model update.")

    stoppable_server.shutdown()

    for f in os.listdir(keys_folder):
        os.remove(os.path.join(keys_folder, f))