Beispiel #1
0
    def quit_remote(self, task_name):
        """
        Sending the last message to the server before leaving.

        :param task_name: server task identifier
        :return: server's reply to the last message
        """
        server_message, retry = None, 3
        with self.set_up_channel(self.servers[task_name]) as channel:
            stub = fed_service.FederatedTrainingStub(channel)
            while retry > 0:
                try:
                    self.logger.info('Quitting server: %s', task_name)
                    server_message = stub.Quit(self.client_state(task_name))
                    # Clear the stopping flag
                    # if the connection to server recovered.
                    self.should_stop = False

                    self.logger.info('Received comment from server: %s',
                                     server_message.comment)
                    break
                except grpc.RpcError as grpc_error:
                    self.grpc_error_handler(grpc_error)
                    retry -= 1
                    time.sleep(3)
        return server_message
Beispiel #2
0
    def send_heartbeat(self, task_name):
        if self.token:
            token = fed_msg.Token()
            token.token = self.token

            with self.set_up_channel(self.servers[task_name]) as channel:
                stub = fed_service.FederatedTrainingStub(channel)
                try:
                    self.logger.debug('Send %s heartbeat %s', task_name,
                                      self.token)
                    stub.Heartbeat(token)
                except grpc.RpcError as grpc_error:
                    pass
Beispiel #3
0
    def push_remote_model(self, task_name):
        """
        Read local model and push to self.server[task_name] channel.
        This function makes and sends a Contribution Message.

        :param task_name: should be one of the keys of `self.server`
        """
        # contrib = fed_msg.Contribution()
        # # set client auth. data
        # contrib.client.CopyFrom(self.client_state(task_name))
        # # set model meta info.
        # model_meta = self.model_manager.model_meta(task_name)
        # contrib.client.meta.CopyFrom(model_meta)
        # # set num. of local iterations
        # contrib.n_iter = self.model_manager.num_local_iterations()
        # # set contribution type
        # contrib.type = FED_DELTA_W
        # # set model data
        # model_data = self.model_manager.read_current_model(
        #     task_name, contrib.type)
        # contrib.data.CopyFrom(model_data)

        client_state = self.client_state(task_name)
        contrib = self.data_assembler.get_contribution_data(
            self.model_manager, task_name, client_state)

        server_msg, retry = None, self.retry
        with self.set_up_channel(self.servers[task_name]) as channel:
            stub = fed_service.FederatedTrainingStub(channel)
            while retry > 0:
                try:
                    self.logger.info('Send %s at round %s', task_name,
                                     contrib.client.meta.current_round)
                    server_msg = stub.SubmitUpdate(contrib)
                    # Clear the stopping flag
                    # if the connection to server recovered.
                    self.should_stop = False

                    self.logger.info('Received comments: %s %s',
                                     server_msg.meta.task.name,
                                     server_msg.comment)
                    break
                except grpc.RpcError as grpc_error:
                    if grpc_error.details().startswith('Contrib'):
                        self.logger.info('Publish model failed: %s',
                                         grpc_error.details())
                        break  # outdated contribution, no need to retry
                    self.grpc_error_handler(grpc_error, verbose=self.verbose)
                    retry -= 1
                    time.sleep(5)
        return server_msg
Beispiel #4
0
    def push_remote_fake_update(self, task_name):
        """
        Read local model and push to self.server[task_name] channel.
        This function makes and sends a Contribution Message.

        :param task_name: should be one of the keys of `self.server`
        """

        state_message = fed_msg.ClientState(uid=self.uid, token=self.token)
        state_message.meta.task.name = task_name
        # 组装信息
        contrib = fed_msg.Contribution()
        # set client auth. data
        contrib.client.CopyFrom(state_message)
        # 服务需要验证 ModelMeta 信息
        model_meta = self.model_manager.model_meta(task_name)
        contrib.client.meta.CopyFrom(model_meta)

        server_msg, retry = None, self.retry
        with self.set_up_channel(self.servers[task_name]) as channel:
            stub = fed_service.FederatedTrainingStub(channel)
            while retry > 0:
                try:
                    self.logger.info('Send fake update data')
                    server_msg = stub.SubmitUpdate(contrib)
                    # Clear the stopping flag
                    # if the connection to server recovered.
                    self.should_stop = False

                    self.logger.info('Received comments: %s %s',
                                     server_msg.meta.task.name,
                                     server_msg.comment)
                    break
                except grpc.RpcError as grpc_error:
                    if grpc_error.details().startswith('Contrib'):
                        self.logger.info('Publish fake model failed: %s',
                                         grpc_error.details())
                        break  # outdated contribution, no need to retry
                    self.grpc_error_handler(grpc_error, verbose=self.verbose)
                    retry -= 1
                    time.sleep(5)
        return server_msg
Beispiel #5
0
    def fetch_remote_model(self, task_name):
        """
        Get registered with the remote server via channel,
        and fetch the server's model parameters.

        :param task_name: server identifier string
        :return: a CurrentModel message from server
        """
        reg_result, m_result, retry = None, None, self.retry
        with self.set_up_channel(self.servers[task_name]) as channel:
            stub = fed_service.FederatedTrainingStub(channel)
            if not self.token:
                while retry > 0:
                    try:
                        reg_result = stub.Register(
                            self.client_registration(task_name))
                        # self.logger.info('Registration: {}'.format(reg_result))
                        self.token = reg_result.token
                        self.logger.info(
                            'Successfully registered client:{} for {}. Got token:{}'
                            .format(self.uid, task_name, self.token))
                        # Clear the stopping flag
                        # if the connection to server recovered.
                        self.should_stop = False
                        break
                    except grpc.RpcError as grpc_error:
                        self.grpc_error_handler(grpc_error,
                                                verbose=self.verbose)
                        retry -= 1
                        time.sleep(5)
                if self.should_stop:
                    self.train_end = True
                if reg_result is None:
                    return None

            retry = self.retry
            while retry > 0:
                # get the global model
                try:
                    m_result = stub.GetModel(self.client_state(task_name))
                    # Clear the stopping flag
                    # if the connection to server recovered.
                    self.should_stop = False

                    self.logger.info(
                        'Received {} model at round {} ({} Bytes), update signal: {}'
                        .format(task_name, m_result.meta.current_round,
                                m_result.ByteSize(),
                                m_result.allowed_to_perform_update))
                    break
                except grpc.RpcError as grpc_error:
                    decrease_retry = self.grpc_error_handler(
                        grpc_error, verbose=self.verbose)
                    if decrease_retry:
                        retry -= 1
                    time.sleep(5)

                    # self.logger.info('Retry fetching model...({})'.format(
                    #       server_info))
            if self.should_stop:
                self.train_end = True
        return m_result