def fetch_model_request(self, request, context): """ Looks for the current experiment state in the db/global_models folder and returns the model_definition and model_parameters it to the worker. The model corresponding to the last experiment state is always stored under the experiment_id. Marks the task as done. """ if not self.utils.client_is_valid(request.client, request.secret): logging.debug('Client not known....') return globalserver_pb2.DefaultResponse( message='Client not known....', ok=False) logging.info( f"received fetch_model request from {request.client}/{request.experiment_id}/{request.model_id}" ) model_parameters, experiment_id, task_id, result = self._get_model_parameters( request) if request.client != 'INTERFACE': #todo make nicer experiment_documents = list( self.fl_db.experiment.find({ "_id": experiment_id }).limit(1))[0] if request.client not in experiment_documents['clients']: logging.debug('Experiment not for client....') return globalserver_pb2.DefaultResponse( message='Experiment not for client....', ok=False) else: experiment_documents = list( self.fl_db.model.find({ "_id": experiment_id }).limit(1))[0] # todo make proper stream logging.info(f'Streaming model to {request.client}...') yield globalserver_pb2.Model( message="Streaming model", ok=True, protocol=experiment_documents['protocol'], model_parameters=json.dumps( model_parameters['parameters']).encode('utf-8'), model_definition=json.dumps( model_parameters['model']).encode('utf-8')) logging.info(f'Streaming model to {request.client} finished') if request.experiment_id != '' and request.task_id != '': correct_task_completed = self.utils.task_completion( db=self.fl_db, task_id=task_id, experiment_id=experiment_id, client=request.client, result=result, db_session=self.db_session) if self.client_is_working[request.experiment_id].get( request.client, False) and correct_task_completed: self.client_is_working[request.experiment_id][ request.client] = False
def stopped_experiment_response(self, request, context): """ This tells the global_server that the worker has stopped. If all worker for an experiment stopped it changes status of the experiment to failed, not finished and not running. """ if not self.utils.client_is_valid(request.client, request.secret): return globalserver_pb2.DefaultResponse( message='Client not known....', ok=False) logging.info( f'Client {request.experiment_id} {request.client} stopped') self._check_if_all_clients_stopped(request) return globalserver_pb2.DefaultResponse( message= f'Hello. Thanks for following the protocol. Fetch your next task.', ok=True)
def _finish_up_task_response(self, request, task_id, result): if not self.utils.task_completion(db=self.fl_db, task_id=task_id, experiment_id=ObjectId( request.experiment_id), client=request.client, result=result, db_session=self.db_session): return globalserver_pb2.DefaultResponse( message='Wrong task finished', ok=False) if request.experiment_id in self.client_is_working and request.client in self.client_is_working[ request.experiment_id]: self.client_is_working[request.experiment_id][ request.client] = False return globalserver_pb2.DefaultResponse( message= f'Hello {request.client}. Thanks for following the protocol. Fetch your next task.', ok=True)
def _send_loss_response(self, request, context, data_type): if not self.utils.client_is_valid(request.client, request.secret): return globalserver_pb2.DefaultResponse( message='Client not known....', ok=False) logging.info( f"received send_validation_loss_response from {request.client}/{request.experiment_id}" ) task_id, result = self._save_loss(request, data_type=data_type) logging.info(f'Loss from {request.client} recieved') return self._finish_up_task_response(request, task_id, result)
def send_datasets(self, request, context): if not self.utils.client_is_valid(request.client, request.secret): return globalserver_pb2.DefaultResponse( message='Client not known....', ok=False) try: json.dump(json.loads(request.protocol), open(f"{request.client}_datasets.json", "w+")) logging.info(f"datasets stored") except json.decoder.JSONDecodeError as error: logging.info(request.protocol) return globalserver_pb2.ExperimentResponse(message='', experiment_id='', ok=True)
def send_model_update_response(self, request, context): """ Stores the response from the worker in db/local_model_updates/experiment/task and marks the task as done """ first_request = next(request) weights_stream, request = request, first_request if not self.utils.client_is_valid(request.client, request.secret): return globalserver_pb2.DefaultResponse( message='Client not known....', ok=False) logging.info( f"received send_model_update_response from {request.client}/{request.experiment_id}" ) task_id, result = self._save_model_updates(request, weights_stream) logging.info(f'Model from {request.client} received') return self._finish_up_task_response(request, task_id, result)
def train_model_response(self, request, context): """ This is simply lets the global_server know that the worker finished training. Marks task as done """ if not self.utils.client_is_valid(request.client, request.secret): return globalserver_pb2.DefaultResponse( message='Client not known....', ok=False) logging.info( f"received train_model_response from {request.client}/{request.experiment_id}" ) task_id = ObjectId(request.task_id) if int(os.getenv( 'SERVER', 1)) else request.task_id result = MessageToDict(request) result.pop('secret', None) logging.info(f'{request.client} finished training') return self._finish_up_task_response(request, task_id, result)
def stop_experiment(self, request, context): """ Tells the node controller which experiments they should stop """ if not self.utils.client_is_valid(request.client, request.secret): return globalserver_pb2.DefaultResponse( message='Client not known....', ok=False) experiment_documents = list( self.fl_db.experiment.find({"is_running": True})) finished_experiments = [ str(experiment_document['_id']) for experiment_document in experiment_documents if (request.client in experiment_document['clients']) and ( experiment_document.get("is_finished", False) or experiment_document.get("has_failed", False)) ] if len(finished_experiments) > 0: logging.info( f"Telling {request.client} to stop {finished_experiments}") return globalserver_pb2.ExperimentResponse( message='', experiment_id=json.dumps(finished_experiments), ok=True)
def test_connection(request, context): return globalserver_pb2.DefaultResponse( message= f'Hello. Thanks for following the protocol. Fetch your next task.', ok=True)