async def send_message(self, client_id: str, msg: Dict): data = msg.get("data") if data and msg.get("type", "") == "update_weight": info = AggClient() info.num_samples = int(data["num_samples"]) info.weights = data["weights"] self._client_meta[client_id].info = info current_clinets = [ x.info for x in self._client_meta.values() if x.info ] # exit while aggregation job is NOT start if len(current_clinets) < self.participants_count: return self.current_round += 1 weights = self.aggregation.aggregate(current_clinets) exit_flag = "ok" if self.exit_check() else "continue" msg["type"] = "recv_weight" msg["round_number"] = self.current_round msg["data"] = { "total_sample": self.aggregation.total_size, "round_number": self.current_round, "weights": weights, "exit_flag": exit_flag } for to_client, websocket in self._clients.items(): try: await websocket.send_json(msg) except Exception as err: LOGGER.error(err) else: if msg["type"] == "recv_weight": self._client_meta[to_client].info = None
def add_client(self, client_id: str, websocket: WebSocket): if client_id in self._clients: raise ValueError(f"Client {client_id} is already in the server") LOGGER.info(f"Adding client {client_id}") self._clients[client_id] = websocket self._client_meta[client_id] = WSClientInfo( client_id=client_id, connected_at=time.time(), info=None )
async def on_connect(self, websocket: WebSocket): servername = websocket.scope['path'].lstrip("/") LOGGER.info("Connecting new client...") server: Optional[Aggregator] = self.scope.get(servername) if server is None: raise RuntimeError("HOST `client` instance unavailable!") self.server = server await websocket.accept()
def upload_file(self, files, name=""): if not (files and os.path.isfile(files)): return files if not name: name = os.path.basename(files) LOGGER.info(f"Try to upload file {name}") _url = f"{self.kbserver}/file/upload" with open(files, "rb") as fin: files = {"file": fin} outurl = http_request(url=_url, method="POST", files=files) if FileOps.is_remote(outurl): return outurl outurl = outurl.lstrip("/") FileOps.delete(files) return f"{self.kbserver}/{outurl}"
def __init__(self, data=None, estimator=None, aggregation=None, transmitter=None, chooser=None) -> None: from plato.config import Config # set parameters server = Config().server._asdict() clients = Config().clients._asdict() datastore = Config().data._asdict() train = Config().trainer._asdict() if data is not None: datastore.update(data.parameters) Config().data = Config.namedtuple_from_dict(datastore) self.model = None if estimator is not None: self.model = estimator.model if estimator.pretrained is not None: Config().params['pretrained_model_dir'] = estimator.pretrained if estimator.saved is not None: Config().params['model_dir'] = estimator.saved train.update(estimator.hyperparameters) Config().trainer = Config.namedtuple_from_dict(train) server["address"] = Context.get_parameters("AGG_BIND_IP", "0.0.0.0") server["port"] = int(Context.get_parameters("AGG_BIND_PORT", 7363)) if transmitter is not None: server.update(transmitter.parameters) if aggregation is not None: Config().algorithm = Config.namedtuple_from_dict( aggregation.parameters) if aggregation.parameters["type"] == "mistnet": clients["type"] = "mistnet" server["type"] = "mistnet" else: clients["do_test"] = True if chooser is not None: clients["per_round"] = chooser.parameters["per_round"] LOGGER.info("address %s, port %s", server["address"], server["port"]) Config().server = Config.namedtuple_from_dict(server) Config().clients = Config.namedtuple_from_dict(clients) from plato.servers import registry as server_registry self.server = server_registry.get(model=self.model)
def update_task_status(self, tasks: str, new_status=1): data = { "tasks": tasks, "status": int(new_status) } _url = f"{self.kbserver}/update/status" try: outurl = http_request(url=_url, method="POST", json=data) except Exception as err: LOGGER.error(f"Update kb error: {err}") outurl = None if not outurl: return None if not FileOps.is_remote(outurl): outurl = outurl.lstrip("/") outurl = f"{self.kbserver}/{outurl}" return outurl
def update_db(self, task_info_file): _url = f"{self.kbserver}/update" try: with open(task_info_file, "rb") as fin: files = {"task": fin} outurl = http_request(url=_url, method="POST", files=files) except Exception as err: LOGGER.error(f"Update kb error: {err}") outurl = None else: if not FileOps.is_remote(outurl): outurl = outurl.lstrip("/") outurl = f"{self.kbserver}/{outurl}" FileOps.delete(task_info_file) return outurl
def __init__(self, estimator, callback=None, version="latest"): self.run_flag = True hot_update_conf = Context.get_parameters("MODEL_HOT_UPDATE_CONFIG") if not hot_update_conf: LOGGER.error("As `MODEL_HOT_UPDATE_CONF` unset a value, skipped") self.run_flag = False model_check_time = int( Context.get_parameters("MODEL_POLL_PERIOD_SECONDS", "60")) if model_check_time < 1: LOGGER.warning("Catch an abnormal value in " "`MODEL_POLL_PERIOD_SECONDS`, fallback with 60") model_check_time = 60 self.hot_update_conf = hot_update_conf self.check_time = model_check_time self.production_estimator = estimator self.callback = callback self.version = version self.temp_path = tempfile.gettempdir() super(ModelLoadingThread, self).__init__()
def run(self, app, **kwargs): if hasattr(app, "add_middleware"): app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) LOGGER.info(f"Start {self.server_name} server over {self.url}") config = uvicorn.Config( app, host=self.host, port=self.http_port, ssl_keyfile=self.keyfile, ssl_certfile=self.certfile, workers=self.workers, timeout_keep_alive=self.timeout, log_level="info", **kwargs) server = Server(config=config) with server.run_in_thread() as current_thread: return self.wait_stop(current=current_thread)
def run(self): while self.run_flag: time.sleep(self.check_time) conf = FileOps.download(self.hot_update_conf) if not (conf and FileOps.exists(conf)): continue with open(conf, "r") as fin: try: conf_msg = json.load(fin) model_msg = conf_msg["model_config"] latest_version = str(model_msg["model_update_time"]) model = FileOps.download( model_msg["model_path"], FileOps.join_path(self.temp_path, f"model.{latest_version}")) except (json.JSONDecodeError, KeyError): LOGGER.error(f"fail to parse model hot update config: " f"{self.hot_update_conf}") continue if not (model and FileOps.exists(model)): continue if latest_version == self.version: continue self.version = latest_version with self.MODEL_MANIPULATION_SEM: LOGGER.info(f"Update model start with version {self.version}") try: self.production_estimator.load(model) status = K8sResourceKindStatus.COMPLETED.value LOGGER.info(f"Update model complete " f"with version {self.version}") except Exception as e: LOGGER.error(f"fail to update model: {e}") status = K8sResourceKindStatus.FAILED.value if self.callback: self.callback(task_info=None, status=status, kind="deploy") gc.collect()
def remove_client(self, client_id: str): if client_id not in self._clients: raise ValueError(f"Client {client_id} is not in the server") LOGGER.info(f"Removing Client {client_id} from server") del self._clients[client_id] del self._client_meta[client_id]
async def send_message(self, client_id: str, msg: Dict): for to_client, websocket in self._clients.items(): if to_client == client_id: continue LOGGER.info(f"send data to Client {to_client} from server") await websocket.send_json(msg)
def train(self, train_data: BaseDataSource, valid_data: BaseDataSource = None, post_process=None, **kwargs): """ fit for update the knowledge based on training data. Parameters ---------- train_data : BaseDataSource Train data, see `sedna.datasources.BaseDataSource` for more detail. valid_data : BaseDataSource Valid data, BaseDataSource or None. post_process : function function or a registered method, callback after `estimator` train. kwargs : Dict parameters for `estimator` training, Like: `early_stopping_rounds` in Xgboost.XGBClassifier Returns ------- feedback : Dict contain all training result in each tasks. task_index_url : str task extractor model path, used for task mining. """ tasks, task_extractor, train_data = self._task_definition(train_data) self.extractor = task_extractor task_groups = self._task_relationship_discovery(tasks) self.models = [] callback = None if isinstance(post_process, str): callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)() self.task_groups = [] feedback = {} rare_task = [] for i, task in enumerate(task_groups): if not isinstance(task, TaskGroup): rare_task.append(i) self.models.append(None) self.task_groups.append(None) continue if not (task.samples and len(task.samples) > self.min_train_sample): self.models.append(None) self.task_groups.append(None) rare_task.append(i) n = len(task.samples) LOGGER.info(f"Sample {n} of {task.entry} will be merge") continue LOGGER.info(f"MTL Train start {i} : {task.entry}") model = None for t in task.tasks: # if model has train in tasks if not (t.model and t.result): continue model_path = t.model.save(model_name=f"{task.entry}.model") t.model = model_path model = Model(index=i, entry=t.entry, model=model_path, result=t.result) model.meta_attr = t.meta_attr break if not model: model_obj = set_backend(estimator=self.base_model) res = model_obj.train(train_data=task.samples, **kwargs) if callback: res = callback(model_obj, res) model_path = model_obj.save(model_name=f"{task.entry}.model") model = Model(index=i, entry=task.entry, model=model_path, result=res) model.meta_attr = [t.meta_attr for t in task.tasks] task.model = model self.models.append(model) feedback[task.entry] = model.result self.task_groups.append(task) if len(rare_task): model_obj = set_backend(estimator=self.base_model) res = model_obj.train(train_data=train_data, **kwargs) model_path = model_obj.save(model_name="global.model") for i in rare_task: task = task_groups[i] entry = getattr(task, 'entry', "global") if not isinstance(task, TaskGroup): task = TaskGroup(entry=entry, tasks=[]) model = Model(index=i, entry=entry, model=model_path, result=res) model.meta_attr = [t.meta_attr for t in task.tasks] task.model = model task.samples = train_data self.models[i] = model feedback[entry] = res self.task_groups[i] = task task_index = { "extractor": self.extractor, "task_groups": self.task_groups } if valid_data: feedback, _ = self.evaluate(valid_data, **kwargs) try: FileOps.dump(task_index, self.task_index_url) except TypeError: return feedback, task_index return feedback, self.task_index_url
def http_request(url, method=None, timeout=None, binary=True, **kwargs): _maxTimeout = timeout if timeout else 300 _method = "GET" if not method else method try: response = requests.request(method=_method, url=url, **kwargs) if response.status_code == 200: return (response.json() if binary else response.content.decode("utf-8")) elif 200 < response.status_code < 400: LOGGER.info(f"Redirect_URL: {response.url}") LOGGER.warning( 'Get invalid status code %s while request %s', response.status_code, url) except (ConnectionRefusedError, requests.exceptions.ConnectionError): LOGGER.warning(f'Connection refused while request {url}') except requests.exceptions.HTTPError as err: LOGGER.warning(f"Http Error while request {url} : f{err}") except requests.exceptions.Timeout as err: LOGGER.warning(f"Timeout Error while request {url} : f{err}") except requests.exceptions.RequestException as err: LOGGER.warning(f"Error occurred while request {url} : f{err}")
async def _send(self, data): try: await asyncio.wait_for(self.ws.send(data), self._ws_timeout) return True except Exception as err: LOGGER.info(f"{self.uri} send data failed - with {err}")
async def connect(self): LOGGER.info(f"{self.uri} connection by {self.client_id}") try: self.ws = await asyncio.wait_for(websockets.connect( self.uri, **self.kwargs ), self._ws_timeout) await self.ws.send(json.dumps({'type': 'subscribe', 'client_id': self.client_id})) res = await self.ws.recv() return res except ConnectionRefusedError: LOGGER.warning(f"{self.uri} connection was refused by server") except ConnectionClosedError: LOGGER.warning(f"{self.uri} connection lost") except ConnectionClosedOK: LOGGER.warning(f"{self.uri} connection closed") except InvalidStatusCode as err: LOGGER.warning( f"{self.uri} websocket failed - " f"with invalid status code {err.status_code}") except WebSocketException as err: LOGGER.warning(f"{self.uri} websocket failed - with {err}") except OSError as err: LOGGER.warning(f"{self.uri} connection failed - with {err}")