def __init__(self, skills: List[Component]) -> None: self.skills = skills self.history: Dict = defaultdict(list) self.states: Dict = defaultdict(lambda: [None] * len(self.skills)) self.wrapped_skills: List[SkillWrapper] = \ [SkillWrapper(skill, skill_id, self) for skill_id, skill in enumerate(self.skills)] self.dialog_logger: DialogLogger = DialogLogger()
def __init__(self, model_config: Path, socket_type: str, port: Optional[int] = None, socket_file: Optional[Union[str, Path]] = None) -> None: """Initialize socket server. Args: model_config: Path to the config file. socket_type: Socket family. "TCP" for the AF_INET socket, "UNIX" for the AF_UNIX. port: Port number for the AF_INET address family. If parameter is not defined, the port number from the model_config is used. socket_file: Path to the file to which server of the AF_UNIX address family connects. If parameter is not defined, the path from the model_config is used. """ socket_config_path = get_settings_path() / SOCKET_CONFIG_FILENAME self._params = get_server_params(socket_config_path, model_config) self._socket_type = socket_type or self._params['socket_type'] if self._socket_type == 'TCP': host = self._params['host'] port = port or self._params['port'] self._address_family = socket.AF_INET self._launch_msg = f'{self._params["binding_message"]} http://{host}:{port}' self._bind_address = (host, port) elif self._socket_type == 'UNIX': self._address_family = socket.AF_UNIX bind_address = socket_file or self._params['unix_socket_file'] bind_address = Path(bind_address).resolve() if bind_address.exists(): bind_address.unlink() self._bind_address = str(bind_address) self._launch_msg = f'{self._params["binding_message"]} {self._bind_address}' else: raise ValueError( f'socket type "{self._socket_type}" is not supported') self._dialog_logger = DialogLogger(agent_name='dp_api') self._log = getLogger(__name__) self._loop = asyncio.get_event_loop() self._model = build_model(model_config) self._socket = socket.socket(self._address_family, socket.SOCK_STREAM) self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self._socket.setblocking(False)
class SocketServer: """Creates socket server that sends the received data to the DeepPavlov model and returns model response. The server receives dictionary serialized to JSON formatted bytes array and sends it to the model. The dictionary keys should match model arguments names, the values should be lists or tuples of inferenced values. Example: {“context”:[“Elon Musk launched his cherry Tesla roadster to the Mars orbit”]} Socket server returns dictionary {'status': status, 'payload': payload} serialized to a JSON formatted byte array, where: status (str): 'OK' if the model successfully processed the data, else - error message. payload: (Optional[List[Tuple]]): The model result if no error has occurred, otherwise None """ _address_family: socket.AddressFamily _bind_address: Union[Tuple[str, int], str] _launch_msg: str _loop: asyncio.AbstractEventLoop _model: Chainer _params: Dict _socket: socket.socket _socket_type: str def __init__(self, model_config: Path, socket_type: str, port: Optional[int] = None, socket_file: Optional[Union[str, Path]] = None) -> None: """Initialize socket server. Args: model_config: Path to the config file. socket_type: Socket family. "TCP" for the AF_INET socket, "UNIX" for the AF_UNIX. port: Port number for the AF_INET address family. If parameter is not defined, the port number from the model_config is used. socket_file: Path to the file to which server of the AF_UNIX address family connects. If parameter is not defined, the path from the model_config is used. """ socket_config_path = get_settings_path() / SOCKET_CONFIG_FILENAME self._params = get_server_params(socket_config_path, model_config) self._socket_type = socket_type or self._params['socket_type'] if self._socket_type == 'TCP': host = self._params['host'] port = port or self._params['port'] self._address_family = socket.AF_INET self._launch_msg = f'{self._params["binding_message"]} http://{host}:{port}' self._bind_address = (host, port) elif self._socket_type == 'UNIX': self._address_family = socket.AF_UNIX bind_address = socket_file or self._params['unix_socket_file'] bind_address = Path(bind_address).resolve() if bind_address.exists(): bind_address.unlink() self._bind_address = str(bind_address) self._launch_msg = f'{self._params["binding_message"]} {self._bind_address}' else: raise ValueError( f'socket type "{self._socket_type}" is not supported') self._dialog_logger = DialogLogger(agent_name='dp_api') self._log = getLogger(__name__) self._loop = asyncio.get_event_loop() self._model = build_model(model_config) self._socket = socket.socket(self._address_family, socket.SOCK_STREAM) self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self._socket.setblocking(False) def start(self) -> None: """Binds the socket to the address and enables the server to accept connections""" self._socket.bind(self._bind_address) self._socket.listen() self._log.info(self._launch_msg) try: self._loop.run_until_complete(self._server()) except Exception as e: self._log.error(f'got exception {e} while running server') finally: self._loop.close() self._socket.close() async def _server(self) -> None: while True: conn, addr = await self._loop.sock_accept(self._socket) self._loop.create_task(self._handle_connection(conn, addr)) async def _handle_connection(self, conn: socket.socket, addr: Tuple) -> None: self._log.info(f'handling connection from {addr}') conn.setblocking(False) recv_data = b'' try: while True: chunk = await self._loop.run_in_executor( None, conn.recv, self._params['bufsize']) if chunk: recv_data += chunk else: break except BlockingIOError: pass try: data = json.loads(recv_data) except ValueError: await self._wrap_error(conn, f'request "{recv_data}" type is not json') return self._dialog_logger.log_in(data) model_args = [] for param_name in self._params['model_args_names']: param_value = data.get(param_name) if param_value is None or (isinstance(param_value, list) and len(param_value) > 0): model_args.append(param_value) else: await self._wrap_error( conn, f"nonempty array expected but got '{param_name}'={repr(param_value)}" ) return lengths = {len(i) for i in model_args if i is not None} if not lengths: await self._wrap_error(conn, 'got empty request') return elif len(lengths) > 1: await self._wrap_error( conn, f'got several different batch sizes: {lengths}') return batch_size = list(lengths)[0] model_args = [arg or [None] * batch_size for arg in model_args] # in case when some parameters were not described in model_args model_args += [[None] * batch_size for _ in range(len(self._model.in_x) - len(model_args))] prediction = await self._loop.run_in_executor(None, self._model, *model_args) if len(self._model.out_params) == 1: prediction = [prediction] prediction = list(zip(*prediction)) result = await self._response('OK', prediction) self._dialog_logger.log_out(result) await self._loop.sock_sendall(conn, result) async def _wrap_error(self, conn: socket.socket, error: str) -> None: self._log.error(error) await self._loop.sock_sendall(conn, await self._response(error, None)) @staticmethod async def _response(status: str, payload: Optional[List[Tuple]]) -> bytes: """Puts arguments into dict and serialize it to JSON formatted byte array. Args: status: Response status. 'OK' if no error has occurred, otherwise error message. payload: DeepPavlov model result if no error has occurred, otherwise None. Returns: dict({'status': status, 'payload': payload}) serialized to a JSON formatted byte array. """ resp_dict = jsonify_data({'status': status, 'payload': payload}) resp_str = json.dumps(resp_dict) return resp_str.encode('utf-8')
from deeppavlov.core.commands.infer import build_model from deeppavlov.core.commands.utils import parse_config from deeppavlov.core.common.chainer import Chainer from deeppavlov.core.common.file import read_json from deeppavlov.core.common.paths import get_settings_path from deeppavlov.core.data.utils import check_nested_dict_keys, jsonify_data SERVER_CONFIG_FILENAME = 'server_config.json' log = getLogger(__name__) app = Flask(__name__) Swagger(app) CORS(app) dialog_logger = DialogLogger(agent_name='dp_api') def get_server_params(server_config_path, model_config): server_config = read_json(server_config_path) model_config = parse_config(model_config) server_params = server_config['common_defaults'] if check_nested_dict_keys(model_config, ['metadata', 'labels', 'server_utils']): model_tag = model_config['metadata']['labels']['server_utils'] if model_tag in server_config['model_defaults']: model_defaults = server_config['model_defaults'][model_tag] for param_name in model_defaults.keys(): if model_defaults[param_name]: