Beispiel #1
0
def read_endpoints(endpoint_file):
    nlg = utils.read_endpoint_config(endpoint_file, endpoint_type="nlg")
    nlu = utils.read_endpoint_config(endpoint_file, endpoint_type="nlu")
    action = utils.read_endpoint_config(endpoint_file,
                                        endpoint_type="action_endpoint")
    model = utils.read_endpoint_config(endpoint_file, endpoint_type="models")

    return AvailableEndpoints(nlg, nlu, action, model)
Beispiel #2
0
def read_endpoints(endpoint_file):
    nlg = utils.read_endpoint_config(endpoint_file,
                                     endpoint_type="nlg")
    nlu = utils.read_endpoint_config(endpoint_file,
                                     endpoint_type="nlu")
    action = utils.read_endpoint_config(endpoint_file,
                                        endpoint_type="action_endpoint")
    model = utils.read_endpoint_config(endpoint_file,
                                       endpoint_type="models")

    return AvailableEndpoints(nlg, nlu, action, model)
def test_tracker_store_from_string(default_domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = utils.read_endpoint_config(endpoints_path, "tracker_store")

    tracker_store = TrackerStore.find_tracker_store(default_domain,
                                                    store_config)

    assert isinstance(tracker_store, ExampleTrackerStore)
Beispiel #4
0
def test_file_broker_from_config():
    cfg = utils.read_endpoint_config(
        "data/test_endpoints/event_brokers/"
        "file_endpoint.yml", "event_broker")
    actual = broker.from_endpoint_config(cfg)

    assert isinstance(actual, FileProducer)
    assert actual.path == "rasa_event.log"
def test_tracker_store_from_invalid_module(default_domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = utils.read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "a.module.which.cannot.be.found"

    tracker_store = TrackerStore.find_tracker_store(default_domain,
                                                    store_config)

    assert isinstance(tracker_store, InMemoryTrackerStore)
def test_tracker_store_from_invalid_string(default_domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = utils.read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "any string"

    tracker_store = TrackerStore.find_tracker_store(default_domain,
                                                    store_config)

    assert isinstance(tracker_store, InMemoryTrackerStore)
Beispiel #7
0
def create_tracker_store(core_model, endpoints):
    domain = get_domain(core_model)
    tracker_publish = utils.read_endpoint_config(endpoints, "tracker-publish")
    # Setup tracker store
    tracker_store = None
    store = utils.read_endpoint_config(endpoints, "tracker-store")
    if store.type == 'memory':
        tracker_store = InMemoryTrackerStore(domain=domain,
                                             publish_url=tracker_publish.url)
    elif store.type == 'redis':
        tracker_store = RedisTrackerStore(domain=domain,
                                          host=store.host,
                                          port=store.port,
                                          db=store.db,
                                          password=store.password,
                                          publish_url=tracker_publish.url)

    return tracker_store
Beispiel #8
0
def get_agent():
    from rasa_nlu import utils, config
    from rasa_core.policies.fallback import FallbackPolicy
    from rasa_core.policies.keras_policy import KerasPolicy
    from rasa_core.interpreter import RasaNLUInterpreter
    from rasa_core.agent import Agent
    interpreter = RasaNLUInterpreter('sample/models/current/nlu/default/default')
    action_endpoint_conf = utils.read_endpoint_config(os.path.join(current_dir, "sample/core_config.yml"), endpoint_type="action_endpoint")
    agent = Agent.load("sample/models/current/dialogue", interpreter=interpreter, action_endpoint=action_endpoint_conf)
    return agent
Beispiel #9
0
def test_broker_from_config():
    cfg = utils.read_endpoint_config(EVENT_BROKER_ENDPOINT_FILE,
                                     "event_broker")
    actual = PikaProducer.from_endpoint_config(cfg)

    expected = PikaProducer("localhost", "username", "password", "queue")

    assert actual.host == expected.host
    assert actual.credentials.username == expected.credentials.username
    assert actual.queue == expected.queue
Beispiel #10
0
def test_pika_broker_from_config():
    cfg = utils.read_endpoint_config(
        'data/test_endpoints/event_brokers/'
        'pika_endpoint.yml', "event_broker")
    actual = broker.from_endpoint_config(cfg)

    assert isinstance(actual, PikaProducer)
    assert actual.host == "localhost"
    assert actual.credentials.username == "username"
    assert actual.queue == "queue"
def test_tracker_store_endpoint_config_loading():
    cfg = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "tracker_store")

    assert cfg == EndpointConfig.from_dict({
        "type": "redis",
        "url": "localhost",
        "port": 6379,
        "db": 0,
        "password": "******",
        "timeout": 30000
    })
def test_find_tracker_store(default_domain):
    store = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "tracker_store")
    tracker_store = RedisTrackerStore(domain=default_domain,
                                      host="localhost",
                                      port=6379,
                                      db=0,
                                      password="******",
                                      record_exp=3000)

    assert isinstance(
        tracker_store,
        type(TrackerStore.find_tracker_store(default_domain, store)))
Beispiel #13
0
 def read_endpoints(cls, endpoint_file):
     nlg = read_endpoint_config(
             endpoint_file, endpoint_type="nlg")
     nlu = read_endpoint_config(
             endpoint_file, endpoint_type="nlu")
     action = read_endpoint_config(
             endpoint_file, endpoint_type="action_endpoint")
     model = read_endpoint_config(
             endpoint_file, endpoint_type="models")
     tracker_store = read_endpoint_config(
             endpoint_file, endpoint_type="tracker_store")
     event_broker = read_endpoint_config(
             endpoint_file, endpoint_type="event_broker")
     rules = read_endpoint_config(
         endpoint_file, endpoint_type="rules")
     credentials = read_endpoint_config(
         endpoint_file, endpoint_type="credentials")
     nlu_models_info = read_endpoint_config(
         endpoint_file, endpoint_type="nlu_models_info")
     return cls(nlg, nlu, action, model, tracker_store, event_broker, rules, credentials, nlu_models_info)
Beispiel #14
0
def _read_endpoints(endpoint_file):
    """Read different endpoints from a file.
    Possible endpoints are:
    ---------
    nlg
        url to a nlg-server
    nlu
        url to a nlu-server
    action
        url to a custom-action-server
    model
        url from which to fetch the core-model

    Parameters:
    ----------
    endpoint_file:  str
                    path to yaml file which contains the endpoints

    Returns:
    ----------
    AvailableEndpoints
        tuple with the endpoints nlg, nlu, action and model

    """
    AvailableEndpoints = namedtuple('AvailableEndpoints', 'nlg '
                                    'nlu '
                                    'action '
                                    'model')

    nlg = utils.read_endpoint_config(endpoint_file, endpoint_type="nlg")
    nlu = utils.read_endpoint_config(endpoint_file, endpoint_type="nlu")
    action = utils.read_endpoint_config(endpoint_file,
                                        endpoint_type="action_endpoint")
    model = utils.read_endpoint_config(endpoint_file, endpoint_type="model")

    return AvailableEndpoints(nlg, nlu, action, model)
Beispiel #15
0
def train_dialogue_model(domain_file, stories_file, output_path,
                         nlu_model_path=None,
                         endpoints=None,
                         max_history=None,
                         dump_flattened_stories=False,
                         kwargs=None):
    if not kwargs:
        kwargs = {}

    action_endpoint = utils.read_endpoint_config(endpoints, "action_endpoint")

    fallback_args, kwargs = utils.extract_args(kwargs,
                                               {"nlu_threshold",
                                                "core_threshold",
                                                "fallback_action_name"})

    policies = [
        FallbackPolicy(
                fallback_args.get("nlu_threshold",
                                  DEFAULT_NLU_FALLBACK_THRESHOLD),
                fallback_args.get("core_threshold",
                                  DEFAULT_CORE_FALLBACK_THRESHOLD),
                fallback_args.get("fallback_action_name",
                                  DEFAULT_FALLBACK_ACTION)),
        MemoizationPolicy(
                max_history=max_history),
        KerasPolicy(
                MaxHistoryTrackerFeaturizer(BinarySingleStateFeaturizer(),
                                            max_history=max_history))]

    agent = Agent(domain_file,
                  action_endpoint=action_endpoint,
                  interpreter=nlu_model_path,
                  policies=policies)

    data_load_args, kwargs = utils.extract_args(kwargs,
                                                {"use_story_concatenation",
                                                 "unique_last_num_states",
                                                 "augmentation_factor",
                                                 "remove_duplicates",
                                                 "debug_plots"})

    training_data = agent.load_data(stories_file, **data_load_args)
    agent.train(training_data, **kwargs)
    agent.persist(output_path, dump_flattened_stories)

    return agent
Beispiel #16
0
def train_dialogue_model(domain_file, stories_file, output_path,
                         nlu_model_path=None,
                         endpoints=None,
                         max_history=None,
                         dump_flattened_stories=False,
                         kwargs=None):
    if not kwargs:
        kwargs = {}

    action_endpoint = utils.read_endpoint_config(endpoints, "action_endpoint")

    fallback_args, kwargs = utils.extract_args(kwargs,
                                               {"nlu_threshold",
                                                "core_threshold",
                                                "fallback_action_name"})

    policies = [
        FallbackPolicy(
                fallback_args.get("nlu_threshold",
                                  DEFAULT_NLU_FALLBACK_THRESHOLD),
                fallback_args.get("core_threshold",
                                  DEFAULT_CORE_FALLBACK_THRESHOLD),
                fallback_args.get("fallback_action_name",
                                  DEFAULT_FALLBACK_ACTION)),
        MemoizationPolicy(
                max_history=max_history),
        KerasPolicy(
                MaxHistoryTrackerFeaturizer(BinarySingleStateFeaturizer(),
                                            max_history=max_history))]

    agent = Agent(domain_file,
                  action_endpoint=action_endpoint,
                  interpreter=nlu_model_path,
                  policies=policies)

    data_load_args, kwargs = utils.extract_args(kwargs,
                                                {"use_story_concatenation",
                                                 "unique_last_num_states",
                                                 "augmentation_factor",
                                                 "remove_duplicates",
                                                 "debug_plots"})

    training_data = agent.load_data(stories_file, **data_load_args)
    agent.train(training_data, **kwargs)
    agent.persist(output_path, dump_flattened_stories)

    return agent
def test_kafka_broker_from_config():
    endpoints_path = 'data/test_endpoints/event_brokers/' \
                     'kafka_plaintext_endpoint.yml'
    cfg = utils.read_endpoint_config(endpoints_path, "event_broker")

    actual = KafkaProducer.from_endpoint_config(cfg)

    expected = KafkaProducer("localhost",
                             "username",
                             "password",
                             topic="topic",
                             security_protocol="SASL_PLAINTEXT")

    assert actual.host == expected.host
    assert actual.sasl_username == expected.sasl_username
    assert actual.sasl_password == expected.sasl_password
    assert actual.topic == expected.topic
Beispiel #18
0
def recreate_agent(
        model_directory,  # type: Text
        nlu_model=None,  # type: Optional[Text]
        tracker_dump=None,  # type: Optional[Text]
        endpoints=None):
    # type: (...) -> Tuple[Agent, DialogueStateTracker]
    """Recreate an agent instance."""

    nlg_endpoint = utils.read_endpoint_config(endpoints, "nlg")

    logger.debug("Loading Rasa Core Agent")
    agent = Agent.load(model_directory, nlu_model, generator=nlg_endpoint)

    logger.debug("Finished loading agent. Loading stories now.")

    tracker = load_tracker_from_json(tracker_dump, agent.domain)
    replay_events(tracker, agent)

    return agent, tracker
Beispiel #19
0
def recreate_agent(model_directory,  # type: Text
                   nlu_model=None,  # type: Optional[Text]
                   tracker_dump=None,  # type: Optional[Text]
                   endpoints=None
                   ):
    # type: (...) -> Tuple[Agent, DialogueStateTracker]
    """Recreate an agent instance."""

    nlg_endpoint = utils.read_endpoint_config(endpoints, "nlg")

    logger.debug("Loading Rasa Core Agent")
    agent = Agent.load(model_directory, nlu_model,
                       generator=nlg_endpoint)

    logger.debug("Finished loading agent. Loading stories now.")

    tracker = load_tracker_from_json(tracker_dump, agent.domain)
    replay_events(tracker, agent)

    return agent, tracker
Beispiel #20
0
    log.setLevel(logging.WARN)

    logger.info("Rasa process starting")

    interpreter = interpreter_from_args(nlu_model, nlu_endpoint)
    agent = Agent.load(model_directory, interpreter, generator=nlg_endpoint)

    logger.info("Finished loading agent, starting input channel & server.")
    if channel:
        input_channel = create_input_channel(channel, port, credentials_file)
        agent.handle_channel(input_channel)

    return agent


if __name__ == '__main__':
    # Running as standalone python application
    arg_parser = create_argument_parser()
    cmdline_args = arg_parser.parse_args()

    utils.configure_colored_logging(cmdline_args.loglevel)
    utils.configure_file_logging(cmdline_args.loglevel, cmdline_args.log_file)

    nlg_endpoint = utils.read_endpoint_config(cmdline_args.endpoints, "nlg")

    nlu_endpoint = utils.read_endpoint_config(cmdline_args.endpoints, "nlu")

    main(cmdline_args.core, cmdline_args.nlu, cmdline_args.connector,
         cmdline_args.port, cmdline_args.credentials, nlg_endpoint,
         nlu_endpoint)
Beispiel #21
0
def test_nlg_endpoint_config_loading():
    cfg = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "nlg")

    assert cfg == EndpointConfig.from_dict(
        {"url": "http://localhost:5055/nlg"})
Beispiel #22
0
def test_nlg_endpoint_config_loading():
    cfg = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "nlg")

    assert cfg == EndpointConfig.from_dict({
        "url": "http://localhost:5055/nlg"
    })
def create_app(
        model_directory,  # type: Text
        interpreter=None,  # type: Union[Text, NLI, None]
        loglevel="INFO",  # type: Optional[Text]
        logfile="rasa_core.log",  # type: Optional[Text]
        cors_origins=None,  # type: Optional[List[Text]]
        action_factory=None,  # type: Optional[Text]
        auth_token=None,  # type: Optional[Text]
        tracker_store=None,  # type: Optional[TrackerStore]
        endpoints=None):
    """Class representing a Rasa Core HTTP server."""

    app = Flask(__name__)
    CORS(app, resources={r"/*": {"origins": "*"}})
    # Setting up logfile
    utils.configure_file_logging(loglevel, logfile)

    if not cors_origins:
        cors_origins = []
    model_directory = model_directory

    nlg_endpoint = utils.read_endpoint_config(endpoints, "nlg")

    nlu_endpoint = utils.read_endpoint_config(endpoints, "nlu")

    tracker_store = tracker_store

    action_factory = action_factory

    _interpreter = run.interpreter_from_args(interpreter, nlu_endpoint)

    # this needs to be an array, so we can modify it in the nested functions...
    _agent = [
        _create_agent(model_directory, _interpreter, action_factory,
                      tracker_store, nlg_endpoint)
    ]

    def agent():
        if _agent and _agent[0]:
            return _agent[0]
        else:
            return None

    @app.route("/", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def hello():
        """Check if the server is running and responds with the version."""
        return "hello from Rasa Core: " + __version__

    @app.route("/version", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def version():
        """respond with the version number of the installed rasa core."""

        return jsonify({'version': __version__})

    # <sender_id> can be be 'default' if there's only 1 client
    @app.route("/conversations/<sender_id>/continue",
               methods=['POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def continue_predicting(sender_id):
        """Continue a prediction started with parse.

        Caller should have executed the action returned from the parse
        endpoint. The events returned from that executed action are
        passed to continue which will trigger the next action prediction.

        If continue predicts action listen, the caller should wait for the
        next user message."""

        request_params = request.get_json(force=True)
        encoded_events = request_params.get("events", [])
        executed_action = request_params.get("executed_action", None)
        evts = events.deserialise_events(encoded_events)
        try:
            response = agent().continue_message_handling(
                sender_id, executed_action, evts)
        except ValueError as e:
            return Response(jsonify(error=e.message),
                            status=400,
                            content_type="application/json")
        except Exception as e:
            logger.exception(e)
            return Response(jsonify(error="Server failure. Error: {}"
                                    "".format(e)),
                            status=500,
                            content_type="application/json")
        return jsonify(response)

    @app.route("/conversations/<sender_id>/tracker/events",
               methods=['POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def append_events(sender_id):
        """Append a list of events to the state of a conversation"""

        request_params = request.get_json(force=True)
        evts = events.deserialise_events(request_params)
        tracker = agent().tracker_store.get_or_create_tracker(sender_id)
        for e in evts:
            tracker.update(e)
        agent().tracker_store.save(tracker)
        return jsonify(tracker.current_state())

    @app.route("/conversations", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def list_trackers():
        return jsonify(list(agent().tracker_store.keys()))

    @app.route("/conversations/<sender_id>/tracker",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def retrieve_tracker(sender_id):
        """Get a dump of a conversations tracker including its events."""

        # parameters
        use_history = bool_arg('ignore_restarts', default=False)
        should_include_events = bool_arg('events', default=True)
        until_time = request.args.get('until', None)

        # retrieve tracker and set to requested state
        tracker = agent().tracker_store.get_or_create_tracker(sender_id)
        if until_time is not None:
            tracker = tracker.travel_back_in_time(float(until_time))

        # dump and return tracker
        state = tracker.current_state(
            should_include_events=should_include_events,
            only_events_after_latest_restart=use_history)
        return jsonify(state)

    @app.route("/conversations/<sender_id>/tracker",
               methods=['PUT', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def update_tracker(sender_id):
        """Use a list of events to set a conversations tracker to a state."""

        request_params = request.get_json(force=True)
        tracker = DialogueStateTracker.from_dict(sender_id, request_params,
                                                 agent().domain)
        agent().tracker_store.save(tracker)

        # will override an existing tracker with the same id!
        agent().tracker_store.save(tracker)
        return jsonify(tracker.current_state(should_include_events=True))

    @app.route("/domain", methods=['GET'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def get_domain():
        """Get current domain in yaml format."""
        accepts = request.headers.get("Accept", default="application/json")
        if accepts.endswith("json"):
            domain = agent().domain.as_dict()
            return jsonify(domain)
        elif accepts.endswith("yml"):
            domain_yaml = agent().domain.as_yaml()
            return Response(domain_yaml,
                            status=200,
                            content_type="application/x-yml")
        else:
            return Response("""Invalid accept header. Domain can be provided 
                    as json ("Accept: application/json")  
                    or yml ("Accept: application/x-yml"). 
                    Make sure you've set the appropriate Accept header.""",
                            status=406)

    @app.route("/conversations/<sender_id>/parse",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def parse(sender_id):
        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.pop('query')
        elif 'q' in request_params:
            message = request_params.pop('q')
        else:
            return Response(
                jsonify(error="Invalid parse parameter specified."),
                status=400,
                mimetype="application/json")

        try:
            # Fetches the predicted action in a json format
            response = agent().start_message_handling(message, sender_id)
            return jsonify(response)

        except Exception as e:
            logger.exception("Caught an exception during parse.")
            return Response(jsonify(error="Server failure. Error: {}"
                                    "".format(e)),
                            status=500,
                            content_type="application/json")

    @app.route("/conversations/<sender_id>/respond",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def respond(sender_id):
        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.pop('query')
        elif 'q' in request_params:
            message = request_params.pop('q')
        else:
            return Response(jsonify(error="Invalid respond parameter "
                                    "specified."),
                            status=400,
                            mimetype="application/json")

        try:
            # Set the output channel
            out = CollectingOutputChannel()
            # Fetches the appropriate bot response in a json format
            responses = agent().handle_message(message,
                                               output_channel=out,
                                               sender_id=sender_id)
            return jsonify(responses)

        except Exception as e:
            logger.exception("Caught an exception during respond.")
            return Response(jsonify(error="Server failure. Error: {}"
                                    "".format(e)),
                            status=500,
                            content_type="application/json")

    @app.route("/load", methods=['POST', 'OPTIONS'])
    @requires_auth(auth_token)
    @cross_origin(origins=cors_origins)
    def load_model():
        """Loads a zipped model, replacing the existing one."""

        if 'model' not in request.files:
            # model file is missing
            abort(400)

        model_file = request.files['model']

        logger.info("Received new model through REST interface.")
        zipped_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip")
        zipped_path.close()

        model_file.save(zipped_path.name)

        logger.debug("Downloaded model to {}".format(zipped_path.name))

        zip_ref = zipfile.ZipFile(zipped_path.name, 'r')
        zip_ref.extractall(model_directory)
        zip_ref.close()
        logger.debug("Unzipped model to {}".format(
            os.path.abspath(model_directory)))

        _agent[0] = _create_agent(model_directory, interpreter, action_factory,
                                  tracker_store, nlg_endpoint)
        logger.debug("Finished loading new agent.")
        return jsonify({'success': 1})

    return app
Beispiel #24
0
def test_no_broker_in_config():
    cfg = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "event_broker")

    actual = broker.from_endpoint_config(cfg)

    assert actual is None
Beispiel #25
0
def create_app(model_directory,  # type: Text
               interpreter=None,  # type: Union[Text, NLI, None]
               loglevel="INFO",  # type: Optional[Text]
               logfile="rasa_core.log",  # type: Optional[Text]
               cors_origins=None,  # type: Optional[List[Text]]
               action_factory=None,  # type: Optional[Text]
               auth_token=None,  # type: Optional[Text]
               tracker_store=None,  # type: Optional[TrackerStore]
               endpoints=None
               ):
    """Class representing a Rasa Core HTTP server."""

    app = Flask(__name__)
    CORS(app, resources={r"/*": {"origins": "*"}})
    # Setting up logfile
    utils.configure_file_logging(loglevel, logfile)

    if not cors_origins:
        cors_origins = []
    model_directory = model_directory

    nlg_endpoint = utils.read_endpoint_config(endpoints, "nlg")

    nlu_endpoint = utils.read_endpoint_config(endpoints, "nlu")

    tracker_store = tracker_store

    action_factory = action_factory

    _interpreter = run.interpreter_from_args(interpreter, nlu_endpoint)

    # this needs to be an array, so we can modify it in the nested functions...
    _agent = [_create_agent(model_directory, _interpreter,
                            action_factory, tracker_store, nlg_endpoint)]

    def agent():
        if _agent and _agent[0]:
            return _agent[0]
        else:
            return None

    @app.route("/",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def hello():
        """Check if the server is running and responds with the version."""
        return "hello from Rasa Core: " + __version__

    @app.route("/version",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def version():
        """respond with the version number of the installed rasa core."""

        return jsonify({'version': __version__})

    # <sender_id> can be be 'default' if there's only 1 client
    @app.route("/conversations/<sender_id>/continue",
               methods=['POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def continue_predicting(sender_id):
        """Continue a prediction started with parse.

        Caller should have executed the action returned from the parse
        endpoint. The events returned from that executed action are
        passed to continue which will trigger the next action prediction.

        If continue predicts action listen, the caller should wait for the
        next user message."""

        request_params = request.get_json(force=True)
        encoded_events = request_params.get("events", [])
        executed_action = request_params.get("executed_action", None)
        evts = events.deserialise_events(encoded_events)
        try:
            response = agent().continue_message_handling(sender_id,
                                                         executed_action,
                                                         evts)
        except ValueError as e:
            return Response(jsonify(error=e.message),
                            status=400,
                            content_type="application/json")
        except Exception as e:
            logger.exception(e)
            return Response(jsonify(error="Server failure. Error: {}"
                                          "".format(e)),
                            status=500,
                            content_type="application/json")
        return jsonify(response)

    @app.route("/conversations/<sender_id>/tracker/events",
               methods=['POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def append_events(sender_id):
        """Append a list of events to the state of a conversation"""

        request_params = request.get_json(force=True)
        evts = events.deserialise_events(request_params)
        tracker = agent().tracker_store.get_or_create_tracker(sender_id)
        for e in evts:
            tracker.update(e)
        agent().tracker_store.save(tracker)
        return jsonify(tracker.current_state())

    @app.route("/conversations",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def list_trackers():
        return jsonify(list(agent().tracker_store.keys()))

    @app.route("/conversations/<sender_id>/tracker",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def retrieve_tracker(sender_id):
        """Get a dump of a conversations tracker including its events."""

        # parameters
        use_history = bool_arg('ignore_restarts', default=False)
        should_include_events = bool_arg('events', default=True)
        until_time = request.args.get('until', None)

        # retrieve tracker and set to requested state
        tracker = agent().tracker_store.get_or_create_tracker(sender_id)
        if until_time is not None:
            tracker = tracker.travel_back_in_time(float(until_time))

        # dump and return tracker
        state = tracker.current_state(
                should_include_events=should_include_events,
                only_events_after_latest_restart=use_history)
        return jsonify(state)

    @app.route("/conversations/<sender_id>/tracker",
               methods=['PUT', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def update_tracker(sender_id):
        """Use a list of events to set a conversations tracker to a state."""

        request_params = request.get_json(force=True)
        tracker = DialogueStateTracker.from_dict(sender_id,
                                                 request_params,
                                                 agent().domain)
        agent().tracker_store.save(tracker)

        # will override an existing tracker with the same id!
        agent().tracker_store.save(tracker)
        return jsonify(tracker.current_state(should_include_events=True))

    @app.route("/domain",
               methods=['GET'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def get_domain():
        """Get current domain in yaml format."""
        accepts = request.headers.get("Accept", default="application/json")
        if accepts.endswith("json"):
            domain = agent().domain.as_dict()
            return jsonify(domain)
        elif accepts.endswith("yml"):
            domain_yaml = agent().domain.as_yaml()
            return Response(domain_yaml,
                            status=200,
                            content_type="application/x-yml")
        else:
            return Response(
                    """Invalid accept header. Domain can be provided 
                    as json ("Accept: application/json")  
                    or yml ("Accept: application/x-yml"). 
                    Make sure you've set the appropriate Accept header.""",
                    status=406)

    @app.route("/conversations/<sender_id>/parse",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def parse(sender_id):
        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.pop('query')
        elif 'q' in request_params:
            message = request_params.pop('q')
        else:
            return Response(
                    jsonify(error="Invalid parse parameter specified."),
                    status=400,
                    mimetype="application/json")

        try:
            # Fetches the predicted action in a json format
            response = agent().start_message_handling(message, sender_id)
            return jsonify(response)

        except Exception as e:
            logger.exception("Caught an exception during parse.")
            return Response(jsonify(error="Server failure. Error: {}"
                                          "".format(e)),
                            status=500,
                            content_type="application/json")

    @app.route("/conversations/<sender_id>/respond",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def respond(sender_id):
        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.pop('query')
        elif 'q' in request_params:
            message = request_params.pop('q')
        else:
            return Response(jsonify(error="Invalid respond parameter "
                                          "specified."),
                            status=400,
                            mimetype="application/json")

        try:
            # Set the output channel
            out = CollectingOutputChannel()
            # Fetches the appropriate bot response in a json format
            responses = agent().handle_message(message,
                                               output_channel=out,
                                               sender_id=sender_id)
            return jsonify(responses)

        except Exception as e:
            logger.exception("Caught an exception during respond.")
            return Response(jsonify(error="Server failure. Error: {}"
                                          "".format(e)),
                            status=500,
                            content_type="application/json")

    @app.route("/load", methods=['POST', 'OPTIONS'])
    @requires_auth(auth_token)
    @cross_origin(origins=cors_origins)
    def load_model():
        """Loads a zipped model, replacing the existing one."""

        if 'model' not in request.files:
            # model file is missing
            abort(400)

        model_file = request.files['model']

        logger.info("Received new model through REST interface.")
        zipped_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip")
        zipped_path.close()

        model_file.save(zipped_path.name)

        logger.debug("Downloaded model to {}".format(zipped_path.name))

        zip_ref = zipfile.ZipFile(zipped_path.name, 'r')
        zip_ref.extractall(model_directory)
        zip_ref.close()
        logger.debug("Unzipped model to {}".format(
                os.path.abspath(model_directory)))

        _agent[0] = _create_agent(model_directory, interpreter,
                                  action_factory, tracker_store, nlg_endpoint)
        logger.debug("Finished loading new agent.")
        return jsonify({'success': 1})

    return app
Beispiel #26
0
def create_app(
        model_directory,  # type: Text
        interpreter=None,  # type: Union[Text, NLI, None]
        loglevel="INFO",  # type: Optional[Text]
        logfile="rasa_core.log",  # type: Optional[Text]
        cors_origins=None,  # type: Optional[List[Text]]
        action_factory=None,  # type: Optional[Text]
        auth_token=None,  # type: Optional[Text]
        tracker_store=None,  # type: Optional[TrackerStore]
        endpoints=None):
    """Class representing a Rasa Core HTTP server."""

    app = Flask(__name__)
    CORS(app, resources={r"/*": {"origins": "*"}})
    # Setting up logfile
    utils.configure_file_logging(loglevel, logfile)

    if not cors_origins:
        cors_origins = []
    model_directory = model_directory

    nlg_endpoint = utils.read_endpoint_config(endpoints, "nlg")

    nlu_endpoint = utils.read_endpoint_config(endpoints, "nlu")

    tracker_store = tracker_store

    action_factory = action_factory

    _interpreter = run.interpreter_from_args(interpreter, nlu_endpoint)

    # this needs to be an array, so we can modify it in the nested functions...
    _agent = [
        _create_agent(model_directory, _interpreter, action_factory,
                      tracker_store, nlg_endpoint)
    ]

    def agent():
        if _agent and _agent[0]:
            return _agent[0]
        else:
            return None

    @app.route("/", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def hello():
        """Check if the server is running and responds with the version."""
        return "hello from Rasa Core: " + __version__

    @app.route("/status/health.check", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def version():
        """respond with the version number of the installed rasa core."""

        # return jsonify({'version': __version__})
        return "OK:dlg-app:reg_20181121_01"

    @app.route("/dlg-app/service/chat/v1/im",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def respond():
        global list_data

        try:
            request_params = request_parameters()
            if 'partyNo' in request_params:
                partyNo = request_params.pop('partyNo')
                if len(partyNo.strip()) == 0:
                    logger.exception("partyNo is empty")
                    return jsonify({
                        "resCode": 400101,
                        "resMsg": "partyNo is empty",
                        "data": ""
                    })

            else:
                logger.exception("partyNo does not exist")
                return jsonify({
                    "resCode": 400102,
                    "resMsg": "partyNo [1] does not exist",
                    "data": ""
                })

            if 'question' in request_params:
                question = request_params.pop('question')
                if len(question.strip()) == 0:
                    logger.exception("question is empty")
                    return jsonify({
                        "resCode": 0,
                        "resMsg": "question is empty",
                        "data": {
                            "type":
                            1,
                            "msgType":
                            "LU:CsMsg",
                            "content":
                            str({
                                "title": "",
                                "content": "question is empty"
                            })
                        }
                    })

            if 'channel' in request_params:
                channel = request_params.pop('channel')
                if len(channel.strip()) == 0:
                    logger.exception("channel is empty")
                    return jsonify({
                        "resCode": 400101,
                        "resMsg": "channel is empty",
                        "data": ""
                    })
            else:
                logger.exception("channel does not exist")
                return jsonify({
                    "resCode": 400103,
                    "resMsg": "channel [3] does not exist",
                    "data": ""
                })

            appVersion = request_params.pop('appVersion')

            if 'msgId' in request_params:
                msgId = request_params.pop('msgId')
                if len(msgId.strip()) == 0:
                    logger.exception("msgId is empty")
                    return jsonify({
                        "resCode": 400101,
                        "resMsg": "msgId is empty",
                        "data": ""
                    })
            else:
                logger.exception("msgId does not exist")
                return jsonify({
                    "resCode": 400103,
                    "resMsg": "msgId does not exist",
                    "data": ""
                })

            # Set the output channel
            out = CollectingOutputChannel()
            # Fetches the appropriate bot response in a json format

            try:
                responses = agent().handle_message(question,
                                                   output_channel=out,
                                                   sender_id=partyNo)

                # print("response: "+str(responses))

                text = str(responses[0]["text"])

                if text.find("<utter_restart>") >= 0:
                    data_json = {
                        "type": 1,
                        "msgType": "LU:CsMsg",
                        "content": str({
                            "title": "",
                            "content": "restart"
                        })
                    }
                else:
                    content = ''
                    link = []

                    text_new = eval(text)
                    if type(text_new) is list:
                        # print(text_new)
                        data_json = {
                            "type": 1,
                            "msgType": "LU:CsMsg",
                            "content": str({
                                "title": "",
                                "content": text_new[0]
                            })
                        }
                    elif type(text_new) is dict:
                        # print(text_new)
                        tuple_tmp = list(text_new.items())[0]
                        if tuple_tmp[0] == 'REC':
                            data_json = {
                                "type": 1,
                                "msgType": "LU:LucyRecommend",
                                "content": tuple_tmp[1]
                            }
                        elif tuple_tmp[0] == 'XIAOAN':
                            data_json = {
                                "type": 1,
                                "msgType": "XIAOAN",
                                "content": tuple_tmp[1]
                            }
                        elif tuple_tmp[0] == 'CONTENT_LINK':
                            tmp = eval(tuple_tmp[1])
                            data_json = {
                                "type":
                                1,
                                "msgType":
                                "LU:CsMsg",
                                "content":
                                str({
                                    "title": "",
                                    "content": tmp["content"],
                                    "link": tmp["link"]
                                })
                            }
                        elif tuple_tmp[0] == 'TO_PERSON':
                            data_json = {
                                "type":
                                0,
                                "msgType":
                                "LU:CsMsg",
                                "content":
                                str({
                                    "title": "",
                                    "content": tuple_tmp[1]
                                })
                            }
                        else:
                            print("output error format")
                    else:
                        print("output error format")

            except Exception as e:
                logger.exception("Caught an exception of eval(text) is error.")
                data_json = {
                    "type": 1,
                    "msgType": "LU:CsMsg",
                    "content": str({
                        "title": "",
                        "content": "test1"
                    })
                }

            # try:
            #     print(data_json)
            # except Exception as e:
            #     logger.exception("Caught an exception of data_json is None.")
            #     data_json={"type":1,"msgType":"LU:CsMsg", "content":str({ "title":"", "content":"test2"})}

            # sql = "INSERT INTO dialogue (msgId, partyNo, question,responses,time_stamp) VALUES ( %s, %s, %s, %s,%s)"
            sql = "INSERT INTO dialogue (msgId, partyNo, question,responses) VALUES ( %s, %s, %s, %s)"
            # list_data.append((msgId,partyNo,question,text,str(datetime.now())))

            list_data.append((msgId, partyNo, question, str(data_json)))

            # print("sql: "+sql)

            if len(list_data) >= 10:
                try:
                    connect = pool.connection(
                    )  #以后每次需要数据库连接就是用connection()函数获取连接就好了
                    cursor = connect.cursor()
                    cursor.executemany(sql, list_data)
                    connect.commit()
                    print('成功插入', cursor.rowcount, '条数据')
                    list_data = []
                except Exception as e:
                    # print(e.message)
                    logger.exception(
                        "Caught an exception during insert mysql.")
                    connect.rollback()  # 事务回滚
                cursor.close()
                connect.close()

            return jsonify({
                "resCode": 0,
                "resMsg": "success",
                "data": data_json
            })

        except Exception as e:
            logger.exception("Caught an exception during respond.")
            return jsonify({
                "resCode": 9999,
                "resMsg": "failure",
                "data": {
                    "type": 1,
                    "msgType": "LU:CsMsg",
                    "content": str({
                        "title": "",
                        "content": "test3"
                    })
                }
            })

    return app
Beispiel #27
0
    logger.info("Finished loading agent, starting input channel & server.")
    if channel:
        input_channel = create_input_channel(channel, port, credentials_file)
        agent.handle_channel(input_channel)

    return agent


if __name__ == '__main__':
    # Running as standalone python application
    arg_parser = create_argument_parser()
    cmdline_args = arg_parser.parse_args()

    utils.configure_colored_logging(cmdline_args.loglevel)
    utils.configure_file_logging(cmdline_args.loglevel,
                                 cmdline_args.log_file)

    nlg_endpoint = utils.read_endpoint_config(cmdline_args.endpoints,
                                              "nlg")

    nlu_endpoint = utils.read_endpoint_config(cmdline_args.endpoints,
                                              "nlu")

    main(cmdline_args.core,
         cmdline_args.nlu,
         cmdline_args.connector,
         cmdline_args.port,
         cmdline_args.credentials,
         nlg_endpoint,
         nlu_endpoint)