コード例 #1
0
def run_interactive_learning(stories: Text = None,
                             finetune: bool = False,
                             skip_visualization: bool = False,
                             server_args: Dict[Text, Any] = None,
                             additional_arguments: Dict[Text, Any] = None):
    """Start the interactive learning with the model of the agent."""

    server_args = server_args or {}

    if not skip_visualization:
        p = Process(target=start_visualization, args=("story_graph.dot", ))
        p.deamon = True
        p.start()
    else:
        p = None

    app = run.configure_app(enable_api=True)
    endpoints = AvailableEndpoints.read_endpoints(server_args.get("endpoints"))

    # before_server_start handlers make sure the agent is loaded before the
    # interactive learning IO starts
    if server_args.get("core"):
        app.register_listener(
            partial(run.load_agent_on_start, server_args.get("core"),
                    endpoints, server_args.get("nlu")), 'before_server_start')
    else:
        app.register_listener(
            partial(train_agent_on_start, server_args, endpoints,
                    additional_arguments), 'before_server_start')

    _serve_application(app, stories, finetune, skip_visualization)

    if not skip_visualization:
        p.terminate()
        p.join()
コード例 #2
0
def run(core_dir, nlu_dir, endpoints_file):
    configs = {
        'user': os.getenv('ROCKETCHAT_BOT_USERNAME'),
        'password': os.getenv('ROCKETCHAT_BOT_PASSWORD'),
        'server_url': os.getenv('ROCKETCHAT_URL'),
    }

    input_channel = RocketChatInput(user=configs['user'],
                                    password=configs['password'],
                                    server_url=configs['server_url'])

    _endpoints = AvailableEndpoints.read_endpoints(endpoints_file)
    _interpreter = NaturalLanguageInterpreter.create(nlu_dir, _endpoints.nlu)
    _tracker_store = ElasticTrackerStore()

    _agent = load_agent(core_dir,
                        interpreter=_interpreter,
                        tracker_store=_tracker_store,
                        endpoints=_endpoints)

    http_server = start_server([input_channel], "", "", 5005, _agent)

    try:
        http_server.serve_forever()
    except Exception as exc:
        logger.exception(exc)
コード例 #3
0
def create_agent(model: Text, endpoints: Text = None) -> 'Agent':
    from rasa_core.interpreter import RasaNLUInterpreter
    from rasa_core.tracker_store import TrackerStore
    from rasa_core import broker
    from rasa_core.utils import AvailableEndpoints

    core_path, nlu_path = get_model_subdirectories(model)
    _endpoints = AvailableEndpoints.read_endpoints(endpoints)

    _interpreter = None
    if os.path.exists(nlu_path):
        _interpreter = RasaNLUInterpreter(model_directory=nlu_path)
    else:
        _interpreter = None
        logging.info("No NLU model found. Running without NLU.")

    _broker = broker.from_endpoint_config(_endpoints.event_broker)

    _tracker_store = TrackerStore.find_tracker_store(None,
                                                     _endpoints.tracker_store,
                                                     _broker)

    return Agent.load(core_path,
                      generator=_endpoints.nlg,
                      tracker_store=_tracker_store,
                      action_endpoint=_endpoints.action)
コード例 #4
0
def serve_application(model_directory: Text,
                      nlu_model: Optional[Text] = None,
                      tracker_dump: Optional[Text] = None,
                      port: int = constants.DEFAULT_SERVER_PORT,
                      endpoints: Optional[Text] = None,
                      enable_api: bool = True):
    from rasa_core import run

    _endpoints = AvailableEndpoints.read_endpoints(endpoints)

    nlu = NaturalLanguageInterpreter.create(nlu_model, _endpoints.nlu)

    input_channels = run.create_http_input_channels("cmdline", None)

    agent = load_agent(model_directory, interpreter=nlu, endpoints=_endpoints)

    http_server = run.start_server(input_channels,
                                   None,
                                   None,
                                   port=port,
                                   initial_agent=agent,
                                   enable_api=enable_api)

    tracker = load_tracker_from_json(tracker_dump, agent.domain)

    run.start_cmdline_io(constants.DEFAULT_SERVER_FORMAT.format(port),
                         http_server.stop,
                         sender_id=tracker.sender_id)

    replay_events(tracker, agent)

    try:
        http_server.serve_forever()
    except Exception as exc:
        logger.exception(exc)
コード例 #5
0
ファイル: run-webchat.py プロジェクト: pablodiegoss/bot
def run(core_dir, nlu_dir):

    input_channel = SocketIOInput(user_message_evt="user_uttered",
                                  bot_message_evt="bot_uttered",
                                  session_persistence=True,
                                  namespace=None)

    _endpoints = AvailableEndpoints.read_endpoints('endpoints.yml')
    _interpreter = NaturalLanguageInterpreter.create(nlu_dir)

    elastic_user = os.getenv('ELASTICSEARCH_USER')
    if elastic_user is None:
        _tracker_store = ElasticTrackerStore(
            domain=os.getenv('ELASTICSEARCH_URL', 'elasticsearch:9200'))
    else:
        _tracker_store = ElasticTrackerStore(
            domain=os.getenv('ELASTICSEARCH_URL', 'elasticsearch:9200'),
            user=os.getenv('ELASTICSEARCH_USER', 'user'),
            password=os.getenv('ELASTICSEARCH_PASSWORD', 'password'),
            scheme=os.getenv('ELASTICSEARCH_HTTP_SCHEME', 'http'),
            scheme_port=os.getenv('ELASTICSEARCH_PORT', '80'))

    _agent = load_agent(core_dir,
                        interpreter=_interpreter,
                        tracker_store=_tracker_store,
                        endpoints=_endpoints)

    WEBCHAT_PORT = os.getenv('WEBCHAT_PORT', 3000)
    http_server = start_server([input_channel], "", "", int(WEBCHAT_PORT),
                               _agent)

    try:
        http_server.serve_forever()
    except Exception as exc:
        logger.exception(exc)
コード例 #6
0
def run(core_dir, nlu_dir):
    pika_broker = None

    if ENABLE_ANALYTICS:
        pika_broker = PikaProducer(url, username, password, queue=queue)

    configs = {
        'user': os.getenv('ROCKETCHAT_BOT_USERNAME'),
        'password': os.getenv('ROCKETCHAT_BOT_PASSWORD'),
        'server_url': os.getenv('ROCKETCHAT_URL'),
    }

    input_channel = RocketChatInput(user=configs['user'],
                                    password=configs['password'],
                                    server_url=configs['server_url'])

    _tracker_store = InMemoryTrackerStore(domain=None,
                                          event_broker=pika_broker)

    _endpoints = AvailableEndpoints.read_endpoints(None)
    _interpreter = NaturalLanguageInterpreter.create(nlu_dir)

    _agent = load_agent(core_dir,
                        interpreter=_interpreter,
                        tracker_store=_tracker_store,
                        endpoints=_endpoints)

    http_server = start_server([input_channel], "", "", 5005, _agent)

    try:
        http_server.serve_forever()
    except Exception as exc:
        logger.exception(exc)
コード例 #7
0
def run(core_dir, nlu_dir):

    _endpoints = AvailableEndpoints.read_endpoints('endpoints.yml')
    _interpreter = NaturalLanguageInterpreter.create(nlu_dir)

    input_channel = TelegramInput(access_token=os.getenv(
        'TELEGRAM_ACCESS_TOKEN', ''),
                                  verify=os.getenv('VERIFY', ''),
                                  webhook_url=os.getenv('WEBHOOK_URL', ''))

    elastic_user = os.getenv('ELASTICSEARCH_USER')
    if elastic_user is None:
        _tracker_store = ElasticTrackerStore(
            domain=os.getenv('ELASTICSEARCH_URL', 'elasticsearch:9200'))
    else:
        _tracker_store = ElasticTrackerStore(
            domain=os.getenv('ELASTICSEARCH_URL', 'elasticsearch:9200'),
            user=os.getenv('ELASTICSEARCH_USER', 'user'),
            password=os.getenv('ELASTICSEARCH_PASSWORD', 'password'),
            scheme=os.getenv('ELASTICSEARCH_HTTP_SCHEME', 'http'),
            scheme_port=os.getenv('ELASTICSEARCH_PORT', '80'))

    _agent = load_agent(core_dir,
                        interpreter=_interpreter,
                        tracker_store=_tracker_store,
                        endpoints=_endpoints)

    http_server = _agent.handle_channels([input_channel], 5001, "")

    try:
        http_server.serve_forever()
    except Exception as exc:
        logger.exception(exc)
コード例 #8
0
def run():
    endpoints = AvailableEndpoints.read_endpoints('config/endpoints.yml')
    interpreter = NaturalLanguageInterpreter.create('models/ticket/nlu_bot',endpoints.nlu)
    agent = load_agent("models/dialogue", interpreter=interpreter, endpoints=endpoints)
    serve_application(agent,channel='rest')
    # serve_application(agent)

    return agent
コード例 #9
0
def test_formbot_example():
    sys.path.append("examples/formbot/")

    p = "examples/formbot/"
    stories = os.path.join(p, "data", "stories.md")
    endpoint = EndpointConfig("https://abc.defg/webhooks/actions")
    endpoints = AvailableEndpoints(action=endpoint)
    agent = train(os.path.join(p, "domain.yml"),
                  stories,
                  os.path.join(p, "models", "dialogue"),
                  endpoints=endpoints,
                  policy_config="rasa_core/default_config.yml")
    response = {
        'events': [{
            'event': 'form',
            'name': 'restaurant_form',
            'timestamp': None
        }, {
            'event': 'slot',
            'timestamp': None,
            'name': 'requested_slot',
            'value': 'cuisine'
        }],
        'responses': [{
            'template': 'utter_ask_cuisine'
        }]
    }

    httpretty.register_uri(httpretty.POST,
                           'https://abc.defg/webhooks/actions',
                           body=json.dumps(response))

    httpretty.enable()

    responses = agent.handle_text("/request_restaurant")

    httpretty.disable()

    assert responses[0]['text'] == 'what cuisine?'

    response = {
        "error": "Failed to validate slot cuisine with action restaurant_form",
        "action_name": "restaurant_form"
    }

    httpretty.register_uri(httpretty.POST,
                           'https://abc.defg/webhooks/actions',
                           status=400,
                           body=json.dumps(response))

    httpretty.enable()

    responses = agent.handle_text("/chitchat")

    httpretty.disable()

    assert responses[0]['text'] == 'chitchat'
コード例 #10
0
 def _create_agent(model_directory, interpreter):
     """Creates a Rasa Agent which runs when the server is started"""
     try:
         endpoints = AvailableEndpoints.read_endpoints('endpoints.yml')
         return Agent.load(model_directory, interpreter, action_endpoint=endpoints.action)
     except Exception as e:
         logger.warn("Failed to load any agent model. Running "
                     "Rasa Core server with out loaded model now. {}"
                     "".format(e))
         return None
コード例 #11
0
async def test_formbot_example():
    sys.path.append("examples/formbot/")

    p = "examples/formbot/"
    stories = os.path.join(p, "data", "stories.md")
    endpoint = EndpointConfig("https://example.com/webhooks/actions")
    endpoints = AvailableEndpoints(action=endpoint)
    agent = await train(os.path.join(p, "domain.yml"),
                        stories,
                        os.path.join(p, "models", "dialogue"),
                        endpoints=endpoints,
                        policy_config="rasa_core/default_config.yml")
    response = {
        'events': [{
            'event': 'form',
            'name': 'restaurant_form',
            'timestamp': None
        }, {
            'event': 'slot',
            'timestamp': None,
            'name': 'requested_slot',
            'value': 'cuisine'
        }],
        'responses': [{
            'template': 'utter_ask_cuisine'
        }]
    }

    with aioresponses() as mocked:
        mocked.post('https://example.com/webhooks/actions',
                    payload=response,
                    repeat=True)

        responses = await agent.handle_text("/request_restaurant")

        assert responses[0]['text'] == 'what cuisine?'

    response = {
        "error": "Failed to validate slot cuisine with action "
        "restaurant_form",
        "action_name": "restaurant_form"
    }

    with aioresponses() as mocked:
        # noinspection PyTypeChecker
        mocked.post('https://example.com/webhooks/actions',
                    repeat=True,
                    exception=ClientResponseError(400, "",
                                                  json.dumps(response)))

        responses = await agent.handle_text("/chitchat")

        assert responses[0]['text'] == 'chitchat'
コード例 #12
0
def run_online(domain_file="config/domain.yml", stories_file="config/stories.md", output_path="models/dialogue",
                      max_history=3, kwargs={"batch_size": 50, "epochs": 800, "max_training_samples": 300}):
    interpreter = RasaNLUInterpreter("models/ticket/nlu_bot")
    agent = train.train_dialogue_model(domain_file=domain_file,
                                       interpreter=interpreter,
                                       stories_file=stories_file,
                                       output_path=output_path,
                                       max_history=max_history,
                                       endpoints=AvailableEndpoints.read_endpoints("config/endpoints.yml"),
                                       kwargs=kwargs)

    online.run_online_learning(agent)
コード例 #13
0
ファイル: agent_testor.py プロジェクト: CoderOverflow/stack
def load_agent():
    p = "examples/formbot/"
    strpath = os.path.join(p, "models", "dialogue")
    endpoint = EndpointConfig("http://localhost:5055/webhook")
    endpoints = AvailableEndpoints(action=endpoint)
    loaded = Agent.load(strpath,
                        interpreter=RegexInterpreter(),
                        action_endpoint=endpoints.action)
    # loaded = Agent.load(strpath, interpreter=RegexInterpreter())
    responses = loaded.handle_text("/request_restaurant")
    print(responses[0])
    responses = loaded.handle_text("/chitchat")
    print(responses)
コード例 #14
0
ファイル: test.py プロジェクト: ng-healthpointe/rasa_core
def main():
    from rasa_core.agent import Agent
    from rasa_core.interpreter import NaturalLanguageInterpreter
    from rasa_core.utils import (AvailableEndpoints, set_default_subparser)
    import rasa_nlu.utils as nlu_utils
    import rasa_core.cli
    from rasa_core import utils

    loop = asyncio.get_event_loop()

    # Running as standalone python application
    arg_parser = create_argument_parser()
    set_default_subparser(arg_parser, 'default')
    cmdline_arguments = arg_parser.parse_args()

    logging.basicConfig(level=cmdline_arguments.loglevel)
    _endpoints = AvailableEndpoints.read_endpoints(cmdline_arguments.endpoints)

    if cmdline_arguments.output:
        nlu_utils.create_dir(cmdline_arguments.output)

    if not cmdline_arguments.core:
        raise ValueError("you must provide a core model directory to evaluate "
                         "using -d / --core")
    if cmdline_arguments.mode == 'default':

        _interpreter = NaturalLanguageInterpreter.create(
            cmdline_arguments.nlu, _endpoints.nlu)

        _agent = Agent.load(cmdline_arguments.core, interpreter=_interpreter)

        stories = loop.run_until_complete(
            rasa_core.cli.train.stories_from_cli_args(cmdline_arguments))

        loop.run_until_complete(
            test(stories, _agent, cmdline_arguments.max_stories,
                 cmdline_arguments.output,
                 cmdline_arguments.fail_on_prediction_errors,
                 cmdline_arguments.e2e))

    elif cmdline_arguments.mode == 'compare':
        compare(cmdline_arguments.core, cmdline_arguments.stories,
                cmdline_arguments.output)

        story_n_path = os.path.join(cmdline_arguments.core, 'num_stories.json')

        number_of_stories = utils.read_json_file(story_n_path)
        plot_curve(cmdline_arguments.output, number_of_stories)

    logger.info("Finished evaluation")
コード例 #15
0
def run(core_dir, nlu_dir):
    _endpoints = AvailableEndpoints.read_endpoints('endpoints.yml')
    _interpreter = NaturalLanguageInterpreter.create(nlu_dir)

    input_channel = FacebookInput(
        fb_verify=VERIFY,
        # you need tell facebook this token, to confirm your URL
        fb_secret=SECRET,  # your app secret
        fb_access_token=FACEBOOK_ACCESS_TOKEN
        # token for the page you subscribed to
    )

    _agent = load_agent(core_dir,
                        interpreter=_interpreter,
                        endpoints=_endpoints)

    _agent.handle_channels([input_channel], 5001, serve_forever=True)
コード例 #16
0
def run_evaluation(file_to_evaluate,
                   fail_on_prediction_errors=False,
                   max_stories=None,
                   use_e2e=False):

    _endpoints = AvailableEndpoints.read_endpoints(None)
    _interpreter = NaturalLanguageInterpreter.create(NLU_DIR)
    _agent = load_agent(CORE_DIR,
                        interpreter=_interpreter,
                        endpoints=_endpoints)

    completed_trackers = _generate_trackers(file_to_evaluate, _agent,
                                            max_stories, use_e2e)
    story_evaluation, _ = collect_story_predictions(completed_trackers, _agent,
                                                    fail_on_prediction_errors,
                                                    use_e2e)
    _failed_stories = story_evaluation.failed_stories

    _num_stories = len(completed_trackers)
    _file_result = FileResult(num_stories=_num_stories,
                              num_failed_stories=len(_failed_stories))

    file_message = "EVALUATING STORIES FOR FILE '{}':".format(file_to_evaluate)
    utils.print_color('\n' + '#' * 80, BOLD_COLOR)
    utils.print_color(file_message, BOLD_COLOR)

    files_results[file_to_evaluate] = _file_result

    if len(_failed_stories) == 0:
        success_message = "The stories have passed for file '{}'!!" \
                          .format(file_to_evaluate)
        utils.print_color('\n' + '=' * len(success_message), BLUE_COLOR)
        utils.print_color(success_message, BLUE_COLOR)
        utils.print_color('=' * len(success_message), BLUE_COLOR)
    else:
        for failed_story in _failed_stories:
            process_failed_story(failed_story.export_stories())
            story_name = re.search('## (.*)',
                                   failed_story.export_stories()).group(1)
            all_failed_stories.append(file_to_evaluate + ' - ' + story_name)

    utils.print_color('#' * 80 + '\n', BOLD_COLOR)
コード例 #17
0
def run(core_dir, nlu_dir):

    _endpoints = AvailableEndpoints.read_endpoints('endpoints.yml')
    _interpreter = NaturalLanguageInterpreter.create(nlu_dir)

    input_channel = TelegramInput(access_token=os.getenv(
        'TELEGRAM_ACCESS_TOKEN', ''),
                                  verify=os.getenv('VERIFY', ''),
                                  webhook_url=os.getenv('WEBHOOK_URL', ''))

    _agent = load_agent(core_dir,
                        interpreter=_interpreter,
                        endpoints=_endpoints)

    http_server = _agent.handle_channels([input_channel], 5001, "")

    try:
        http_server.serve_forever()
    except Exception as exc:
        logger.exception(exc)
コード例 #18
0
ファイル: run-rocketchat.py プロジェクト: gabibguedes/tais
def run(core_dir, nlu_dir):
    configs = {
        'user': os.getenv('ROCKETCHAT_BOT_USERNAME'),
        'password': os.getenv('ROCKETCHAT_BOT_PASSWORD'),
        'server_url': os.getenv('ROCKETCHAT_URL'),
    }

    input_channel = RocketChatInput(
        user=configs['user'],
        password=configs['password'],
        server_url=configs['server_url']
    )

    _endpoints = AvailableEndpoints.read_endpoints(None)
    _interpreter = NaturalLanguageInterpreter.create(nlu_dir)

    elastic_user = os.getenv('ELASTICSEARCH_USER')
    if elastic_user is None:
        _tracker_store = ElasticTrackerStore(
            domain = os.getenv('ELASTICSEARCH_URL', 'elasticsearch:9200')
        )
    else:
        _tracker_store = ElasticTrackerStore(
            domain      = os.getenv('ELASTICSEARCH_URL', 'elasticsearch:9200'),
            user        = os.getenv('ELASTICSEARCH_USER', 'user'),
            password    = os.getenv('ELASTICSEARCH_PASSWORD', 'password'),
            scheme      = os.getenv('ELASTICSEARCH_HTTP_SCHEME', 'http'),
            scheme_port = os.getenv('ELASTICSEARCH_PORT', '80')
        )

    _agent = load_agent(core_dir,
                        interpreter=_interpreter,
                        tracker_store=_tracker_store,
                        endpoints=_endpoints)

    http_server = start_server([input_channel], "", "", 5005, _agent)

    try:
        http_server.serve_forever()
    except Exception as exc:
        logger.exception(exc)
コード例 #19
0
 def __init__(self):
     self.config = Config()
     self.endpoints = AvailableEndpoints.read_endpoints(
         self.config.RASA_CONFIG_ENDPOINTS_FILE)
     # self.interpreter = NaturalLanguageInterpreter.create(self.config.RASA_CONFIG_NLU_TRAIN_PACKAGE_NAME, self.endpoints.nlu)
     self.agent = load_agent(
         self.config.RASA_CONFIG_CORE_DIALOGUE_PACKAGE_NAME,
         interpreter=None,
         endpoints=self.endpoints)
     self.executor = ActionExecutor()
     self.executor.register_package(
         self.config.RASA_CONFIG_ENDPOINTS_ACTION_PACKAGE_NAME)
     self.message_processor = MessageProcessor(
         # self.interpreter,
         None,
         self.agent.policy_ensemble,
         self.agent.domain,
         self.agent.tracker_store,
         self.agent.nlg,
         action_endpoint=self.agent.action_endpoint,
         message_preprocessor=None)
コード例 #20
0
def run(model: Text,
        endpoints: Text,
        connector: Text = None,
        credentials: Text = None,
        **kwargs: Dict):
    """Runs a Rasa model.

    Args:
        model: Path to model archive.
        endpoints: Path to endpoints file.
        connector: Connector which should be use (overwrites `credentials`
        field).
        credentials: Path to channel credentials file.
        **kwargs: Additional arguments which are passed to
        `rasa_core.run.serve_application`.

    """
    import rasa_core.run
    from rasa_core.utils import AvailableEndpoints

    model_path = get_model(model)
    core_path, nlu_path = get_model_subdirectories(model_path)
    _endpoints = AvailableEndpoints.read_endpoints(endpoints)

    if not connector and not credentials:
        channel = "cmdline"
        logger.info("No chat connector configured, falling back to the "
                    "command line. Use `rasa configure channel` to connect"
                    "the bot to e.g. facebook messenger.")
    else:
        channel = connector

    kwargs = minimal_kwargs(kwargs, rasa_core.run.serve_application)
    rasa_core.run.serve_application(core_path,
                                    nlu_path,
                                    channel=channel,
                                    credentials_file=credentials,
                                    endpoints=_endpoints,
                                    **kwargs)
    shutil.rmtree(model_path)
コード例 #21
0
ファイル: up.py プロジェクト: jayd2446/rasa_stack
def start_core(platform_token):
    from rasa_core.utils import AvailableEndpoints
    _endpoints = AvailableEndpoints(
        # TODO: make endpoints more configurable, esp ports
        model=EndpointConfig(
            "http://localhost:5002"
            "/api/projects/default/models/tags/production",
            token=platform_token,
            wait_time_between_pulls=1),
        event_broker=EndpointConfig(**{"type": "file"}),
        nlg=EndpointConfig("http://localhost:5002"
                           "/api/nlg",
                           token=platform_token))

    from rasa_core import broker
    _broker = broker.from_endpoint_config(_endpoints.event_broker)

    from rasa_core.tracker_store import TrackerStore
    _tracker_store = TrackerStore.find_tracker_store(None,
                                                     _endpoints.tracker_store,
                                                     _broker)

    from rasa_core.run import load_agent
    _agent = load_agent("models",
                        interpreter=None,
                        tracker_store=_tracker_store,
                        endpoints=_endpoints)
    from rasa_core.run import serve_application
    print_success("About to start core")

    serve_application(
        _agent,
        "rasa",
        5005,
        "credentials.yml",
        "*",
        None,  # TODO: configure auth token
        True)
コード例 #22
0
def start_core(platform_token):
    from rasa_core.utils import AvailableEndpoints
    from rasa_core.run import serve_application
    from rasa_core.utils import EndpointConfig

    _endpoints = AvailableEndpoints(
        # TODO: make endpoints more configurable, esp ports
        model=EndpointConfig("http://localhost:5002"
                             "/api/projects/default/models/tags/production",
                             token=platform_token,
                             wait_time_between_pulls=1),
        event_broker=EndpointConfig(**{"type": "file"}),
        nlg=EndpointConfig("http://localhost:5002"
                           "/api/nlg",
                           token=platform_token))

    serve_application("models",
                      nlu_model=None,
                      channel="rasa",
                      credentials_file="credentials.yml",
                      cors="*",
                      auth_token=None,  # TODO: configure auth token
                      enable_api=True,
                      endpoints=_endpoints)
コード例 #23
0
def test_formbot_example():
    sys.path.append("examples/formbot/")

    p = "examples/formbot/"
    stories = os.path.join(p, "data", "stories.md")
    endpoint = EndpointConfig("http://localhost:5055/webhook")
    endpoints = AvailableEndpoints(action=endpoint)
    agent = train_dialogue_model(os.path.join(p, "domain.yml"),
                                 stories,
                                 os.path.join(p, "models", "dialogue"),
                                 endpoints=endpoints,
                                 policy_config="rasa_core/default_config.yml")
    # response = {
    #     'events': [
    #         {'event': 'form', 'name': 'restaurant_form', 'timestamp': None},
    #         {'event': 'slot', 'timestamp': None,
    #          'name': 'requested_slot', 'value': 'cuisine'}
    #     ],
    #     'responses': [
    #         {'template': 'utter_ask_cuisine'}
    #     ]
    # }    
    print(type(agent.policy_ensemble))

    responses = agent.handle_text("/request_restaurant")
    # assert responses[0]['text'] == 'what cuisine?'
    print(responses)

    # response = {
    #     "error": "Failed to validate slot cuisine with action restaurant_form",
    #     "action_name": "restaurant_form"
    # }

    responses = agent.handle_text("/chitchat")
    # assert responses[0]['text'] == 'chitchat'
    print(responses[0])
コード例 #24
0
if __name__ == '__main__':
    # Running as standalone python application
    arg_parser = create_argument_parser()
    cmdline_args = arg_parser.parse_args()

    logging.getLogger('werkzeug').setLevel(logging.WARN)
    logging.getLogger('matplotlib').setLevel(logging.WARN)

    utils.configure_colored_logging(cmdline_args.loglevel)
    utils.configure_file_logging(cmdline_args.loglevel,
                                 cmdline_args.log_file)

    logger.info("Rasa process starting")

    _endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints)
    _interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu,
                                                     _endpoints.nlu)
    _broker = PikaProducer.from_endpoint_config(_endpoints.event_broker)

    _tracker_store = TrackerStore.find_tracker_store(
        None, _endpoints.tracker_store, _broker)
    _agent = load_agent(cmdline_args.core,
                        interpreter=_interpreter,
                        tracker_store=_tracker_store,
                        endpoints=_endpoints)
    serve_application(_agent,
                      cmdline_args.connector,
                      cmdline_args.port,
                      cmdline_args.credentials,
                      cmdline_args.cors,
コード例 #25
0
app = Flask(__name__)
app.secret_key = '12345'


@app.route('/')
def hello_world():
    return render_template('home.html')


get_random_response = lambda intent: random.choice(intent_response_dict[intent]
                                                   )

# Load Rasa NLU interpreter and Rasa Core agent
interpreter = RasaNLUInterpreter("models/livio/nlu")
_endpoints = AvailableEndpoints.read_endpoints("endpoints.yml")
action_endpoint = _endpoints.action
agent = Agent.load("models/livio/dialouge",
                   interpreter=interpreter,
                   action_endpoint=action_endpoint)


@app.route('/chat', methods=["POST"])
def chat():
    try:
        user_message = request.form["text"]
        response = requests.get("http://localhost:5000/parse",
                                params={"q": user_message})
        response = response.json()
        entities = response.get("entities")
        topresponse = response["intent"]
コード例 #26
0
from robot import app
from robot.config.setting import Config

logger = logging.getLogger(__name__)
robot_api = Blueprint('robot_api', __name__)

CORS(app, resources={r"/*": {"origins": "*"}})
cors_origins = None or []
__version__ = '0.11.6'
auth_token = None
config = Config
#初始化jwt参数
# JWTManager(app)

endpoints = AvailableEndpoints.read_endpoints(
    config.RASA_CONFIG_ENDPOINTS_FILE)
interpreter = NaturalLanguageInterpreter.create(
    config.RASA_CONFIG_NLU_TRAIN_PACKAGE_NAME, endpoints.nlu)
agent = load_agent(config.RASA_CONFIG_CORE_DIALOGUE_PACKAGE_NAME,
                   interpreter=interpreter,
                   endpoints=endpoints)
input_channels = create_http_input_channels('rest', None)
rasa_core.channels.channel.register(input_channels,
                                    app,
                                    agent.handle_message,
                                    route="/webhooks/")


@robot_api.route("/", methods=['GET', 'OPTIONS'])
@cross_origin(origins=cors_origins)
def hello():