Ejemplo n.º 1
0
 def create_tracker_store(store, domain):
     # type: (Optional[TrackerStore], Domain) -> TrackerStore
     if store is not None:
         store.domain = domain
         return store
     else:
         # modify by fangning 20181114
         return RedisTrackerStore(domain)
Ejemplo n.º 2
0
def test_find_tracker_store(default_domain):
    store = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "tracker_store")
    tracker_store = RedisTrackerStore(domain=default_domain,
                                      host="localhost",
                                      port=6379,
                                      db=0,
                                      password="******",
                                      record_exp=3000)

    assert isinstance(
        tracker_store,
        type(TrackerStore.find_tracker_store(default_domain, store)))
Ejemplo n.º 3
0
def create_tracker_store(core_model, endpoints):
    domain = get_domain(core_model)
    tracker_publish = utils.read_endpoint_config(endpoints, "tracker-publish")
    # Setup tracker store
    tracker_store = None
    store = utils.read_endpoint_config(endpoints, "tracker-store")
    if store.type == 'memory':
        tracker_store = InMemoryTrackerStore(domain=domain,
                                             publish_url=tracker_publish.url)
    elif store.type == 'redis':
        tracker_store = RedisTrackerStore(domain=domain,
                                          host=store.host,
                                          port=store.port,
                                          db=store.db,
                                          password=store.password,
                                          publish_url=tracker_publish.url)

    return tracker_store
Ejemplo n.º 4
0
def stores_to_be_tested():
    return [RedisTrackerStore(domain, mock=True),
            InMemoryTrackerStore(domain)]
Ejemplo n.º 5
0
 def save(self, tracker, timeout=None):
     timeout = self.timeout
     serialised_tracker = RedisTrackerStore.serialise_tracker(tracker)
     self.red.set(tracker.sender_id, serialised_tracker, ex=timeout)
Ejemplo n.º 6
0
import pymongo
from pymongo import MongoClient

mongo_service = os.environ['MONGO_SERVICE_HOST']
redis_service = os.environ['REDIS_SERVICE_HOST']

client = MongoClient('mongodb://{}:27017/'.format(mongo_service))
db = client.simba



# Agent number one
interpreter = RasaNLUInterpreter('./models/nlu/Auto/Auto')
domain = TemplateDomain.load("models/dialogue/domain.yml")
action_endpoint = EndpointConfig(url="http://localhost:5055/webhook")
tracker_store = RedisTrackerStore(domain, host="localhost", port=6379, db=13)

mongo_tracker_store = MongoTrackerStore(domain=domain, host="mongodb://{}".format(mongo_service), db="main")
redis_store = RedisTrackerStore(domain, host=redis_service, port=6379)

agent_one = Agent.load('domain', interpreter=interpreter, tracker_store=mongo_tracker_store,
                       action_endpoint=action_endpoint)

# Agent number two
# interpreter = RasaNLUInterpreter('./models/nlu/Auto/Auto')
# domain = TemplateDomain.load("models/dialogue/domain.yml")
# action_endpoint = EndpointConfig(url="http://localhost:5055/webhook")
# tracker_store = RedisTrackerStore(domain,host="localhost", port=6379, db= 13)

agent_two = Agent.load('domain', interpreter=interpreter, tracker_store=mongo_tracker_store,
                       action_endpoint=action_endpoint)
Ejemplo n.º 7
0
def create_app(model_directory,
               interpreter=None,
               loglevel="INFO",
               logfile="rasa_core.log",
               cors_origins=None,
               action_factory=None,
               auth_token=None,
               tracker_store=None):
    """Class representing a Rasa Core HTTP server."""

    app = Flask(__name__)
    CORS(app, resources={r"/*": {"origins": "*"}})

    utils.configure_file_logging(loglevel, logfile)

    if not cors_origins:
        cors_origins = []

    model_directory = model_directory

    domain = TemplateDomain.load(os.path.join(model_directory, "domain.yml"),
                                 action_factory)
    # ensures the domain hasn't changed between test and train
    tracker_store = RedisTrackerStore(domain, host=os.environ["REDIS_HOST"])
    action_factory = action_factory

    # this needs to be an array, so we can modify it in the nested functions...
    _agent = [
        _create_agent(model_directory, interpreter, action_factory,
                      tracker_store)
    ]

    def agent():
        if _agent and _agent[0]:
            return _agent[0]
        else:
            return None

    @app.route("/", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def hello():
        """Check if the server is running and responds with the version."""
        return "hello from Rasa Core: " + __version__

    @app.route("/version", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def version():
        """respond with the version number of the installed rasa core."""

        return jsonify({'version': __version__})

    @app.route("/conversations/<sender_id>/continue",
               methods=['POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def continue_predicting(sender_id):
        """Continue a prediction started with parse.

		Caller should have executed the action returned from the parse
		endpoint. The events returned from that executed action are
		passed to continue which will trigger the next action prediction.

		If continue predicts action listen, the caller should wait for the
		next user message."""

        request_params = request.get_json(force=True)
        encoded_events = request_params.get("events", [])
        executed_action = request_params.get("executed_action", None)
        evts = events.deserialise_events(encoded_events)
        try:
            response = agent().continue_message_handling(
                sender_id, executed_action, evts)
        except ValueError as e:
            return Response(jsonify(error=e.message),
                            status=400,
                            content_type="application/json")
        except Exception as e:
            logger.exception(e)
            return Response(jsonify(error="Server failure. Error: {}"
                                    "".format(e)),
                            status=500,
                            content_type="application/json")
        return jsonify(response)

    @app.route("/conversations/<sender_id>/tracker/events",
               methods=['POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def append_events(sender_id):
        """Append a list of events to the state of a conversation"""

        request_params = request.get_json(force=True)
        evts = events.deserialise_events(request_params)
        tracker = agent().tracker_store.get_or_create_tracker(sender_id)
        for e in evts:
            tracker.update(e)
        agent().tracker_store.save(tracker)
        return jsonify(tracker.current_state())

    @app.route("/conversations", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def list_trackers():
        return jsonify(list(agent().tracker_store.keys()))

    @app.route("/conversations/<sender_id>/tracker",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def retrieve_tracker(sender_id):
        """Get a dump of a conversations tracker including its events."""

        # parameters
        use_history = bool_arg('ignore_restarts', default=False)
        should_include_events = bool_arg('events', default=True)
        until_time = request.args.get('until', None)

        # retrieve tracker and set to requested state
        tracker = agent().tracker_store.get_or_create_tracker(sender_id)
        if until_time is not None:
            tracker = tracker.travel_back_in_time(float(until_time))

        # dump and return tracker
        state = tracker.current_state(
            should_include_events=should_include_events,
            only_events_after_latest_restart=use_history)
        return jsonify(state)

    @app.route("/conversations/<sender_id>/tracker",
               methods=['PUT', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def update_tracker(sender_id):
        """Use a list of events to set a conversations tracker to a state."""

        request_params = request.get_json(force=True)
        tracker = DialogueStateTracker.from_dict(sender_id, request_params,
                                                 agent().domain)
        agent().tracker_store.save(tracker)

        # will override an existing tracker with the same id!
        agent().tracker_store.save(tracker)
        return jsonify(tracker.current_state(should_include_events=True))

    @app.route("/conversations/<sender_id>/parse",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def parse(sender_id):
        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.pop('query')
        elif 'q' in request_params:
            message = request_params.pop('q')
        else:
            return Response(
                jsonify(error="Invalid parse parameter specified."),
                status=400,
                mimetype="application/json")

        try:
            response = agent().start_message_handling(message, sender_id)
            return jsonify(response)
        except Exception as e:
            logger.exception("Caught an exception during parse.")
            return Response(jsonify(error="Server failure. Error: {}"
                                    "".format(e)),
                            status=500,
                            content_type="application/json")

    @app.route("/conversations/<sender_id>/respond",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def respond(sender_id):
        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.get('query')
        elif 'q' in request_params:
            message = request_params.pop('q')
        else:
            return Response(jsonify(error="Invalid respond parameter "
                                    "specified."),
                            status=400,
                            mimetype="application/json")
        print(request_params)
        try:
            out = CollectingOutputChannel()
            responses = agent().handle_message(message,
                                               output_channel=out,
                                               sender_id=sender_id)
            return jsonify(responses)

        except Exception as e:
            logger.exception("Caught an exception during respond.")
            return Response(jsonify(error="Server failure. Error: {}"
                                    "".format(e)),
                            status=500,
                            content_type="application/json")

    @app.route("/load", methods=['POST', 'OPTIONS'])
    @requires_auth(auth_token)
    @cross_origin(origins=cors_origins)
    def load_model():
        """Loads a zipped model, replacing the existing one."""

        if 'model' not in request.files:
            # model file is missing
            abort(400)

        model_file = request.files['model']

        logger.info("Received new model through REST interface.")
        zipped_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip")
        zipped_path.close()

        model_file.save(zipped_path.name)

        logger.debug("Downloaded model to {}".format(zipped_path.name))

        zip_ref = zipfile.ZipFile(zipped_path.name, 'r')
        zip_ref.extractall(model_directory)
        zip_ref.close()
        logger.debug("Unzipped model to {}".format(
            os.path.abspath(model_directory)))

        _agent[0] = _create_agent(model_directory, interpreter, action_factory,
                                  tracker_store)
        logger.debug("Finished loading new agent.")
        return jsonify({'success': 1})

    return app
Ejemplo n.º 8
0
import os

redis_url = os.getenv('REDIS_URL', None)

if redis_url:
    from urllib.parse import urlparse
    from rasa_core.tracker_store import RedisTrackerStore
    from redis import Redis

    parse_result = urlparse(redis_url)

    hostname = parse_result.hostname
    port = parse_result.port
    password = parse_result.password

    tracker_store = RedisTrackerStore(None,
                                      host=hostname,
                                      port=port,
                                      db=0,
                                      password=password)

    scheduler_store = Redis(host=hostname, port=port, db=1, password=password)
else:
    tracker_store = None
    scheduler_store = None
Ejemplo n.º 9
0
 def save(self, tracker, timeout=None):
     timeout = self.timeout
     serialised_tracker = RedisTrackerStore.serialise_tracker(tracker)
     self.red.set(tracker.sender_id, serialised_tracker, ex=timeout)