def fetch_model_request(self, request, context):
        """
        Looks for the current experiment state in the db/global_models folder and returns the model_definition
        and model_parameters it to the worker. The model corresponding to the last experiment state is always stored
        under the experiment_id.  Marks the task as done.
        """
        if not self.utils.client_is_valid(request.client, request.secret):
            logging.debug('Client not known....')
            return globalserver_pb2.DefaultResponse(
                message='Client not known....', ok=False)

        logging.info(
            f"received fetch_model request from {request.client}/{request.experiment_id}/{request.model_id}"
        )
        model_parameters, experiment_id, task_id, result = self._get_model_parameters(
            request)
        if request.client != 'INTERFACE':  #todo make nicer
            experiment_documents = list(
                self.fl_db.experiment.find({
                    "_id": experiment_id
                }).limit(1))[0]
            if request.client not in experiment_documents['clients']:
                logging.debug('Experiment not for client....')
                return globalserver_pb2.DefaultResponse(
                    message='Experiment not for client....', ok=False)
        else:
            experiment_documents = list(
                self.fl_db.model.find({
                    "_id": experiment_id
                }).limit(1))[0]
        # todo make proper stream
        logging.info(f'Streaming model to {request.client}...')
        yield globalserver_pb2.Model(
            message="Streaming model",
            ok=True,
            protocol=experiment_documents['protocol'],
            model_parameters=json.dumps(
                model_parameters['parameters']).encode('utf-8'),
            model_definition=json.dumps(
                model_parameters['model']).encode('utf-8'))
        logging.info(f'Streaming model to {request.client} finished')

        if request.experiment_id != '' and request.task_id != '':
            correct_task_completed = self.utils.task_completion(
                db=self.fl_db,
                task_id=task_id,
                experiment_id=experiment_id,
                client=request.client,
                result=result,
                db_session=self.db_session)

            if self.client_is_working[request.experiment_id].get(
                    request.client, False) and correct_task_completed:
                self.client_is_working[request.experiment_id][
                    request.client] = False
    def stopped_experiment_response(self, request, context):
        """
        This tells the global_server that the worker has stopped. If all worker for an experiment stopped it changes
        status of the experiment to failed, not finished and not running.
        """
        if not self.utils.client_is_valid(request.client, request.secret):
            return globalserver_pb2.DefaultResponse(
                message='Client not known....', ok=False)

        logging.info(
            f'Client {request.experiment_id} {request.client} stopped')

        self._check_if_all_clients_stopped(request)
        return globalserver_pb2.DefaultResponse(
            message=
            f'Hello. Thanks for following the protocol. Fetch your next task.',
            ok=True)
    def _finish_up_task_response(self, request, task_id, result):
        if not self.utils.task_completion(db=self.fl_db,
                                          task_id=task_id,
                                          experiment_id=ObjectId(
                                              request.experiment_id),
                                          client=request.client,
                                          result=result,
                                          db_session=self.db_session):
            return globalserver_pb2.DefaultResponse(
                message='Wrong task finished', ok=False)

        if request.experiment_id in self.client_is_working and request.client in self.client_is_working[
                request.experiment_id]:
            self.client_is_working[request.experiment_id][
                request.client] = False
        return globalserver_pb2.DefaultResponse(
            message=
            f'Hello {request.client}. Thanks for following the protocol. Fetch your next task.',
            ok=True)
    def _send_loss_response(self, request, context, data_type):
        if not self.utils.client_is_valid(request.client, request.secret):
            return globalserver_pb2.DefaultResponse(
                message='Client not known....', ok=False)

        logging.info(
            f"received send_validation_loss_response from {request.client}/{request.experiment_id}"
        )
        task_id, result = self._save_loss(request, data_type=data_type)
        logging.info(f'Loss from {request.client} recieved')
        return self._finish_up_task_response(request, task_id, result)
 def send_datasets(self, request, context):
     if not self.utils.client_is_valid(request.client, request.secret):
         return globalserver_pb2.DefaultResponse(
             message='Client not known....', ok=False)
     try:
         json.dump(json.loads(request.protocol),
                   open(f"{request.client}_datasets.json", "w+"))
         logging.info(f"datasets stored")
     except json.decoder.JSONDecodeError as error:
         logging.info(request.protocol)
     return globalserver_pb2.ExperimentResponse(message='',
                                                experiment_id='',
                                                ok=True)
    def send_model_update_response(self, request, context):
        """
        Stores the response from the worker in db/local_model_updates/experiment/task and marks the task as done
        """
        first_request = next(request)
        weights_stream, request = request, first_request
        if not self.utils.client_is_valid(request.client, request.secret):
            return globalserver_pb2.DefaultResponse(
                message='Client not known....', ok=False)

        logging.info(
            f"received send_model_update_response from {request.client}/{request.experiment_id}"
        )
        task_id, result = self._save_model_updates(request, weights_stream)
        logging.info(f'Model from {request.client} received')

        return self._finish_up_task_response(request, task_id, result)
    def train_model_response(self, request, context):
        """
        This is simply lets the global_server know that the worker finished training.
        Marks task as done
        """
        if not self.utils.client_is_valid(request.client, request.secret):
            return globalserver_pb2.DefaultResponse(
                message='Client not known....', ok=False)

        logging.info(
            f"received train_model_response from {request.client}/{request.experiment_id}"
        )
        task_id = ObjectId(request.task_id) if int(os.getenv(
            'SERVER', 1)) else request.task_id

        result = MessageToDict(request)
        result.pop('secret', None)

        logging.info(f'{request.client} finished training')
        return self._finish_up_task_response(request, task_id, result)
    def stop_experiment(self, request, context):
        """
        Tells the node controller which experiments they should stop
        """
        if not self.utils.client_is_valid(request.client, request.secret):
            return globalserver_pb2.DefaultResponse(
                message='Client not known....', ok=False)

        experiment_documents = list(
            self.fl_db.experiment.find({"is_running": True}))
        finished_experiments = [
            str(experiment_document['_id'])
            for experiment_document in experiment_documents
            if (request.client in experiment_document['clients']) and (
                experiment_document.get("is_finished", False)
                or experiment_document.get("has_failed", False))
        ]
        if len(finished_experiments) > 0:
            logging.info(
                f"Telling {request.client} to stop {finished_experiments}")
        return globalserver_pb2.ExperimentResponse(
            message='',
            experiment_id=json.dumps(finished_experiments),
            ok=True)
 def test_connection(request, context):
     return globalserver_pb2.DefaultResponse(
         message=
         f'Hello. Thanks for following the protocol. Fetch your next task.',
         ok=True)