Пример #1
0
def start_model_server(model_config, https=False, ssl_key=None, ssl_cert=None, port=None):
    server_config_path = get_settings_path() / SERVER_CONFIG_FILENAME
    server_params = get_server_params(server_config_path, model_config)

    host = server_params['host']
    port = port or server_params['port']
    model_endpoint = server_params['model_endpoint']
    model_args_names = server_params['model_args_names']

    https = https or server_params['https']

    if https:
        ssh_key_path = Path(ssl_key or server_params['https_key_path']).resolve()
        if not ssh_key_path.is_file():
            e = FileNotFoundError('Ssh key file not found: please provide correct path in --key param or '
                                  'https_key_path param in server configuration file')
            log.error(e)
            raise e

        ssh_cert_path = Path(ssl_cert or server_params['https_cert_path']).resolve()
        if not ssh_cert_path.is_file():
            e = FileNotFoundError('Ssh certificate file not found: please provide correct path in --cert param or '
                                  'https_cert_path param in server configuration file')
            log.error(e)
            raise e

        ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
        ssl_context.load_cert_chain(ssh_cert_path, ssh_key_path)
    else:
        ssl_context = None

    model = build_model(model_config)

    @app.route('/')
    def index():
        return redirect('/apidocs/')

    endpoint_description = {
        'description': 'A model endpoint',
        'parameters': [
            {
                'name': 'data',
                'in': 'body',
                'required': 'true',
                'example': {arg: ['value'] for arg in model_args_names}
            }
        ],
        'responses': {
            "200": {
                "description": "A model response"
            }
        }
    }

    @app.route(model_endpoint, methods=['POST'])
    @swag_from(endpoint_description)
    def answer():
        return interact(model, model_args_names)

    app.run(host=host, port=port, threaded=False, ssl_context=ssl_context)
Пример #2
0
    def interact_api(config_path):
        server_conf_file = get_settings_path() / SERVER_CONFIG_FILENAME

        server_params = get_server_params(server_conf_file, config_path)
        model_args_names = server_params['model_args_names']

        url_base = 'http://{}:{}/'.format(server_params['host'], api_port or server_params['port'])
        url = urljoin(url_base.replace('http://0.0.0.0:', 'http://127.0.0.1:'), server_params['model_endpoint'])

        post_headers = {'Accept': 'application/json'}

        post_payload = {}
        for arg_name in model_args_names:
            arg_value = str(' '.join(['qwerty'] * 10))
            post_payload[arg_name] = [arg_value]

        logfile = io.BytesIO(b'')
        args = [sys.executable, "-m", "deeppavlov", "riseapi", str(config_path)]
        if api_port:
            args += ['-p', str(api_port)]
        p = pexpect.popen_spawn.PopenSpawn(' '.join(args),
                                           timeout=None, logfile=logfile)
        try:
            p.expect(url_base)
            post_response = requests.post(url, json=post_payload, headers=post_headers)
            response_code = post_response.status_code
            assert response_code == 200, f"POST request returned error code {response_code} with {config_path}"

        except pexpect.exceptions.EOF:
            raise RuntimeError('Got unexpected EOF: \n{}'.format(logfile.getvalue().decode()))

        finally:
            p.kill(signal.SIGTERM)
            p.wait()
Пример #3
0
def init_bot_for_model(agent: Agent, token: str, model_name: str):
    bot = telebot.TeleBot(token)

    models_info_path = Path(get_settings_path(), TELEGRAM_MODELS_INFO_FILENAME).resolve()
    models_info = read_json(str(models_info_path))
    model_info = models_info[model_name] if model_name in models_info else models_info['@default']

    @bot.message_handler(commands=['start'])
    def send_start_message(message):
        chat_id = message.chat.id
        out_message = model_info['start_message']
        bot.send_message(chat_id, out_message)

    @bot.message_handler(commands=['help'])
    def send_help_message(message):
        chat_id = message.chat.id
        out_message = model_info['help_message']
        bot.send_message(chat_id, out_message)

    @bot.message_handler()
    def handle_inference(message):
        chat_id = message.chat.id
        context = message.text

        response: RichMessage = agent([context], [chat_id])[0]
        for message in response.json():
            message_text = message['content']
            bot.send_message(chat_id, message_text)

    bot.polling()
Пример #4
0
def run_ms_bot_framework_server(agent_generator: callable,
                                app_id: str,
                                app_secret: str,
                                multi_instance: bool = False,
                                stateful: bool = False,
                                port: Optional[int] = None):

    server_config_path = Path(get_settings_path(),
                              SERVER_CONFIG_FILENAME).resolve()
    server_params = read_json(server_config_path)

    host = server_params['common_defaults']['host']
    port = port or server_params['common_defaults']['port']

    ms_bf_server_params = server_params['ms_bot_framework_defaults']

    ms_bf_server_params['multi_instance'] = multi_instance or server_params[
        'common_defaults']['multi_instance']
    ms_bf_server_params[
        'stateful'] = stateful or server_params['common_defaults']['stateful']

    ms_bf_server_params['auth_url'] = AUTH_URL
    ms_bf_server_params['auth_host'] = AUTH_HOST
    ms_bf_server_params['auth_content_type'] = AUTH_CONTENT_TYPE
    ms_bf_server_params['auth_grant_type'] = AUTH_GRANT_TYPE
    ms_bf_server_params['auth_scope'] = AUTH_SCOPE

    ms_bf_server_params[
        'auth_app_id'] = app_id or ms_bf_server_params['auth_app_id']
    if not ms_bf_server_params['auth_app_id']:
        e = ValueError(
            'Microsoft Bot Framework app id required: initiate -i param '
            'or auth_app_id param in server configuration file')
        log.error(e)
        raise e

    ms_bf_server_params['auth_app_secret'] = app_secret or ms_bf_server_params[
        'auth_app_secret']
    if not ms_bf_server_params['auth_app_secret']:
        e = ValueError(
            'Microsoft Bot Framework app secret required: initiate -s param '
            'or auth_app_secret param in server configuration file')
        log.error(e)
        raise e

    input_q = Queue()
    bot = Bot(agent_generator, ms_bf_server_params, input_q)
    bot.start()

    @app.route('/')
    def index():
        return redirect('/apidocs/')

    @app.route('/v3/conversations', methods=['POST'])
    def handle_activity():
        activity = request.get_json()
        bot.input_queue.put(activity)
        return jsonify({}), 200

    app.run(host=host, port=port, threaded=True)
Пример #5
0
    def interact_api(config_path):
        server_conf_file = get_settings_path() / SERVER_CONFIG_FILENAME

        server_params = get_server_params(server_conf_file, config_path)
        model_args_names = server_params['model_args_names']

        url_base = 'http://{}:{}/'.format(server_params['host'], api_port or server_params['port'])
        url = urljoin(url_base.replace('http://0.0.0.0:', 'http://127.0.0.1:'), server_params['model_endpoint'])

        post_headers = {'Accept': 'application/json'}

        post_payload = {}
        for arg_name in model_args_names:
            arg_value = str(' '.join(['qwerty'] * 10))
            post_payload[arg_name] = [arg_value]

        logfile = io.BytesIO(b'')
        args = [sys.executable, "-m", "deeppavlov", "riseapi", str(config_path)]
        if api_port:
            args += ['-p', str(api_port)]
        p = pexpect.popen_spawn.PopenSpawn(' '.join(args),
                                           timeout=None, logfile=logfile)
        try:
            p.expect(url_base)
            post_response = requests.post(url, json=post_payload, headers=post_headers)
            response_code = post_response.status_code
            assert response_code == 200, f"POST request returned error code {response_code} with {config_path}"

        except pexpect.exceptions.EOF:
            raise RuntimeError('Got unexpected EOF: \n{}'.format(logfile.getvalue().decode()))

        finally:
            p.kill(signal.SIGTERM)
            p.wait()
Пример #6
0
def init_bot_for_model(agent: Agent, token: str, model_name: str):
    bot = telebot.TeleBot(token)

    models_info_path = Path(get_settings_path(),
                            TELEGRAM_MODELS_INFO_FILENAME).resolve()
    models_info = read_json(str(models_info_path))
    model_info = models_info[
        model_name] if model_name in models_info else models_info['@default']

    @bot.message_handler(commands=['start'])
    def send_start_message(message):
        chat_id = message.chat.id
        out_message = model_info['start_message']
        bot.send_message(chat_id, out_message)

    @bot.message_handler(commands=['help'])
    def send_help_message(message):
        chat_id = message.chat.id
        out_message = model_info['help_message']
        bot.send_message(chat_id, out_message)

    @bot.message_handler()
    def handle_inference(message):
        chat_id = message.chat.id
        context = message.text

        response: RichMessage = agent([context], [chat_id])[0]
        for message in response.json():
            message_text = message['content']
            bot.send_message(chat_id, message_text)

    bot.polling()
Пример #7
0
def start_alice_server(model_config, https=False, ssl_key=None, ssl_cert=None, port=None):
    server_config_path = get_settings_path() / SERVER_CONFIG_FILENAME
    server_params = get_server_params(server_config_path, model_config)

    https = https or server_params['https']

    if not https:
        ssl_key = ssl_cert = None
    else:
        ssh_key = Path(ssl_key or server_params['https_key_path']).resolve()
        if not ssh_key.is_file():
            e = FileNotFoundError('Ssh key file not found: please provide correct path in --key param or '
                                  'https_key_path param in server configuration file')
            log.error(e)
            raise e

        ssh_cert = Path(ssl_cert or server_params['https_cert_path']).resolve()
        if not ssh_cert.is_file():
            e = FileNotFoundError('Ssh certificate file not found: please provide correct path in --cert param or '
                                  'https_cert_path param in server configuration file')
            log.error(e)
            raise e

    host = server_params['host']
    port = port or server_params['port']
    model_endpoint = server_params['model_endpoint']

    model = build_model(model_config)
    skill = DefaultStatelessSkill(model, lang='ru')
    agent = DefaultAgent([skill], skills_processor=DefaultRichContentWrapper())

    start_agent_server(agent, host, port, model_endpoint, ssl_key, ssl_cert)
Пример #8
0
    def __init__(self, enabled: bool = False, agent_name: Optional[str] = None) -> None:
        self.config: dict = read_json(get_settings_path() / LOGGER_CONFIG_FILENAME)
        self.enabled: bool = enabled or self.config['enabled']

        if self.enabled:
            self.agent_name: str = agent_name or self.config['agent_name']
            self.log_max_size: int = self.config['logfile_max_size_kb']
            self.log_file = self._get_log_file()
            self.log_file.writelines('"Agent initiated"\n')
Пример #9
0
    def __init__(self, enabled: bool = False, logger_name: Optional[str] = None) -> None:
        self.config: dict = read_json(get_settings_path() / LOGGER_CONFIG_FILENAME)
        self.enabled: bool = enabled or self.config['enabled']

        if self.enabled:
            self.logger_name: str = logger_name or self.config['logger_name']
            self.log_max_size: int = self.config['logfile_max_size_kb']
            self.log_file = self._get_log_file()
            self.log_file.writelines('"Dialog logger initiated"\n')
Пример #10
0
    def interact_socket(config_path, socket_type):
        socket_conf_file = get_settings_path() / SOCKET_CONFIG_FILENAME

        socket_params = get_server_params(socket_conf_file, config_path)
        model_args_names = socket_params['model_args_names']

        host = socket_params['host']
        port = api_port or socket_params['port']

        socket_payload = {}
        for arg_name in model_args_names:
            arg_value = ' '.join(['qwerty'] * 10)
            socket_payload[arg_name] = [arg_value]
        dumped_socket_payload = json.dumps(socket_payload)

        logfile = io.BytesIO(b'')
        args = [sys.executable, "-m", "deeppavlov", "risesocket", str(config_path), '--socket-type', socket_type]
        if socket_type == 'TCP':
            args += ['-p', str(port)]
            address_family = socket.AF_INET
            connect_arg = (host, port)
        else:
            address_family = socket.AF_UNIX
            connect_arg = socket_params['unix_socket_file']
        p = pexpect.popen_spawn.PopenSpawn(' '.join(args),
                                           timeout=None, logfile=logfile)
        try:
            p.expect(socket_params['binding_message'])
            with socket.socket(address_family, socket.SOCK_STREAM) as s:
                s.connect(connect_arg)
                s.sendall(dumped_socket_payload.encode('utf-8'))
                data = b''
                try:
                    while True:
                        buf = s.recv(1024)
                        s.setblocking(False)
                        if buf:
                            data += buf
                        else:
                            break
                except BlockingIOError:
                    pass
            resp = json.loads(data)
            assert resp['status'] == 'OK', f"{socket_type} socket request returned status: {resp['status']}"\
                                           f" with {config_path}\n{logfile.getvalue().decode()}"

        except pexpect.exceptions.EOF:
            raise RuntimeError(f'Got unexpected EOF: \n{logfile.getvalue().decode()}')

        except json.JSONDecodeError:
            raise ValueError(f'Got JSON not serializable response from model: "{data}"\n{logfile.getvalue().decode()}')

        finally:
            p.kill(signal.SIGTERM)
            p.wait()
Пример #11
0
def main():
    """DeepPavlov console configuration utility."""
    args = parser.parse_args()
    path = get_settings_path()

    if args.default:
        if populate_settings_dir(force=True):
            print(f'Populated {path} with default settings files')
        else:
            print(f'{path} is already a default settings directory')
    else:
        print(f'Current DeepPavlov settings path: {path}')
Пример #12
0
def main():
    """DeepPavlov console configuration utility."""
    args = parser.parse_args()
    path = get_settings_path()

    if args.default:
        if populate_settings_dir(force=True):
            print(f'Populated {path} with default settings files')
        else:
            print(f'{path} is already a default settings directory')
    else:
        print(f'Current DeepPavlov settings path: {path}')
Пример #13
0
def run_ms_bot_framework_server(agent_generator: callable, app_id: str, app_secret: str,
                                multi_instance: bool = False, stateful: bool = False, port: Optional[int] = None):

    server_config_path = Path(get_settings_path(), SERVER_CONFIG_FILENAME).resolve()
    server_params = read_json(server_config_path)

    host = server_params['common_defaults']['host']
    port = port or server_params['common_defaults']['port']

    ms_bf_server_params = server_params['ms_bot_framework_defaults']

    ms_bf_server_params['multi_instance'] = multi_instance or server_params['common_defaults']['multi_instance']
    ms_bf_server_params['stateful'] = stateful or server_params['common_defaults']['stateful']

    ms_bf_server_params['auth_url'] = AUTH_URL
    ms_bf_server_params['auth_host'] = AUTH_HOST
    ms_bf_server_params['auth_content_type'] = AUTH_CONTENT_TYPE
    ms_bf_server_params['auth_grant_type'] = AUTH_GRANT_TYPE
    ms_bf_server_params['auth_scope'] = AUTH_SCOPE

    ms_bf_server_params['auth_app_id'] = app_id or ms_bf_server_params['auth_app_id']
    if not ms_bf_server_params['auth_app_id']:
        e = ValueError('Microsoft Bot Framework app id required: initiate -i param '
                       'or auth_app_id param in server configuration file')
        log.error(e)
        raise e

    ms_bf_server_params['auth_app_secret'] = app_secret or ms_bf_server_params['auth_app_secret']
    if not ms_bf_server_params['auth_app_secret']:
        e = ValueError('Microsoft Bot Framework app secret required: initiate -s param '
                       'or auth_app_secret param in server configuration file')
        log.error(e)
        raise e

    input_q = Queue()
    bot = Bot(agent_generator, ms_bf_server_params, input_q)
    bot.start()

    @app.route('/')
    def index():
        return redirect('/apidocs/')

    @app.route('/v3/conversations', methods=['POST'])
    def handle_activity():
        activity = request.get_json()
        bot.input_queue.put(activity)
        return jsonify({}), 200

    app.run(host=host, port=port, threaded=True)
Пример #14
0
def interact_model_by_telegram(config, token=None):
    server_config_path = Path(get_settings_path(), SERVER_CONFIG_FILENAME)
    server_config = read_json(server_config_path)
    token = token if token else server_config['telegram_defaults']['token']
    if not token:
        e = ValueError('Telegram token required: initiate -t param or telegram_defaults/token '
                       'in server configuration file')
        log.error(e)
        raise e

    model = build_model(config)
    model_name = type(model.get_main_component()).__name__
    skill = DefaultStatelessSkill(model)
    agent = DefaultAgent([skill], skills_processor=DefaultRichContentWrapper())
    init_bot_for_model(agent, token, model_name)
Пример #15
0
def interact_model_by_telegram(config, token=None):
    server_config_path = Path(get_settings_path(), SERVER_CONFIG_FILENAME)
    server_config = read_json(server_config_path)
    token = token if token else server_config['telegram_defaults']['token']
    if not token:
        e = ValueError(
            'Telegram token required: initiate -t param or telegram_defaults/token '
            'in server configuration file')
        log.error(e)
        raise e

    model = build_model(config)
    model_name = type(model.get_main_component()).__name__
    skill = DefaultStatelessSkill(model)
    agent = DefaultAgent([skill], skills_processor=DefaultRichContentWrapper())
    init_bot_for_model(agent, token, model_name)
Пример #16
0
    def __init__(self,
                 model_config: Path,
                 socket_type: str,
                 port: Optional[int] = None,
                 socket_file: Optional[Union[str, Path]] = None) -> None:
        """Initialize socket server.

        Args:
            model_config: Path to the config file.
            socket_type: Socket family. "TCP" for the AF_INET socket, "UNIX" for the AF_UNIX.
            port: Port number for the AF_INET address family. If parameter is not defined, the port number from the
                model_config is used.
            socket_file: Path to the file to which server of the AF_UNIX address family connects. If parameter
                is not defined, the path from the model_config is used.

        """
        socket_config_path = get_settings_path() / SOCKET_CONFIG_FILENAME
        self._params = get_server_params(socket_config_path, model_config)
        self._socket_type = socket_type or self._params['socket_type']

        if self._socket_type == 'TCP':
            host = self._params['host']
            port = port or self._params['port']
            self._address_family = socket.AF_INET
            self._launch_msg = f'{self._params["binding_message"]} http://{host}:{port}'
            self._bind_address = (host, port)
        elif self._socket_type == 'UNIX':
            self._address_family = socket.AF_UNIX
            bind_address = socket_file or self._params['unix_socket_file']
            bind_address = Path(bind_address).resolve()
            if bind_address.exists():
                bind_address.unlink()
            self._bind_address = str(bind_address)
            self._launch_msg = f'{self._params["binding_message"]} {self._bind_address}'
        else:
            raise ValueError(
                f'socket type "{self._socket_type}" is not supported')

        self._dialog_logger = DialogLogger(agent_name='dp_api')
        self._log = getLogger(__name__)
        self._loop = asyncio.get_event_loop()
        self._model = build_model(model_config)
        self._socket = socket.socket(self._address_family, socket.SOCK_STREAM)

        self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self._socket.setblocking(False)
Пример #17
0
    def _get_connector_params(self) -> dict:
        """Reads bot and conversation default params from connector config file.

         Returns:
             connector_defaults: Dictionary containing bot defaults and conversation defaults dicts.

        """
        connector_config_path = get_settings_path() / CONNECTOR_CONFIG_FILENAME
        connector_config: dict = read_json(connector_config_path)

        bot_name = type(self).__name__
        conversation_defaults = connector_config['telegram']
        bot_defaults = connector_config['deprecated'].get(
            bot_name, conversation_defaults)

        connector_defaults = {
            'bot_defaults': bot_defaults,
            'conversation_defaults': conversation_defaults
        }

        return connector_defaults
Пример #18
0
def start_rabbit_service(model_config: Union[str, Path],
                         service_name: Optional[str] = None,
                         agent_namespace: Optional[str] = None,
                         batch_size: Optional[int] = None,
                         utterance_lifetime_sec: Optional[int] = None,
                         rabbit_host: Optional[str] = None,
                         rabbit_port: Optional[int] = None,
                         rabbit_login: Optional[str] = None,
                         rabbit_password: Optional[str] = None,
                         rabbit_virtualhost: Optional[str] = None) -> None:
    """Launches DeepPavlov model receiving utterances and sending responses via RabbitMQ message broker.

    Args:
        model_config: Path to DeepPavlov model to be launched.
        service_name: Service name set in DeepPavlov Agent config. Used to format RabbitMQ exchanges, queues and routing
            keys names.
        agent_namespace: Service processes messages only from agents with the same namespace value.
        batch_size: Limits the maximum number of utterances to be processed by service at one inference.
        utterance_lifetime_sec: RabbitMQ message expiration time in seconds.
        rabbit_host: RabbitMQ server host name.
        rabbit_port: RabbitMQ server port number.
        rabbit_login: RabbitMQ server administrator username.
        rabbit_password: RabbitMQ server administrator password.
        rabbit_virtualhost: RabbitMQ server virtualhost name.

    """
    service_config_path = get_settings_path() / CONNECTOR_CONFIG_FILENAME
    service_config: dict = read_json(service_config_path)['agent-rabbit']

    service_name = service_name or service_config['service_name']
    agent_namespace = agent_namespace or service_config['agent_namespace']
    batch_size = batch_size or service_config['batch_size']
    utterance_lifetime_sec = utterance_lifetime_sec or service_config['utterance_lifetime_sec']
    rabbit_host = rabbit_host or service_config['rabbit_host']
    rabbit_port = rabbit_port or service_config['rabbit_port']
    rabbit_login = rabbit_login or service_config['rabbit_login']
    rabbit_password = rabbit_password or service_config['rabbit_password']
    rabbit_virtualhost = rabbit_virtualhost or service_config['rabbit_virtualhost']

    loop = asyncio.get_event_loop()

    gateway = RabbitMQServiceGateway(
        model_config=model_config,
        service_name=service_name,
        agent_namespace=agent_namespace,
        batch_size=batch_size,
        utterance_lifetime_sec=utterance_lifetime_sec,
        rabbit_host=rabbit_host,
        rabbit_port=rabbit_port,
        rabbit_login=rabbit_login,
        rabbit_password=rabbit_password,
        rabbit_virtualhost=rabbit_virtualhost,
        loop=loop
    )

    try:
        loop.run_forever()
    except KeyboardInterrupt:
        pass
    finally:
        gateway.disconnect()
        loop.stop()
        loop.close()
        logging.shutdown()
Пример #19
0
def run_alexa_server(agent_generator: callable,
                     multi_instance: bool = False,
                     stateful: bool = False,
                     port: Optional[int] = None,
                     https: bool = False,
                     ssl_key: str = None,
                     ssl_cert: str = None) -> None:
    """Initiates Flask web service with Alexa skill.

    Args:
        agent_generator: Callback Alexa agents factory.
        multi_instance: Multi instance mode flag.
        stateful: Stateful mode flag.
        port: Flask web service port.
        https: Flag for running Alexa skill service in https mode.
        ssl_key: SSL key file path.
        ssl_cert: SSL certificate file path.
    """
    server_config_path = Path(get_settings_path(),
                              SERVER_CONFIG_FILENAME).resolve()
    server_params = read_json(server_config_path)

    host = server_params['common_defaults']['host']
    port = port or server_params['common_defaults']['port']

    alexa_server_params = server_params['alexa_defaults']

    alexa_server_params['multi_instance'] = multi_instance or server_params[
        'common_defaults']['multi_instance']
    alexa_server_params[
        'stateful'] = stateful or server_params['common_defaults']['stateful']
    alexa_server_params['amazon_cert_lifetime'] = AMAZON_CERTIFICATE_LIFETIME

    if https:
        ssh_key_path = Path(ssl_key
                            or server_params['https_key_path']).resolve()
        if not ssh_key_path.is_file():
            e = FileNotFoundError(
                'Ssh key file not found: please provide correct path in --key param or '
                'https_key_path param in server configuration file')
            log.error(e)
            raise e

        ssh_cert_path = Path(ssl_cert
                             or server_params['https_cert_path']).resolve()
        if not ssh_cert_path.is_file():
            e = FileNotFoundError(
                'Ssh certificate file not found: please provide correct path in --cert param or '
                'https_cert_path param in server configuration file')
            log.error(e)
            raise e

        ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
        ssl_context.load_cert_chain(ssh_cert_path, ssh_key_path)
    else:
        ssl_context = None

    input_q = Queue()
    output_q = Queue()

    bot = Bot(agent_generator, alexa_server_params, input_q, output_q)
    bot.start()

    endpoint_description = {
        'description':
        'Amazon Alexa custom service endpoint',
        'parameters': [{
            'name':
            'Signature',
            'in':
            'header',
            'required':
            'true',
            'type':
            'string',
            'example':
            'Z5H5wqd06ExFVPNfJiqhKvAFjkf+cTVodOUirucHGcEVAMO1LfvgqWUkZ/X1ITDZbI0w+SMwVkEQZlkeThbVS/54M22StNDUtfz4Ua20xNDpIPwcWIACAmZ38XxbbTEFJI5WwqrbilNcfzqiGrIPfdO5rl+/xUjHFUdcJdUY/QzBxXsceytVYfEiR9MzOCN2m4C0XnpThUavAu159KrLj8AkuzN0JF87iXv+zOEeZRgEuwmsAnJrRUwkJ4yWokEPnSVdjF0D6f6CscfyvRe9nsWShq7/zRTa41meweh+n006zvf58MbzRdXPB22RI4AN0ksWW7hSC8/QLAKQE+lvaw==',
        }, {
            'name':
            'Signaturecertchainurl',
            'in':
            'header',
            'required':
            'true',
            'type':
            'string',
            'example':
            'https://s3.amazonaws.com/echo.api/echo-api-cert-6-ats.pem',
        }, {
            'name': 'data',
            'in': 'body',
            'required': 'true',
            'example': {
                'version': '1.0',
                'session': {
                    'new': False,
                    'sessionId':
                    'amzn1.echo-api.session.3c6ebffd-55b9-4e1a-bf3c-c921c1801b63',
                    'application': {
                        'applicationId':
                        'amzn1.ask.skill.8b17a5de-3749-4919-aa1f-e0bbaf8a46a6'
                    },
                    'attributes': {
                        'sessionId':
                        'amzn1.echo-api.session.3c6ebffd-55b9-4e1a-bf3c-c921c1801b63'
                    },
                    'user': {
                        'userId':
                        'amzn1.ask.account.AGR4R2LOVHMNMNOGROBVNLU7CL4C57X465XJF2T2F55OUXNTLCXDQP3I55UXZIALEKKZJ6Q2MA5MEFSMZVPEL5NVZS6FZLEU444BVOLPB5WVH5CHYTQAKGD7VFLGPRFZVHHH2NIB4HKNHHGX6HM6S6QDWCKXWOIZL7ONNQSBUCVPMZQKMCYXRG5BA2POYEXFDXRXCGEVDWVSMPQ'
                    }
                },
                'context': {
                    'System': {
                        'application': {
                            'applicationId':
                            'amzn1.ask.skill.8b17a5de-3749-4919-aa1f-e0bbaf8a46a6'
                        },
                        'user': {
                            'userId':
                            'amzn1.ask.account.AGR4R2LOVHMNMNOGROBVNLU7CL4C57X465XJF2T2F55OUXNTLCXDQP3I55UXZIALEKKZJ6Q2MA5MEFSMZVPEL5NVZS6FZLEU444BVOLPB5WVH5CHYTQAKGD7VFLGPRFZVHHH2NIB4HKNHHGX6HM6S6QDWCKXWOIZL7ONNQSBUCVPMZQKMCYXRG5BA2POYEXFDXRXCGEVDWVSMPQ'
                        },
                        'device': {
                            'deviceId':
                            'amzn1.ask.device.AFQAMLYOYQUUACSE7HFVYS4ZI2KUB35JPHQRUPKTDCAU3A47WESP5L57KSWT5L6RT3FVXWH4OA2DNPJRMZ2VGEIACF3PJEIDCOUWUBC4W5RPJNUB3ZVT22J4UJN5UL3T2UBP36RVHFJ5P4IPT2HUY3P2YOY33IOU4O33HUAG7R2BUNROEH4T2',
                            'supportedInterfaces': {}
                        },
                        'apiEndpoint':
                        'https://api.amazonalexa.com',
                        'apiAccessToken':
                        'eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhdWQiOiJodHRwczovL2FwaS5hbWF6b25hbGV4YS5jb20iLCJpc3MiOiJBbGV4YVNraWxsS2l0Iiwic3ViIjoiYW16bjEuYXNrLnNraWxsLjhiMTdhNWRlLTM3NDktNDkxOS1hYTFmLWUwYmJhZjhhNDZhNiIsImV4cCI6MTU0NTIyMzY1OCwiaWF0IjoxNTQ1MjIwMDU4LCJuYmYiOjE1NDUyMjAwNTgsInByaXZhdGVDbGFpbXMiOnsiY29uc2VudFRva2VuIjpudWxsLCJkZXZpY2VJZCI6ImFtem4xLmFzay5kZXZpY2UuQUZRQU1MWU9ZUVVVQUNTRTdIRlZZUzRaSTJLVUIzNUpQSFFSVVBLVERDQVUzQTQ3V0VTUDVMNTdLU1dUNUw2UlQzRlZYV0g0T0EyRE5QSlJNWjJWR0VJQUNGM1BKRUlEQ09VV1VCQzRXNVJQSk5VQjNaVlQyMko0VUpONVVMM1QyVUJQMzZSVkhGSjVQNElQVDJIVVkzUDJZT1kzM0lPVTRPMzNIVUFHN1IyQlVOUk9FSDRUMiIsInVzZXJJZCI6ImFtem4xLmFzay5hY2NvdW50LkFHUjRSMkxPVkhNTk1OT0dST0JWTkxVN0NMNEM1N1g0NjVYSkYyVDJGNTVPVVhOVExDWERRUDNJNTVVWFpJQUxFS0taSjZRMk1BNU1FRlNNWlZQRUw1TlZaUzZGWkxFVTQ0NEJWT0xQQjVXVkg1Q0hZVFFBS0dEN1ZGTEdQUkZaVkhISDJOSUI0SEtOSEhHWDZITTZTNlFEV0NLWFdPSVpMN09OTlFTQlVDVlBNWlFLTUNZWFJHNUJBMlBPWUVYRkRYUlhDR0VWRFdWU01QUSJ9fQ.jcomYhBhU485T4uoe2NyhWnL-kZHoPQKpcycFqa-1sy_lSIitfFGup9DKrf2NkN-I9lZ3xwq9llqx9WRN78fVJjN6GLcDhBDH0irPwt3n9_V7_5bfB6KARv5ZG-JKOmZlLBqQbnln0DAJ10D8HNiytMARNEwduMBVDNK0A5z6YxtRcLYYFD2-Ieg_V8Qx90eE2pd2U5xOuIEL0pXfSoiJ8vpxb8BKwaMO47tdE4qhg_k7v8ClwyXg3EMEhZFjixYNqdW1tCrwDGj58IWMXDyzZhIlRMh6uudMOT6scSzcNVD0v42IOTZ3S_X6rG01B7xhUDlZXMqkrCuzOyqctGaPw'
                    },
                    'Viewport': {
                        'experiences': [{
                            'arcMinuteWidth': 246,
                            'arcMinuteHeight': 144,
                            'canRotate': False,
                            'canResize': False
                        }],
                        'shape':
                        'RECTANGLE',
                        'pixelWidth':
                        1024,
                        'pixelHeight':
                        600,
                        'dpi':
                        160,
                        'currentPixelWidth':
                        1024,
                        'currentPixelHeight':
                        600,
                        'touch': ['SINGLE']
                    }
                },
                'request': {
                    'type': 'IntentRequest',
                    'requestId':
                    'amzn1.echo-api.request.388d0f6e-04b9-4450-a687-b9abaa73ac6a',
                    'timestamp': '2018-12-19T11:47:38Z',
                    'locale': 'en-US',
                    'intent': {
                        'name': 'AskDeepPavlov',
                        'confirmationStatus': 'NONE',
                        'slots': {
                            'raw_input': {
                                'name': 'raw_input',
                                'value': 'my beautiful sandbox skill',
                                'resolutions': {
                                    'resolutionsPerAuthority': [{
                                        'authority':
                                        'amzn1.er-authority.echo-sdk.amzn1.ask.skill.8b17a5de-3749-4919-aa1f-e0bbaf8a46a6.GetInput',
                                        'status': {
                                            'code': 'ER_SUCCESS_NO_MATCH'
                                        }
                                    }]
                                },
                                'confirmationStatus': 'NONE',
                                'source': 'USER'
                            }
                        }
                    }
                }
            }
        }],
        'responses': {
            "200": {
                "description": "A model response"
            }
        }
    }

    @app.route('/')
    def index():
        return redirect('/apidocs/')

    @app.route('/interact', methods=['POST'])
    @swag_from(endpoint_description)
    def handle_request():
        request_body: bytes = request.get_data()
        signature_chain_url: str = request.headers.get('Signaturecertchainurl')
        signature: str = request.headers.get('Signature')
        alexa_request: dict = request.get_json()

        request_dict = {
            'request_body': request_body,
            'signature_chain_url': signature_chain_url,
            'signature': signature,
            'alexa_request': alexa_request
        }

        bot.input_queue.put(request_dict)
        response: dict = bot.output_queue.get()
        response_code = 400 if 'error' in response.keys() else 200

        return jsonify(response), response_code

    app.run(host=host, port=port, threaded=True, ssl_context=ssl_context)
Пример #20
0
def start_model_server(model_config: Path,
                       https: bool = False,
                       ssl_key: Optional[str] = None,
                       ssl_cert: Optional[str] = None,
                       port: Optional[int] = None) -> None:
    server_config_path = get_settings_path() / SERVER_CONFIG_FILENAME
    server_params = get_server_params(server_config_path, model_config)

    host = server_params['host']
    port = port or server_params['port']
    model_endpoint = server_params['model_endpoint']
    docs_endpoint = server_params['docs_endpoint']
    model_args_names = server_params['model_args_names']

    https = https or server_params['https']

    if https:
        ssh_key_path = Path(ssl_key
                            or server_params['https_key_path']).resolve()
        if not ssh_key_path.is_file():
            e = FileNotFoundError(
                'Ssh key file not found: please provide correct path in --key param or '
                'https_key_path param in server configuration file')
            log.error(e)
            raise e

        ssh_cert_path = Path(ssl_cert
                             or server_params['https_cert_path']).resolve()
        if not ssh_cert_path.is_file():
            e = FileNotFoundError(
                'Ssh certificate file not found: please provide correct path in --cert param or '
                'https_cert_path param in server configuration file')
            log.error(e)
            raise e

        ssl_version = PROTOCOL_TLSv1_2
        ssl_keyfile = str(ssh_key_path)
        ssl_certfile = str(ssh_cert_path)
    else:
        ssl_version = None
        ssl_keyfile = None
        ssl_certfile = None

    model = build_model(model_config)

    def batch_decorator(cls: MetaModel) -> MetaModel:
        cls.__annotations__ = {
            arg_name: List[str]
            for arg_name in model_args_names
        }
        cls.__fields__ = {
            arg_name: Field(name=arg_name,
                            type_=List[str],
                            class_validators=None,
                            model_config=BaseConfig,
                            required=False,
                            schema=Schema(None))
            for arg_name in model_args_names
        }
        return cls

    @batch_decorator
    class Batch(BaseModel):
        pass

    @app.get('/', include_in_schema=False)
    async def redirect_to_docs() -> RedirectResponse:
        operation_id = generate_operation_id_for_path(name='answer',
                                                      path=model_endpoint,
                                                      method='post')
        response = RedirectResponse(
            url=f'{docs_endpoint}#/default/{operation_id}')
        return response

    @app.post(model_endpoint, status_code=200, summary='A model endpoint')
    async def answer(item: Batch) -> JSONResponse:
        return interact(model, item.dict())

    @app.post('/probe', status_code=200, include_in_schema=False)
    async def probe(item: Batch) -> JSONResponse:
        return test_interact(model, item.dict())

    @app.get('/api', status_code=200, summary='Model argument names')
    async def api() -> JSONResponse:
        return JSONResponse(model_args_names)

    uvicorn.run(app,
                host=host,
                port=port,
                logger=uvicorn_log,
                ssl_version=ssl_version,
                ssl_keyfile=ssl_keyfile,
                ssl_certfile=ssl_certfile)
Пример #21
0
def run_ms_bot_framework_server(agent_generator: callable,
                                app_id: str,
                                app_secret: str,
                                multi_instance: bool = False,
                                stateful: bool = False,
                                port: Optional[int] = None,
                                https: bool = False,
                                ssl_key: str = None,
                                ssl_cert: str = None):

    server_config_path = Path(get_settings_path(), SERVER_CONFIG_FILENAME).resolve()
    server_params = read_json(server_config_path)

    host = server_params['common_defaults']['host']
    port = port or server_params['common_defaults']['port']

    ms_bf_server_params = server_params['ms_bot_framework_defaults']

    ms_bf_server_params['multi_instance'] = multi_instance or server_params['common_defaults']['multi_instance']
    ms_bf_server_params['stateful'] = stateful or server_params['common_defaults']['stateful']

    ms_bf_server_params['auth_url'] = AUTH_URL
    ms_bf_server_params['auth_host'] = AUTH_HOST
    ms_bf_server_params['auth_content_type'] = AUTH_CONTENT_TYPE
    ms_bf_server_params['auth_grant_type'] = AUTH_GRANT_TYPE
    ms_bf_server_params['auth_scope'] = AUTH_SCOPE

    ms_bf_server_params['auth_app_id'] = app_id or ms_bf_server_params['auth_app_id']
    if not ms_bf_server_params['auth_app_id']:
        e = ValueError('Microsoft Bot Framework app id required: initiate -i param '
                       'or auth_app_id param in server configuration file')
        log.error(e)
        raise e

    ms_bf_server_params['auth_app_secret'] = app_secret or ms_bf_server_params['auth_app_secret']
    if not ms_bf_server_params['auth_app_secret']:
        e = ValueError('Microsoft Bot Framework app secret required: initiate -s param '
                       'or auth_app_secret param in server configuration file')
        log.error(e)
        raise e

    if https:
        ssh_key_path = Path(ssl_key or server_params['https_key_path']).resolve()
        if not ssh_key_path.is_file():
            e = FileNotFoundError('Ssh key file not found: please provide correct path in --key param or '
                                  'https_key_path param in server configuration file')
            log.error(e)
            raise e

        ssh_cert_path = Path(ssl_cert or server_params['https_cert_path']).resolve()
        if not ssh_cert_path.is_file():
            e = FileNotFoundError('Ssh certificate file not found: please provide correct path in --cert param or '
                                  'https_cert_path param in server configuration file')
            log.error(e)
            raise e

        ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
        ssl_context.load_cert_chain(ssh_cert_path, ssh_key_path)
    else:
        ssl_context = None

    input_q = Queue()
    bot = Bot(agent_generator, ms_bf_server_params, input_q)
    bot.start()

    @app.route('/')
    def index():
        return redirect('/apidocs/')

    @app.route('/v3/conversations', methods=['POST'])
    def handle_activity():
        activity = request.get_json()
        bot.input_queue.put(activity)
        return jsonify({}), 200

    app.run(host=host, port=port, threaded=True, ssl_context=ssl_context)
Пример #22
0
def run_alexa_server(agent_generator: callable, multi_instance: bool = False,
                     stateful: bool = False, port: Optional[int] = None, https: bool = False,
                     ssl_key: str = None, ssl_cert: str = None) -> None:
    """Initiates Flask web service with Alexa skill.

    Args:
        agent_generator: Callback Alexa agents factory.
        multi_instance: Multi instance mode flag.
        stateful: Stateful mode flag.
        port: Flask web service port.
        https: Flag for running Alexa skill service in https mode.
        ssl_key: SSL key file path.
        ssl_cert: SSL certificate file path.
    """
    server_config_path = Path(get_settings_path(), SERVER_CONFIG_FILENAME).resolve()
    server_params = read_json(server_config_path)

    host = server_params['common_defaults']['host']
    port = port or server_params['common_defaults']['port']

    alexa_server_params = server_params['alexa_defaults']

    alexa_server_params['multi_instance'] = multi_instance or server_params['common_defaults']['multi_instance']
    alexa_server_params['stateful'] = stateful or server_params['common_defaults']['stateful']
    alexa_server_params['amazon_cert_lifetime'] = AMAZON_CERTIFICATE_LIFETIME

    if https:
        ssh_key_path = Path(ssl_key or server_params['https_key_path']).resolve()
        if not ssh_key_path.is_file():
            e = FileNotFoundError('Ssh key file not found: please provide correct path in --key param or '
                                  'https_key_path param in server configuration file')
            log.error(e)
            raise e

        ssh_cert_path = Path(ssl_cert or server_params['https_cert_path']).resolve()
        if not ssh_cert_path.is_file():
            e = FileNotFoundError('Ssh certificate file not found: please provide correct path in --cert param or '
                                  'https_cert_path param in server configuration file')
            log.error(e)
            raise e

        ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
        ssl_context.load_cert_chain(ssh_cert_path, ssh_key_path)
    else:
        ssl_context = None

    input_q = Queue()
    output_q = Queue()

    bot = Bot(agent_generator, alexa_server_params, input_q, output_q)
    bot.start()

    endpoint_description = {
        'description': 'Amazon Alexa custom service endpoint',
        'parameters': [
            {
                'name': 'Signature',
                'in': 'header',
                'required': 'true',
                'type': 'string',
                'example': 'Z5H5wqd06ExFVPNfJiqhKvAFjkf+cTVodOUirucHGcEVAMO1LfvgqWUkZ/X1ITDZbI0w+SMwVkEQZlkeThbVS/54M22StNDUtfz4Ua20xNDpIPwcWIACAmZ38XxbbTEFJI5WwqrbilNcfzqiGrIPfdO5rl+/xUjHFUdcJdUY/QzBxXsceytVYfEiR9MzOCN2m4C0XnpThUavAu159KrLj8AkuzN0JF87iXv+zOEeZRgEuwmsAnJrRUwkJ4yWokEPnSVdjF0D6f6CscfyvRe9nsWShq7/zRTa41meweh+n006zvf58MbzRdXPB22RI4AN0ksWW7hSC8/QLAKQE+lvaw==',
            },
            {
                'name': 'Signaturecertchainurl',
                'in': 'header',
                'required': 'true',
                'type': 'string',
                'example': 'https://s3.amazonaws.com/echo.api/echo-api-cert-6-ats.pem',
            },
            {
                'name': 'data',
                'in': 'body',
                'required': 'true',
                'example': {
                    'version': '1.0',
                    'session': {
                        'new': False,
                        'sessionId': 'amzn1.echo-api.session.3c6ebffd-55b9-4e1a-bf3c-c921c1801b63',
                        'application': {
                            'applicationId': 'amzn1.ask.skill.8b17a5de-3749-4919-aa1f-e0bbaf8a46a6'
                        },
                        'attributes': {
                            'sessionId': 'amzn1.echo-api.session.3c6ebffd-55b9-4e1a-bf3c-c921c1801b63'
                        },
                        'user': {
                            'userId': 'amzn1.ask.account.AGR4R2LOVHMNMNOGROBVNLU7CL4C57X465XJF2T2F55OUXNTLCXDQP3I55UXZIALEKKZJ6Q2MA5MEFSMZVPEL5NVZS6FZLEU444BVOLPB5WVH5CHYTQAKGD7VFLGPRFZVHHH2NIB4HKNHHGX6HM6S6QDWCKXWOIZL7ONNQSBUCVPMZQKMCYXRG5BA2POYEXFDXRXCGEVDWVSMPQ'
                        }
                    },
                    'context': {
                        'System': {
                            'application': {
                                'applicationId': 'amzn1.ask.skill.8b17a5de-3749-4919-aa1f-e0bbaf8a46a6'
                            },
                            'user': {
                                'userId': 'amzn1.ask.account.AGR4R2LOVHMNMNOGROBVNLU7CL4C57X465XJF2T2F55OUXNTLCXDQP3I55UXZIALEKKZJ6Q2MA5MEFSMZVPEL5NVZS6FZLEU444BVOLPB5WVH5CHYTQAKGD7VFLGPRFZVHHH2NIB4HKNHHGX6HM6S6QDWCKXWOIZL7ONNQSBUCVPMZQKMCYXRG5BA2POYEXFDXRXCGEVDWVSMPQ'
                            },
                            'device': {
                                'deviceId': 'amzn1.ask.device.AFQAMLYOYQUUACSE7HFVYS4ZI2KUB35JPHQRUPKTDCAU3A47WESP5L57KSWT5L6RT3FVXWH4OA2DNPJRMZ2VGEIACF3PJEIDCOUWUBC4W5RPJNUB3ZVT22J4UJN5UL3T2UBP36RVHFJ5P4IPT2HUY3P2YOY33IOU4O33HUAG7R2BUNROEH4T2',
                                'supportedInterfaces': {}
                            },
                            'apiEndpoint': 'https://api.amazonalexa.com',
                            'apiAccessToken': 'eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhdWQiOiJodHRwczovL2FwaS5hbWF6b25hbGV4YS5jb20iLCJpc3MiOiJBbGV4YVNraWxsS2l0Iiwic3ViIjoiYW16bjEuYXNrLnNraWxsLjhiMTdhNWRlLTM3NDktNDkxOS1hYTFmLWUwYmJhZjhhNDZhNiIsImV4cCI6MTU0NTIyMzY1OCwiaWF0IjoxNTQ1MjIwMDU4LCJuYmYiOjE1NDUyMjAwNTgsInByaXZhdGVDbGFpbXMiOnsiY29uc2VudFRva2VuIjpudWxsLCJkZXZpY2VJZCI6ImFtem4xLmFzay5kZXZpY2UuQUZRQU1MWU9ZUVVVQUNTRTdIRlZZUzRaSTJLVUIzNUpQSFFSVVBLVERDQVUzQTQ3V0VTUDVMNTdLU1dUNUw2UlQzRlZYV0g0T0EyRE5QSlJNWjJWR0VJQUNGM1BKRUlEQ09VV1VCQzRXNVJQSk5VQjNaVlQyMko0VUpONVVMM1QyVUJQMzZSVkhGSjVQNElQVDJIVVkzUDJZT1kzM0lPVTRPMzNIVUFHN1IyQlVOUk9FSDRUMiIsInVzZXJJZCI6ImFtem4xLmFzay5hY2NvdW50LkFHUjRSMkxPVkhNTk1OT0dST0JWTkxVN0NMNEM1N1g0NjVYSkYyVDJGNTVPVVhOVExDWERRUDNJNTVVWFpJQUxFS0taSjZRMk1BNU1FRlNNWlZQRUw1TlZaUzZGWkxFVTQ0NEJWT0xQQjVXVkg1Q0hZVFFBS0dEN1ZGTEdQUkZaVkhISDJOSUI0SEtOSEhHWDZITTZTNlFEV0NLWFdPSVpMN09OTlFTQlVDVlBNWlFLTUNZWFJHNUJBMlBPWUVYRkRYUlhDR0VWRFdWU01QUSJ9fQ.jcomYhBhU485T4uoe2NyhWnL-kZHoPQKpcycFqa-1sy_lSIitfFGup9DKrf2NkN-I9lZ3xwq9llqx9WRN78fVJjN6GLcDhBDH0irPwt3n9_V7_5bfB6KARv5ZG-JKOmZlLBqQbnln0DAJ10D8HNiytMARNEwduMBVDNK0A5z6YxtRcLYYFD2-Ieg_V8Qx90eE2pd2U5xOuIEL0pXfSoiJ8vpxb8BKwaMO47tdE4qhg_k7v8ClwyXg3EMEhZFjixYNqdW1tCrwDGj58IWMXDyzZhIlRMh6uudMOT6scSzcNVD0v42IOTZ3S_X6rG01B7xhUDlZXMqkrCuzOyqctGaPw'
                        },
                        'Viewport': {
                            'experiences': [
                                {
                                    'arcMinuteWidth': 246,
                                    'arcMinuteHeight': 144,
                                    'canRotate': False,
                                    'canResize': False
                                }
                            ],
                            'shape': 'RECTANGLE',
                            'pixelWidth': 1024,
                            'pixelHeight': 600,
                            'dpi': 160,
                            'currentPixelWidth': 1024,
                            'currentPixelHeight': 600,
                            'touch': [
                                'SINGLE'
                            ]
                        }
                    },
                    'request': {
                        'type': 'IntentRequest',
                        'requestId': 'amzn1.echo-api.request.388d0f6e-04b9-4450-a687-b9abaa73ac6a',
                        'timestamp': '2018-12-19T11:47:38Z',
                        'locale': 'en-US',
                        'intent': {
                            'name': 'AskDeepPavlov',
                            'confirmationStatus': 'NONE',
                            'slots': {
                                'raw_input': {
                                    'name': 'raw_input',
                                    'value': 'my beautiful sandbox skill',
                                    'resolutions': {
                                        'resolutionsPerAuthority': [
                                            {
                                                'authority': 'amzn1.er-authority.echo-sdk.amzn1.ask.skill.8b17a5de-3749-4919-aa1f-e0bbaf8a46a6.GetInput',
                                                'status': {
                                                    'code': 'ER_SUCCESS_NO_MATCH'
                                                }
                                            }
                                        ]
                                    },
                                    'confirmationStatus': 'NONE',
                                    'source': 'USER'
                                }
                            }
                        }
                    }
                }
            }
        ],
        'responses': {
            "200": {
                "description": "A model response"
            }
        }
    }

    @app.route('/')
    def index():
        return redirect('/apidocs/')

    @app.route('/interact', methods=['POST'])
    @swag_from(endpoint_description)
    def handle_request():
        request_body: bytes = request.get_data()
        signature_chain_url: str = request.headers.get('Signaturecertchainurl')
        signature: str = request.headers.get('Signature')
        alexa_request: dict = request.get_json()

        request_dict = {
            'request_body': request_body,
            'signature_chain_url': signature_chain_url,
            'signature': signature,
            'alexa_request': alexa_request
        }

        bot.input_queue.put(request_dict)
        response: dict = bot.output_queue.get()
        response_code = 400 if 'error' in response.keys() else 200

        return jsonify(response), response_code

    app.run(host=host, port=port, threaded=True, ssl_context=ssl_context)
Пример #23
0
from pydantic import BaseConfig, BaseModel
from pydantic.fields import Field, ModelField
from pydantic.main import ModelMetaclass
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse

from deeppavlov.core.commands.infer import build_model
from deeppavlov.core.commands.utils import parse_config
from deeppavlov.core.common.chainer import Chainer
from deeppavlov.core.common.file import read_json
from deeppavlov.core.common.log import log_config
from deeppavlov.core.common.paths import get_settings_path
from deeppavlov.core.data.utils import check_nested_dict_keys, jsonify_data
from deeppavlov.utils.connector import DialogLogger

SERVER_CONFIG_PATH = get_settings_path() / 'server_config.json'
SSLConfig = namedtuple('SSLConfig', ['version', 'keyfile', 'certfile'])

log = getLogger(__name__)
dialog_logger = DialogLogger(logger_name='rest_api')

app = FastAPI(__file__)

app.add_middleware(CORSMiddleware,
                   allow_origins=['*'],
                   allow_credentials=True,
                   allow_methods=['*'],
                   allow_headers=['*'])


def get_server_params(model_config: Union[str, Path]) -> Dict:
Пример #24
0
def start_model_server(model_config,
                       https=False,
                       ssl_key=None,
                       ssl_cert=None,
                       port=None):
    server_config_path = get_settings_path() / SERVER_CONFIG_FILENAME
    server_params = get_server_params(server_config_path, model_config)

    host = server_params['host']
    port = port or server_params['port']
    model_endpoint = server_params['model_endpoint']
    model_args_names = server_params['model_args_names']

    https = https or server_params['https']

    if https:
        ssh_key_path = Path(ssl_key
                            or server_params['https_key_path']).resolve()
        if not ssh_key_path.is_file():
            e = FileNotFoundError(
                'Ssh key file not found: please provide correct path in --key param or '
                'https_key_path param in server configuration file')
            log.error(e)
            raise e

        ssh_cert_path = Path(ssl_cert
                             or server_params['https_cert_path']).resolve()
        if not ssh_cert_path.is_file():
            e = FileNotFoundError(
                'Ssh certificate file not found: please provide correct path in --cert param or '
                'https_cert_path param in server configuration file')
            log.error(e)
            raise e

        ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
        ssl_context.load_cert_chain(ssh_cert_path, ssh_key_path)
    else:
        ssl_context = None

    model = build_model(model_config)

    @app.route('/')
    def index():
        return redirect('/apidocs/')

    endpoint_description = {
        'description':
        'A model endpoint',
        'parameters': [{
            'name': 'data',
            'in': 'body',
            'required': 'true',
            'example': {arg: ['value']
                        for arg in model_args_names}
        }],
        'responses': {
            "200": {
                "description": "A model response"
            }
        }
    }

    @app.route(model_endpoint, methods=['POST'])
    @swag_from(endpoint_description)
    def answer():
        return interact(model, model_args_names)

    app.run(host=host, port=port, threaded=False, ssl_context=ssl_context)