예제 #1
0
    async def send_message(self, client_id: str, msg: Dict):
        data = msg.get("data")
        if data and msg.get("type", "") == "update_weight":
            info = AggClient()
            info.num_samples = int(data["num_samples"])
            info.weights = data["weights"]
            self._client_meta[client_id].info = info
            current_clinets = [
                x.info for x in self._client_meta.values() if x.info
            ]
            # exit while aggregation job is NOT start
            if len(current_clinets) < self.participants_count:
                return
            self.current_round += 1
            weights = self.aggregation.aggregate(current_clinets)
            exit_flag = "ok" if self.exit_check() else "continue"

            msg["type"] = "recv_weight"
            msg["round_number"] = self.current_round
            msg["data"] = {
                "total_sample": self.aggregation.total_size,
                "round_number": self.current_round,
                "weights": weights,
                "exit_flag": exit_flag
            }
        for to_client, websocket in self._clients.items():
            try:
                await websocket.send_json(msg)
            except Exception as err:
                LOGGER.error(err)
            else:
                if msg["type"] == "recv_weight":
                    self._client_meta[to_client].info = None
예제 #2
0
 def add_client(self, client_id: str, websocket: WebSocket):
     if client_id in self._clients:
         raise ValueError(f"Client {client_id} is already in the server")
     LOGGER.info(f"Adding client {client_id}")
     self._clients[client_id] = websocket
     self._client_meta[client_id] = WSClientInfo(
         client_id=client_id, connected_at=time.time(), info=None
     )
예제 #3
0
 async def on_connect(self, websocket: WebSocket):
     servername = websocket.scope['path'].lstrip("/")
     LOGGER.info("Connecting new client...")
     server: Optional[Aggregator] = self.scope.get(servername)
     if server is None:
         raise RuntimeError("HOST `client` instance unavailable!")
     self.server = server
     await websocket.accept()
예제 #4
0
 def upload_file(self, files, name=""):
     if not (files and os.path.isfile(files)):
         return files
     if not name:
         name = os.path.basename(files)
     LOGGER.info(f"Try to upload file {name}")
     _url = f"{self.kbserver}/file/upload"
     with open(files, "rb") as fin:
         files = {"file": fin}
         outurl = http_request(url=_url, method="POST", files=files)
     if FileOps.is_remote(outurl):
         return outurl
     outurl = outurl.lstrip("/")
     FileOps.delete(files)
     return f"{self.kbserver}/{outurl}"
예제 #5
0
    def __init__(self, data=None, estimator=None,
                 aggregation=None, transmitter=None,
                 chooser=None) -> None:
        from plato.config import Config
        # set parameters
        server = Config().server._asdict()
        clients = Config().clients._asdict()
        datastore = Config().data._asdict()
        train = Config().trainer._asdict()

        if data is not None:
            datastore.update(data.parameters)
            Config().data = Config.namedtuple_from_dict(datastore)

        self.model = None
        if estimator is not None:
            self.model = estimator.model
            if estimator.pretrained is not None:
                Config().params['pretrained_model_dir'] = estimator.pretrained
            if estimator.saved is not None:
                Config().params['model_dir'] = estimator.saved
            train.update(estimator.hyperparameters)
            Config().trainer = Config.namedtuple_from_dict(train)

        server["address"] = Context.get_parameters("AGG_BIND_IP", "0.0.0.0")
        server["port"] = int(Context.get_parameters("AGG_BIND_PORT", 7363))
        if transmitter is not None:
            server.update(transmitter.parameters)

        if aggregation is not None:
            Config().algorithm = Config.namedtuple_from_dict(
                aggregation.parameters)
            if aggregation.parameters["type"] == "mistnet":
                clients["type"] = "mistnet"
                server["type"] = "mistnet"
            else:
                clients["do_test"] = True

        if chooser is not None:
            clients["per_round"] = chooser.parameters["per_round"]

        LOGGER.info("address %s, port %s", server["address"], server["port"])

        Config().server = Config.namedtuple_from_dict(server)
        Config().clients = Config.namedtuple_from_dict(clients)

        from plato.servers import registry as server_registry
        self.server = server_registry.get(model=self.model)
예제 #6
0
 def update_task_status(self, tasks: str, new_status=1):
     data = {
         "tasks": tasks,
         "status": int(new_status)
     }
     _url = f"{self.kbserver}/update/status"
     try:
         outurl = http_request(url=_url, method="POST", json=data)
     except Exception as err:
         LOGGER.error(f"Update kb error: {err}")
         outurl = None
     if not outurl:
         return None
     if not FileOps.is_remote(outurl):
         outurl = outurl.lstrip("/")
         outurl = f"{self.kbserver}/{outurl}"
     return outurl
예제 #7
0
    def update_db(self, task_info_file):

        _url = f"{self.kbserver}/update"

        try:
            with open(task_info_file, "rb") as fin:
                files = {"task": fin}
                outurl = http_request(url=_url, method="POST", files=files)

        except Exception as err:
            LOGGER.error(f"Update kb error: {err}")
            outurl = None
        else:
            if not FileOps.is_remote(outurl):
                outurl = outurl.lstrip("/")
                outurl = f"{self.kbserver}/{outurl}"
            FileOps.delete(task_info_file)
        return outurl
예제 #8
0
파일: base.py 프로젝트: XinYao1994/sedna
 def __init__(self, estimator, callback=None, version="latest"):
     self.run_flag = True
     hot_update_conf = Context.get_parameters("MODEL_HOT_UPDATE_CONFIG")
     if not hot_update_conf:
         LOGGER.error("As `MODEL_HOT_UPDATE_CONF` unset a value, skipped")
         self.run_flag = False
     model_check_time = int(
         Context.get_parameters("MODEL_POLL_PERIOD_SECONDS", "60"))
     if model_check_time < 1:
         LOGGER.warning("Catch an abnormal value in "
                        "`MODEL_POLL_PERIOD_SECONDS`, fallback with 60")
         model_check_time = 60
     self.hot_update_conf = hot_update_conf
     self.check_time = model_check_time
     self.production_estimator = estimator
     self.callback = callback
     self.version = version
     self.temp_path = tempfile.gettempdir()
     super(ModelLoadingThread, self).__init__()
예제 #9
0
    def run(self, app, **kwargs):
        if hasattr(app, "add_middleware"):
            app.add_middleware(
                CORSMiddleware, allow_origins=["*"], allow_credentials=True,
                allow_methods=["*"], allow_headers=["*"],
            )

        LOGGER.info(f"Start {self.server_name} server over {self.url}")

        config = uvicorn.Config(
            app,
            host=self.host,
            port=self.http_port,
            ssl_keyfile=self.keyfile,
            ssl_certfile=self.certfile,
            workers=self.workers,
            timeout_keep_alive=self.timeout,
            log_level="info",
            **kwargs)
        server = Server(config=config)
        with server.run_in_thread() as current_thread:
            return self.wait_stop(current=current_thread)
예제 #10
0
파일: base.py 프로젝트: XinYao1994/sedna
 def run(self):
     while self.run_flag:
         time.sleep(self.check_time)
         conf = FileOps.download(self.hot_update_conf)
         if not (conf and FileOps.exists(conf)):
             continue
         with open(conf, "r") as fin:
             try:
                 conf_msg = json.load(fin)
                 model_msg = conf_msg["model_config"]
                 latest_version = str(model_msg["model_update_time"])
                 model = FileOps.download(
                     model_msg["model_path"],
                     FileOps.join_path(self.temp_path,
                                       f"model.{latest_version}"))
             except (json.JSONDecodeError, KeyError):
                 LOGGER.error(f"fail to parse model hot update config: "
                              f"{self.hot_update_conf}")
                 continue
         if not (model and FileOps.exists(model)):
             continue
         if latest_version == self.version:
             continue
         self.version = latest_version
         with self.MODEL_MANIPULATION_SEM:
             LOGGER.info(f"Update model start with version {self.version}")
             try:
                 self.production_estimator.load(model)
                 status = K8sResourceKindStatus.COMPLETED.value
                 LOGGER.info(f"Update model complete "
                             f"with version {self.version}")
             except Exception as e:
                 LOGGER.error(f"fail to update model: {e}")
                 status = K8sResourceKindStatus.FAILED.value
             if self.callback:
                 self.callback(task_info=None, status=status, kind="deploy")
         gc.collect()
예제 #11
0
 def remove_client(self, client_id: str):
     if client_id not in self._clients:
         raise ValueError(f"Client {client_id} is not in the server")
     LOGGER.info(f"Removing Client {client_id} from server")
     del self._clients[client_id]
     del self._client_meta[client_id]
예제 #12
0
 async def send_message(self, client_id: str, msg: Dict):
     for to_client, websocket in self._clients.items():
         if to_client == client_id:
             continue
         LOGGER.info(f"send data to Client {to_client} from server")
         await websocket.send_json(msg)
예제 #13
0
    def train(self,
              train_data: BaseDataSource,
              valid_data: BaseDataSource = None,
              post_process=None,
              **kwargs):
        """
        fit for update the knowledge based on training data.

        Parameters
        ----------
        train_data : BaseDataSource
            Train data, see `sedna.datasources.BaseDataSource` for more detail.
        valid_data : BaseDataSource
            Valid data, BaseDataSource or None.
        post_process : function
            function or a registered method, callback after `estimator` train.
        kwargs : Dict
            parameters for `estimator` training, Like:
            `early_stopping_rounds` in Xgboost.XGBClassifier

        Returns
        -------
        feedback : Dict
            contain all training result in each tasks.
        task_index_url : str
            task extractor model path, used for task mining.
        """

        tasks, task_extractor, train_data = self._task_definition(train_data)
        self.extractor = task_extractor
        task_groups = self._task_relationship_discovery(tasks)
        self.models = []
        callback = None
        if isinstance(post_process, str):
            callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)()
        self.task_groups = []
        feedback = {}
        rare_task = []
        for i, task in enumerate(task_groups):
            if not isinstance(task, TaskGroup):
                rare_task.append(i)
                self.models.append(None)
                self.task_groups.append(None)
                continue
            if not (task.samples
                    and len(task.samples) > self.min_train_sample):
                self.models.append(None)
                self.task_groups.append(None)
                rare_task.append(i)
                n = len(task.samples)
                LOGGER.info(f"Sample {n} of {task.entry} will be merge")
                continue
            LOGGER.info(f"MTL Train start {i} : {task.entry}")

            model = None
            for t in task.tasks:  # if model has train in tasks
                if not (t.model and t.result):
                    continue
                model_path = t.model.save(model_name=f"{task.entry}.model")
                t.model = model_path
                model = Model(index=i,
                              entry=t.entry,
                              model=model_path,
                              result=t.result)
                model.meta_attr = t.meta_attr
                break
            if not model:
                model_obj = set_backend(estimator=self.base_model)
                res = model_obj.train(train_data=task.samples, **kwargs)
                if callback:
                    res = callback(model_obj, res)
                model_path = model_obj.save(model_name=f"{task.entry}.model")
                model = Model(index=i,
                              entry=task.entry,
                              model=model_path,
                              result=res)

                model.meta_attr = [t.meta_attr for t in task.tasks]
            task.model = model
            self.models.append(model)
            feedback[task.entry] = model.result
            self.task_groups.append(task)

        if len(rare_task):
            model_obj = set_backend(estimator=self.base_model)
            res = model_obj.train(train_data=train_data, **kwargs)
            model_path = model_obj.save(model_name="global.model")
            for i in rare_task:
                task = task_groups[i]
                entry = getattr(task, 'entry', "global")
                if not isinstance(task, TaskGroup):
                    task = TaskGroup(entry=entry, tasks=[])
                model = Model(index=i,
                              entry=entry,
                              model=model_path,
                              result=res)
                model.meta_attr = [t.meta_attr for t in task.tasks]
                task.model = model
                task.samples = train_data
                self.models[i] = model
                feedback[entry] = res
                self.task_groups[i] = task

        task_index = {
            "extractor": self.extractor,
            "task_groups": self.task_groups
        }
        if valid_data:
            feedback, _ = self.evaluate(valid_data, **kwargs)
        try:
            FileOps.dump(task_index, self.task_index_url)
        except TypeError:
            return feedback, task_index
        return feedback, self.task_index_url
예제 #14
0
def http_request(url, method=None, timeout=None, binary=True, **kwargs):
    _maxTimeout = timeout if timeout else 300
    _method = "GET" if not method else method
    try:
        response = requests.request(method=_method, url=url, **kwargs)
        if response.status_code == 200:
            return (response.json() if binary else
                    response.content.decode("utf-8"))
        elif 200 < response.status_code < 400:
            LOGGER.info(f"Redirect_URL: {response.url}")
        LOGGER.warning(
            'Get invalid status code %s while request %s',
            response.status_code,
            url)
    except (ConnectionRefusedError, requests.exceptions.ConnectionError):
        LOGGER.warning(f'Connection refused while request {url}')
    except requests.exceptions.HTTPError as err:
        LOGGER.warning(f"Http Error while request {url} : f{err}")
    except requests.exceptions.Timeout as err:
        LOGGER.warning(f"Timeout Error while request {url} : f{err}")
    except requests.exceptions.RequestException as err:
        LOGGER.warning(f"Error occurred while request {url} : f{err}")
예제 #15
0
 async def _send(self, data):
     try:
         await asyncio.wait_for(self.ws.send(data), self._ws_timeout)
         return True
     except Exception as err:
         LOGGER.info(f"{self.uri} send data failed - with {err}")
예제 #16
0
 async def connect(self):
     LOGGER.info(f"{self.uri} connection by {self.client_id}")
     try:
         self.ws = await asyncio.wait_for(websockets.connect(
             self.uri, **self.kwargs
         ), self._ws_timeout)
         await self.ws.send(json.dumps({'type': 'subscribe',
                                        'client_id': self.client_id}))
         res = await self.ws.recv()
         return res
     except ConnectionRefusedError:
         LOGGER.warning(f"{self.uri} connection was refused by server")
     except ConnectionClosedError:
         LOGGER.warning(f"{self.uri} connection lost")
     except ConnectionClosedOK:
         LOGGER.warning(f"{self.uri} connection closed")
     except InvalidStatusCode as err:
         LOGGER.warning(
             f"{self.uri} websocket failed - "
             f"with invalid status code {err.status_code}")
     except WebSocketException as err:
         LOGGER.warning(f"{self.uri} websocket failed - with {err}")
     except OSError as err:
         LOGGER.warning(f"{self.uri} connection failed - with {err}")