Beispiel #1
0
    def __init__(self,
                 model_directory,
                 interpreter=None,
                 loglevel="INFO",
                 logfile="rasa_core.log",
                 cors_origins=None,
                 action_factory=None,
                 auth_token=None,
                 tracker_store=None):

        utils.configure_file_logging(loglevel, logfile)

        self.config = {
            "cors_origins": cors_origins if cors_origins else [],
            "token": auth_token
        }
        self.model_directory = model_directory
        self.interpreter = interpreter
        self.tracker_store = tracker_store
        self.action_factory = action_factory
        self.agent = self._create_agent(model_directory, interpreter,
                                        action_factory, tracker_store)
Beispiel #2
0
def train_core():
    from rasa_core.policies.fallback import FallbackPolicy
    from rasa_core.policies.keras_policy import KerasPolicy
    from rasa_core.policies.memoization import MemoizationPolicy
    from rasa_core.interpreter import RasaNLUInterpreter
    from rasa_core.agent import Agent
    from rasa_core import utils, server
    from rasa_core.channels.channel import UserMessage
    from rasa_core_sdk.executor import ActionExecutor

    utils.configure_colored_logging("DEBUG")
    utils.configure_file_logging("DEBUG", "rasa_core_logs.txt")

    agent = Agent(os.path.join(current_dir, "sample/domain.yml"), 
                  policies=[
                      MemoizationPolicy(),
                      KerasPolicy(), 
                      FallbackPolicy(fallback_action_name="action_default_fallback",
                              core_threshold=0.3,
                              nlu_threshold=0.3)])
    data = agent.load_data(os.path.join(current_dir, "sample/stories"))
    agent.train(data)
    agent.persist(os.path.join(current_dir, "sample/models/current/dialogue"))
Beispiel #3
0
                          interpreter=interpreter,
                          generator=endpoints.nlg,
                          tracker_store=tracker_store,
                          action_endpoint=endpoints.action)


if __name__ == '__main__':
    # Running as standalone python application
    arg_parser = create_argument_parser()
    cmdline_args = arg_parser.parse_args()

    logging.getLogger('werkzeug').setLevel(logging.WARN)
    logging.getLogger('matplotlib').setLevel(logging.WARN)

    utils.configure_colored_logging(cmdline_args.loglevel)
    utils.configure_file_logging(cmdline_args.loglevel,
                                 cmdline_args.log_file)

    logger.info("Rasa process starting")

    _endpoints = AvailableEndpoints.read_endpoints(cmdline_args.endpoints)
    _interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu,
                                                     _endpoints.nlu)
    _broker = PikaProducer.from_endpoint_config(_endpoints.event_broker)

    _tracker_store = TrackerStore.find_tracker_store(
        None, _endpoints.tracker_store, _broker)
    _agent = load_agent(cmdline_args.core,
                        interpreter=_interpreter,
                        tracker_store=_tracker_store,
                        endpoints=_endpoints)
    serve_application(_agent,
Beispiel #4
0
    return app


if __name__ == '__main__':
    # Running as standalone python application
    from rasa_core import run

    arg_parser = run.create_argument_parser()
    cmdline_args = arg_parser.parse_args()

    logging.getLogger('werkzeug').setLevel(logging.WARN)
    logging.getLogger('matplotlib').setLevel(logging.WARN)

    utils.configure_colored_logging(cmdline_args.loglevel)
    utils.configure_file_logging(cmdline_args.loglevel,
                                 cmdline_args.log_file)

    logger.warning("USING `rasa_core.server` is deprecated and will be "
                   "removed in the future. Use `rasa_core.run --enable_api` "
                   "instead.")

    logger.info("Rasa process starting")

    _endpoints = run.read_endpoints(cmdline_args.endpoints)
    _interpreter = NaturalLanguageInterpreter.create(cmdline_args.nlu,
                                                     _endpoints.nlu)
    _agent = run.load_agent(cmdline_args.core,
                            interpreter=_interpreter,
                            endpoints=_endpoints)

    run.serve_application(_agent,
Beispiel #5
0
def create_app(model_directory,  # type: Text
               interpreter=None,  # type: Union[Text, NLI, None]
               loglevel="INFO",  # type: Optional[Text]
               logfile="rasa_core.log",  # type: Optional[Text]
               cors_origins=None,  # type: Optional[List[Text]]
               action_factory=None,  # type: Optional[Text]
               auth_token=None,  # type: Optional[Text]
               tracker_store=None,  # type: Optional[TrackerStore]
               endpoints=None
               ):
    """Class representing a Rasa Core HTTP server."""

    app = Flask(__name__)
    CORS(app, resources={r"/*": {"origins": "*"}})
    # Setting up logfile
    utils.configure_file_logging(loglevel, logfile)

    if not cors_origins:
        cors_origins = []
    model_directory = model_directory

    nlg_endpoint = utils.read_endpoint_config(endpoints, "nlg")

    nlu_endpoint = utils.read_endpoint_config(endpoints, "nlu")

    tracker_store = tracker_store

    action_factory = action_factory

    _interpreter = run.interpreter_from_args(interpreter, nlu_endpoint)

    # this needs to be an array, so we can modify it in the nested functions...
    _agent = [_create_agent(model_directory, _interpreter,
                            action_factory, tracker_store, nlg_endpoint)]

    def agent():
        if _agent and _agent[0]:
            return _agent[0]
        else:
            return None

    @app.route("/",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def hello():
        """Check if the server is running and responds with the version."""
        return "hello from Rasa Core: " + __version__

    @app.route("/version",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def version():
        """respond with the version number of the installed rasa core."""

        return jsonify({'version': __version__})

    # <sender_id> can be be 'default' if there's only 1 client
    @app.route("/conversations/<sender_id>/continue",
               methods=['POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def continue_predicting(sender_id):
        """Continue a prediction started with parse.

        Caller should have executed the action returned from the parse
        endpoint. The events returned from that executed action are
        passed to continue which will trigger the next action prediction.

        If continue predicts action listen, the caller should wait for the
        next user message."""

        request_params = request.get_json(force=True)
        encoded_events = request_params.get("events", [])
        executed_action = request_params.get("executed_action", None)
        evts = events.deserialise_events(encoded_events)
        try:
            response = agent().continue_message_handling(sender_id,
                                                         executed_action,
                                                         evts)
        except ValueError as e:
            return Response(jsonify(error=e.message),
                            status=400,
                            content_type="application/json")
        except Exception as e:
            logger.exception(e)
            return Response(jsonify(error="Server failure. Error: {}"
                                          "".format(e)),
                            status=500,
                            content_type="application/json")
        return jsonify(response)

    @app.route("/conversations/<sender_id>/tracker/events",
               methods=['POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def append_events(sender_id):
        """Append a list of events to the state of a conversation"""

        request_params = request.get_json(force=True)
        evts = events.deserialise_events(request_params)
        tracker = agent().tracker_store.get_or_create_tracker(sender_id)
        for e in evts:
            tracker.update(e)
        agent().tracker_store.save(tracker)
        return jsonify(tracker.current_state())

    @app.route("/conversations",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def list_trackers():
        return jsonify(list(agent().tracker_store.keys()))

    @app.route("/conversations/<sender_id>/tracker",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def retrieve_tracker(sender_id):
        """Get a dump of a conversations tracker including its events."""

        # parameters
        use_history = bool_arg('ignore_restarts', default=False)
        should_include_events = bool_arg('events', default=True)
        until_time = request.args.get('until', None)

        # retrieve tracker and set to requested state
        tracker = agent().tracker_store.get_or_create_tracker(sender_id)
        if until_time is not None:
            tracker = tracker.travel_back_in_time(float(until_time))

        # dump and return tracker
        state = tracker.current_state(
                should_include_events=should_include_events,
                only_events_after_latest_restart=use_history)
        return jsonify(state)

    @app.route("/conversations/<sender_id>/tracker",
               methods=['PUT', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def update_tracker(sender_id):
        """Use a list of events to set a conversations tracker to a state."""

        request_params = request.get_json(force=True)
        tracker = DialogueStateTracker.from_dict(sender_id,
                                                 request_params,
                                                 agent().domain)
        agent().tracker_store.save(tracker)

        # will override an existing tracker with the same id!
        agent().tracker_store.save(tracker)
        return jsonify(tracker.current_state(should_include_events=True))

    @app.route("/domain",
               methods=['GET'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def get_domain():
        """Get current domain in yaml format."""
        accepts = request.headers.get("Accept", default="application/json")
        if accepts.endswith("json"):
            domain = agent().domain.as_dict()
            return jsonify(domain)
        elif accepts.endswith("yml"):
            domain_yaml = agent().domain.as_yaml()
            return Response(domain_yaml,
                            status=200,
                            content_type="application/x-yml")
        else:
            return Response(
                    """Invalid accept header. Domain can be provided 
                    as json ("Accept: application/json")  
                    or yml ("Accept: application/x-yml"). 
                    Make sure you've set the appropriate Accept header.""",
                    status=406)

    @app.route("/conversations/<sender_id>/parse",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def parse(sender_id):
        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.pop('query')
        elif 'q' in request_params:
            message = request_params.pop('q')
        else:
            return Response(
                    jsonify(error="Invalid parse parameter specified."),
                    status=400,
                    mimetype="application/json")

        try:
            # Fetches the predicted action in a json format
            response = agent().start_message_handling(message, sender_id)
            return jsonify(response)

        except Exception as e:
            logger.exception("Caught an exception during parse.")
            return Response(jsonify(error="Server failure. Error: {}"
                                          "".format(e)),
                            status=500,
                            content_type="application/json")

    @app.route("/conversations/<sender_id>/respond",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def respond(sender_id):
        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.pop('query')
        elif 'q' in request_params:
            message = request_params.pop('q')
        else:
            return Response(jsonify(error="Invalid respond parameter "
                                          "specified."),
                            status=400,
                            mimetype="application/json")

        try:
            # Set the output channel
            out = CollectingOutputChannel()
            # Fetches the appropriate bot response in a json format
            responses = agent().handle_message(message,
                                               output_channel=out,
                                               sender_id=sender_id)
            return jsonify(responses)

        except Exception as e:
            logger.exception("Caught an exception during respond.")
            return Response(jsonify(error="Server failure. Error: {}"
                                          "".format(e)),
                            status=500,
                            content_type="application/json")

    @app.route("/load", methods=['POST', 'OPTIONS'])
    @requires_auth(auth_token)
    @cross_origin(origins=cors_origins)
    def load_model():
        """Loads a zipped model, replacing the existing one."""

        if 'model' not in request.files:
            # model file is missing
            abort(400)

        model_file = request.files['model']

        logger.info("Received new model through REST interface.")
        zipped_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip")
        zipped_path.close()

        model_file.save(zipped_path.name)

        logger.debug("Downloaded model to {}".format(zipped_path.name))

        zip_ref = zipfile.ZipFile(zipped_path.name, 'r')
        zip_ref.extractall(model_directory)
        zip_ref.close()
        logger.debug("Unzipped model to {}".format(
                os.path.abspath(model_directory)))

        _agent[0] = _create_agent(model_directory, interpreter,
                                  action_factory, tracker_store, nlg_endpoint)
        logger.debug("Finished loading new agent.")
        return jsonify({'success': 1})

    return app
Beispiel #6
0
def create_app(model_directory,
               interpreter=None,
               loglevel="INFO",
               logfile="rasa_core.log",
               cors_origins=None,
               action_factory=None,
               auth_token=None,
               tracker_store=None):
    """Class representing a Rasa Core HTTP server."""

    app = Flask(__name__)
    CORS(app, resources={r"/*": {"origins": "*"}})

    utils.configure_file_logging(loglevel, logfile)

    if not cors_origins:
        cors_origins = []

    model_directory = model_directory

    interpreter = interpreter

    tracker_store = tracker_store

    action_factory = action_factory

    # this needs to be an array, so we can modify it in the nested functions...
    _agent = [
        _create_agent(model_directory, interpreter, action_factory,
                      tracker_store)
    ]

    def agent():
        if _agent and _agent[0]:
            return _agent[0]
        else:
            return None

    @app.route("/", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def hello():
        """Check if the server is running and responds with the version."""
        return "hello from Rasa Core: " + __version__

    @app.route("/conversations/<sender_id>/continue",
               methods=['POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def continue_predicting(sender_id):
        """Continue a prediction started with parse.

        Caller should have executed the action returned from the parse
        endpoint. The events returned from that executed action are
        passed to continue which will trigger the next action prediction.

        If continue predicts action listen, the caller should wait for the
        next user message."""

        request_params = request.get_json(force=True)
        encoded_events = request_params.get("events", [])
        executed_action = request_params.get("executed_action", None)
        evts = events.deserialise_events(encoded_events)
        try:
            response = agent().continue_message_handling(
                sender_id, executed_action, evts)
        except ValueError as e:
            return Response(jsonify(error=e.message),
                            status=400,
                            content_type="application/json")
        except Exception as e:
            logger.exception(e)
            return Response(jsonify(error="Server failure. Error: {}"
                                    "".format(e)),
                            status=500,
                            content_type="application/json")
        return jsonify(response)

    @app.route("/conversations/<sender_id>/tracker/events",
               methods=['POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @ensure_loaded_agent(agent)
    def append_events(sender_id):
        """Append a list of events to the state of a conversation"""

        request_params = request.get_json(force=True)
        evts = events.deserialise_events(request_params)
        tracker = agent().tracker_store.get_or_create_tracker(sender_id)
        for e in evts:
            tracker.update(e)
        agent().tracker_store.save(tracker)
        return jsonify(tracker.current_state())

    @app.route("/conversations", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @ensure_loaded_agent(agent)
    def list_trackers():
        return jsonify(list(agent().tracker_store.keys()))

    @app.route("/conversations/<sender_id>/tracker",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @ensure_loaded_agent(agent)
    def retrieve_tracker(sender_id):
        """Get a dump of a conversations tracker including its events."""

        # parameters
        use_history = bool_arg('ignore_restarts', default=False)
        should_include_events = bool_arg('events', default=True)
        until_time = request.args.get('until', None)

        # retrieve tracker and set to requested state
        tracker = agent().tracker_store.get_or_create_tracker(sender_id)
        if until_time is not None:
            tracker = tracker.travel_back_in_time(float(until_time))

        # dump and return tracker
        state = tracker.current_state(
            should_include_events=should_include_events,
            only_events_after_latest_restart=use_history)
        return jsonify(state)

    @app.route("/conversations/<sender_id>/tracker",
               methods=['PUT', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @ensure_loaded_agent(agent)
    def update_tracker(sender_id):
        """Use a list of events to set a conversations tracker to a state."""

        request_params = request.get_json(force=True)
        tracker = DialogueStateTracker.from_dict(sender_id, request_params,
                                                 agent().domain)
        agent().tracker_store.save(tracker)

        # will override an existing tracker with the same id!
        agent().tracker_store.save(tracker)
        return jsonify(tracker.current_state(should_include_events=True))

    @app.route("/conversations/<sender_id>/parse",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def parse(sender_id):
        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.pop('query')
        elif 'q' in request_params:
            message = request_params.pop('q')
        else:
            return Response(
                jsonify(error="Invalid parse parameter specified."),
                status=400,
                mimetype="application/json")

        try:
            response = agent().start_message_handling(message, sender_id)
            return jsonify(response)
        except Exception as e:
            logger.exception("Caught an exception during parse.")
            return Response(jsonify(error="Server failure. Error: {}"
                                    "".format(e)),
                            status=500,
                            content_type="application/json")

    @app.route("/conversations/<sender_id>/respond",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def respond(sender_id):
        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.pop('query')
        elif 'q' in request_params:
            message = request_params.pop('q')
        else:
            return Response(jsonify(error="Invalid respond parameter "
                                    "specified."),
                            status=400,
                            mimetype="application/json")

        try:
            out = CollectingOutputChannel()
            responses = agent().handle_message(message,
                                               output_channel=out,
                                               sender_id=sender_id)
            return jsonify(responses)

        except Exception as e:
            logger.exception("Caught an exception during respond.")
            return Response(jsonify(error="Server failure. Error: {}"
                                    "".format(e)),
                            status=500,
                            content_type="application/json")

    @app.route("/load", methods=['POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def load_model():
        """Loads a zipped model, replacing the existing one."""

        if 'model' not in request.files:
            # model file is missing
            abort(400)

        model_file = request.files['model']

        logger.info("Received new model through REST interface.")
        zipped_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip")
        zipped_path.close()

        model_file.save(zipped_path.name)

        logger.debug("Downloaded model to {}".format(zipped_path.name))

        zip_ref = zipfile.ZipFile(zipped_path.name, 'r')
        zip_ref.extractall(model_directory)
        zip_ref.close()
        logger.debug("Unzipped model to {}".format(
            os.path.abspath(model_directory)))

        _agent[0] = _create_agent(model_directory, interpreter, action_factory,
                                  tracker_store)
        logger.debug("Finished loading new agent.")
        return jsonify({'success': 1})

    @app.route("/version", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def version():
        """respond with the version number of the installed rasa core."""

        return jsonify({'version': __version__})

    return app
Beispiel #7
0
def create_app(env_json_file,
               loglevel="INFO",  # type: Optional[Text]
               logfile="rasa_core.log",  # type: Optional[Text]
               cors_origins=None,  # type: Optional[List[Text]]
               action_factory=None,  # type: Optional[Text]
               auth_token=None,  # type: Optional[Text]
               tracker_store=None  # type: Optional[TrackerStore]
               ):
    """Class representing a Rasa Core HTTP server."""

    app = Flask(__name__)
    CORS(app, resources={r"/*": {"origins": "*"}})
    # Setting up logfile
    utils.configure_file_logging(loglevel, logfile)

    if not cors_origins:
        cors_origins = []

    with open(env_json_file) as f:
        env_json = json.load(f)

    domain_file = env_json['domain']

    interpreter_file = env_json['interpreter_dir']
    path_to_glove = env_json['path_to_glove']


    tracker_store = tracker_store

    action_factory = action_factory

    

    # this needs to be an array, so we can modify it in the nested functions...
    domain = OntologyDomain.load(os.path.join(path, domain_file),
                                     None)


    
    from rasa_core.featurizers import FloatSingleStateFeaturizer,MaxHistoryTrackerFeaturizer
    feat=MaxHistoryTrackerFeaturizer(FloatSingleStateFeaturizer(),max_history=1)
    policy = idare.KerasIDarePolicy(feat)



    from rasa_core.policies.ensemble import SimplePolicyEnsemble
    ensemble = SimplePolicyEnsemble([policy])
    ensemble = ensemble.load(env_json['dst_model_dir'])

    # In[5]:


    from rasa_core.agent import Agent
    from rasa_core.interpreter import NaturalLanguageInterpreter
    from rasa_core.tracker_store import InMemoryTrackerStore
    _interpreter = Interpreter(suprath_dir = interpreter_file, rasa_dir = interpreter_file, 
                                path_to_glove = path_to_glove)
    logger.info("NLU interpreter loaded successfully")
    _tracker_store = Agent.create_tracker_store(None,domain,env_json)
    _agent = [Agent(domain, ensemble, _interpreter, _tracker_store,env_json)]
    global processor
    feedback_logger = sql_feedback_logger
    processor = _agent[0]._create_processor(feedback_logger=feedback_logger)
    
    usr_channel = None
    ha_channel = None
    teams_channel = None

    def agent():
        if _agent and _agent[0]:
            return _agent[0]
        else:
            return None

    @app.route("/",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def hello():
        """Check if the server is running and responds with the version."""
        return "hello from Rasa Core: " + __version__

    @app.route("/version",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def version():
        """respond with the version number of the installed rasa core."""

        return jsonify({'version': __version__})

    # <sender_id> can be be 'default' if there's only 1 client
    @app.route("/run",
               methods=['POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def continue_predicting():
        """Continue a prediction started with parse.

        Caller should have executed the action returned from the parse
        endpoint. The events returned from that executed action are
        passed to continue which will trigger the next action prediction.

        If continue predicts action listen, the caller should wait for the
        next user message."""
        from rasa_core.run import create_input_channel
        from rasa_core.channels.custom_websocket import CustomInput
        input_component1=CustomInput(None)
        input_component2=CustomInput(None)
        from rasa_core.channels.websocket import WebSocketInputChannel
        global usr_channel
        global ha_channel
        try:
            usr_channel = WebSocketInputChannel(int(env_json["userchannelport"]),None,input_component1,http_ip='0.0.0.0')

            botf_input_channel = BotFrameworkInput(
                  app_id=env_json['teams_app_id'],
                    app_password=env_json['teams_app_password']
                  )
            teams_channel = HttpInputChannel(int(env_json['userchannelport2']),'/webhooks/botframework',botf_input_channel)

        except OSError as e:
            logger.error(str(e))
            return str(e)


        usr_channel.output_channel = input_component1.output_channel
        teams_channel.output_channel = botf_input_channel.output_channel

        op_agent = agent()
        op_agent.handle_custom_processor([usr_channel,teams_channel],usr_channel,processor)

        return "ok"


    @app.route("/conversations/<sender_id>/tracker/events",
               methods=['POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def append_events(sender_id):
        """Append a list of events to the state of a conversation"""

        request_params = request.get_json(force=True)
        evts = events.deserialise_events(request_params)
        tracker = agent().tracker_store.get_or_create_tracker(sender_id)
        for e in evts:
            tracker.update(e)
        agent().tracker_store.save(tracker)
        return jsonify(tracker.current_state())

    @app.route("/conversations",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def list_trackers():
        return jsonify(list(agent().tracker_store.keys()))

    @app.route("/conversations/<sender_id>/tracker",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def retrieve_tracker(sender_id):
        """Get a dump of a conversations tracker including its events."""

        # parameters
        use_history = bool_arg('ignore_restarts', default=False)
        should_include_events = bool_arg('events', default=True)
        until_time = request.args.get('until', None)

        # retrieve tracker and set to requested state
        tracker = agent().tracker_store.get_or_create_tracker(sender_id)
        if until_time is not None:
            tracker = tracker.travel_back_in_time(float(until_time))

        # dump and return tracker
        state = tracker.current_state(
                should_include_events=should_include_events,
                only_events_after_latest_restart=use_history)
        return jsonify(state)

    @app.route("/conversations/<sender_id>/tracker",
               methods=['PUT', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def update_tracker(sender_id):
        """Use a list of events to set a conversations tracker to a state."""

        request_params = request.get_json(force=True)
        tracker = DialogueStateTracker.from_dict(sender_id,
                                                 request_params,
                                                 agent().domain)
        agent().tracker_store.save(tracker)

        # will override an existing tracker with the same id!
        agent().tracker_store.save(tracker)
        return jsonify(tracker.current_state(should_include_events=True))

    @app.route("/conversations/<sender_id>/parse",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def parse(sender_id):
        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.pop('query')
        elif 'q' in request_params:
            message = request_params.pop('q')
        else:
            return Response(
                    jsonify(error="Invalid parse parameter specified."),
                    status=400,
                    mimetype="application/json")

        try:
            # Fetches the predicted action in a json format
            response = agent().start_message_handling(message, sender_id)
            return jsonify(response)

        except Exception as e:
            logger.exception("Caught an exception during parse.")
            return Response(jsonify(error="Server failure. Error: {}"
                                          "".format(e)),
                            status=500,
                            content_type="application/json")

    @app.route("/conversations/<sender_id>/respond",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def respond(sender_id):
        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.pop('query')
        elif 'q' in request_params:
            message = request_params.pop('q')
        else:
            return Response(jsonify(error="Invalid respond parameter "
                                          "specified."),
                            status=400,
                            mimetype="application/json")

        try:
            # Set the output channel
            out = CollectingOutputChannel()
            # Fetches the appropriate bot response in a json format
            responses = agent().handle_message(message,
                                               output_channel=out,
                                               sender_id=sender_id)
            return jsonify(responses)

        except Exception as e:
            logger.exception("Caught an exception during respond.")
            return Response(jsonify(error="Server failure. Error: {}"
                                          "".format(e)),
                            status=500,
                            content_type="application/json")

    @app.route("/nlu_parse",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def nlu_parse():

        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.get('query')
        elif 'q' in request_params:
            message = request_params.get('q')
        else:
            return Response(jsonify(error="Invalid respond parameter "
                                          "specified."),
                            status=400,
                            mimetype="application/json")
        return jsonify(_interpreter.parse(message))

    @app.route("/email/<sender_id>/respond",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def email(sender_id):
        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.get('query')
        elif 'q' in request_params:
            message = request_params.get('q')
        else:
            return Response(jsonify(error="Invalid respond parameter "
                                          "specified."),
                            status=400,
                            mimetype="application/json")

        global teams_channel
        inter_channel_mapper = read_json_file(env_json.get('inter_channel_mapper'))
        teams_channel.output_channel.id_map.update(inter_channel_mapper)
        #temporary code follows
        teams_channel.output_channel.id_map.update({sender_id:
                            inter_channel_mapper[list(inter_channel_mapper.keys())[0]]})
        teams_channel.output_channel.reverse_id_map.update({list(inter_channel_mapper.keys())[0]:sender_id})
        #temporary code ends

        email_id = request_params.get('email_id')
        preprocessor = partial(idare.email_preprocessor,email_id=email_id)
        try:
            # Set the output channel
            out = CollectingOutputChannel()
            # Fetches the appropriate bot response in a json format
            agent().handle_email(message, email_preprocessor = preprocessor,
                                                output_channel=out,
                                                alternate_channel = teams_channel,
                                               sender_id=sender_id)
            response = out.latest_output()

            return jsonify(response)

        except Exception as e:
            logger.exception("Caught an exception during respond.")
            return Response(jsonify(error="Server failure. Error: {}"
                                          "".format(e)),
                            status=500,
                            content_type="application/json")


    @app.route("/load_nlu", methods=['POST', 'OPTIONS'])
    @requires_auth(auth_token)
    @cross_origin(origins=cors_origins)
    def load_nlu_model():
        """Loads a zipped model, replacing the existing one."""

        if 'nlu_model' not in request.files:
            # model file is missing
            abort(400)

        model_file = request.files['nlu_model']

        logger.info("Received new nlu_model through REST interface.")
        zipped_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip")
        zipped_path.close()

        model_file.save(zipped_path.name)

        logger.debug("Downloaded model to {}".format(zipped_path.name))

        zip_ref = zipfile.ZipFile(zipped_path.name, 'r')
        zip_ref.extractall(interpreter_file)
        zip_ref.close()
        logger.debug("Unzipped model to {}".format(
                os.path.abspath(interpreter_file)))
        
        global processor
        del processor.interpreter
        
        _interpreter.reload()

        processor.interpreter = _interpreter
        agent().interpreter = _interpreter
        logger.debug("Finished loading new interpreter.")
        return jsonify({'success': 1})

    @app.route("/load_idare", methods=['GET', 'OPTIONS'])
    @requires_auth(auth_token)
    @cross_origin(origins=cors_origins)
    def load_idare():
        """Reload idare."""
        from imp import reload
        global idare
        try:
            idare = reload(idare)
        except Exception as e:
            return str(e)
        return jsonify({'success': 1})
    return app
Beispiel #8
0
def create_app(
        model_directory,  # type: Text
        interpreter=None,  # type: Union[Text, NLI, None]
        loglevel="INFO",  # type: Optional[Text]
        logfile="rasa_core.log",  # type: Optional[Text]
        cors_origins=None,  # type: Optional[List[Text]]
        action_factory=None,  # type: Optional[Text]
        auth_token=None,  # type: Optional[Text]
        tracker_store=None,  # type: Optional[TrackerStore]
        endpoints=None):
    """Class representing a Rasa Core HTTP server."""

    app = Flask(__name__)
    CORS(app, resources={r"/*": {"origins": "*"}})
    # Setting up logfile
    utils.configure_file_logging(loglevel, logfile)

    if not cors_origins:
        cors_origins = []
    model_directory = model_directory

    nlg_endpoint = utils.read_endpoint_config(endpoints, "nlg")

    nlu_endpoint = utils.read_endpoint_config(endpoints, "nlu")

    tracker_store = tracker_store

    action_factory = action_factory

    _interpreter = run.interpreter_from_args(interpreter, nlu_endpoint)

    # this needs to be an array, so we can modify it in the nested functions...
    _agent = [
        _create_agent(model_directory, _interpreter, action_factory,
                      tracker_store, nlg_endpoint)
    ]

    def agent():
        if _agent and _agent[0]:
            return _agent[0]
        else:
            return None

    @app.route("/", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def hello():
        """Check if the server is running and responds with the version."""
        return "hello from Rasa Core: " + __version__

    @app.route("/status/health.check", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def version():
        """respond with the version number of the installed rasa core."""

        # return jsonify({'version': __version__})
        return "OK:dlg-app:reg_20181121_01"

    @app.route("/dlg-app/service/chat/v1/im",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def respond():
        global list_data

        try:
            request_params = request_parameters()
            if 'partyNo' in request_params:
                partyNo = request_params.pop('partyNo')
                if len(partyNo.strip()) == 0:
                    logger.exception("partyNo is empty")
                    return jsonify({
                        "resCode": 400101,
                        "resMsg": "partyNo is empty",
                        "data": ""
                    })

            else:
                logger.exception("partyNo does not exist")
                return jsonify({
                    "resCode": 400102,
                    "resMsg": "partyNo [1] does not exist",
                    "data": ""
                })

            if 'question' in request_params:
                question = request_params.pop('question')
                if len(question.strip()) == 0:
                    logger.exception("question is empty")
                    return jsonify({
                        "resCode": 0,
                        "resMsg": "question is empty",
                        "data": {
                            "type":
                            1,
                            "msgType":
                            "LU:CsMsg",
                            "content":
                            str({
                                "title": "",
                                "content": "question is empty"
                            })
                        }
                    })

            if 'channel' in request_params:
                channel = request_params.pop('channel')
                if len(channel.strip()) == 0:
                    logger.exception("channel is empty")
                    return jsonify({
                        "resCode": 400101,
                        "resMsg": "channel is empty",
                        "data": ""
                    })
            else:
                logger.exception("channel does not exist")
                return jsonify({
                    "resCode": 400103,
                    "resMsg": "channel [3] does not exist",
                    "data": ""
                })

            appVersion = request_params.pop('appVersion')

            if 'msgId' in request_params:
                msgId = request_params.pop('msgId')
                if len(msgId.strip()) == 0:
                    logger.exception("msgId is empty")
                    return jsonify({
                        "resCode": 400101,
                        "resMsg": "msgId is empty",
                        "data": ""
                    })
            else:
                logger.exception("msgId does not exist")
                return jsonify({
                    "resCode": 400103,
                    "resMsg": "msgId does not exist",
                    "data": ""
                })

            # Set the output channel
            out = CollectingOutputChannel()
            # Fetches the appropriate bot response in a json format

            try:
                responses = agent().handle_message(question,
                                                   output_channel=out,
                                                   sender_id=partyNo)

                # print("response: "+str(responses))

                text = str(responses[0]["text"])

                if text.find("<utter_restart>") >= 0:
                    data_json = {
                        "type": 1,
                        "msgType": "LU:CsMsg",
                        "content": str({
                            "title": "",
                            "content": "restart"
                        })
                    }
                else:
                    content = ''
                    link = []

                    text_new = eval(text)
                    if type(text_new) is list:
                        # print(text_new)
                        data_json = {
                            "type": 1,
                            "msgType": "LU:CsMsg",
                            "content": str({
                                "title": "",
                                "content": text_new[0]
                            })
                        }
                    elif type(text_new) is dict:
                        # print(text_new)
                        tuple_tmp = list(text_new.items())[0]
                        if tuple_tmp[0] == 'REC':
                            data_json = {
                                "type": 1,
                                "msgType": "LU:LucyRecommend",
                                "content": tuple_tmp[1]
                            }
                        elif tuple_tmp[0] == 'XIAOAN':
                            data_json = {
                                "type": 1,
                                "msgType": "XIAOAN",
                                "content": tuple_tmp[1]
                            }
                        elif tuple_tmp[0] == 'CONTENT_LINK':
                            tmp = eval(tuple_tmp[1])
                            data_json = {
                                "type":
                                1,
                                "msgType":
                                "LU:CsMsg",
                                "content":
                                str({
                                    "title": "",
                                    "content": tmp["content"],
                                    "link": tmp["link"]
                                })
                            }
                        elif tuple_tmp[0] == 'TO_PERSON':
                            data_json = {
                                "type":
                                0,
                                "msgType":
                                "LU:CsMsg",
                                "content":
                                str({
                                    "title": "",
                                    "content": tuple_tmp[1]
                                })
                            }
                        else:
                            print("output error format")
                    else:
                        print("output error format")

            except Exception as e:
                logger.exception("Caught an exception of eval(text) is error.")
                data_json = {
                    "type": 1,
                    "msgType": "LU:CsMsg",
                    "content": str({
                        "title": "",
                        "content": "test1"
                    })
                }

            # try:
            #     print(data_json)
            # except Exception as e:
            #     logger.exception("Caught an exception of data_json is None.")
            #     data_json={"type":1,"msgType":"LU:CsMsg", "content":str({ "title":"", "content":"test2"})}

            # sql = "INSERT INTO dialogue (msgId, partyNo, question,responses,time_stamp) VALUES ( %s, %s, %s, %s,%s)"
            sql = "INSERT INTO dialogue (msgId, partyNo, question,responses) VALUES ( %s, %s, %s, %s)"
            # list_data.append((msgId,partyNo,question,text,str(datetime.now())))

            list_data.append((msgId, partyNo, question, str(data_json)))

            # print("sql: "+sql)

            if len(list_data) >= 10:
                try:
                    connect = pool.connection(
                    )  #以后每次需要数据库连接就是用connection()函数获取连接就好了
                    cursor = connect.cursor()
                    cursor.executemany(sql, list_data)
                    connect.commit()
                    print('成功插入', cursor.rowcount, '条数据')
                    list_data = []
                except Exception as e:
                    # print(e.message)
                    logger.exception(
                        "Caught an exception during insert mysql.")
                    connect.rollback()  # 事务回滚
                cursor.close()
                connect.close()

            return jsonify({
                "resCode": 0,
                "resMsg": "success",
                "data": data_json
            })

        except Exception as e:
            logger.exception("Caught an exception during respond.")
            return jsonify({
                "resCode": 9999,
                "resMsg": "failure",
                "data": {
                    "type": 1,
                    "msgType": "LU:CsMsg",
                    "content": str({
                        "title": "",
                        "content": "test3"
                    })
                }
            })

    return app
def enable_logging():
    utils.configure_colored_logging(logging.DEBUG)
    utils.configure_file_logging(logging.DEBUG, './rasa_core.log')
    log = logging.getLogger('werkzeug')
    log.setLevel(logging.WARN)
Beispiel #10
0
def create_app(
        model_directory,  # type: Text
        interpreter=None,  # type: Union[Text, NLI, None]
        loglevel="INFO",  # type: Optional[Text]
        logfile="rasa_core.log",  # type: Optional[Text]
        cors_origins=None,  # type: Optional[List[Text]]
        action_factory=None,  # type: Optional[Text]
        auth_token=None,  # type: Optional[Text]
        tracker_store=None  # type: Optional[TrackerStore]
):
    """Class representing a Rasa Core HTTP server."""

    app = Flask(__name__)
    CORS(app, resources={r"/*": {"origins": "*"}})
    # Setting up logfile
    utils.configure_file_logging(loglevel, logfile)

    if not cors_origins:
        cors_origins = []
    model_directory = model_directory

    interpreter = interpreter

    tracker_store = tracker_store

    action_factory = action_factory

    # this needs to be an array, so we can modify it in the nested functions...
    _agent = [
        _create_agent(model_directory, interpreter, action_factory,
                      tracker_store)
    ]

    def agent():
        if _agent and _agent[0]:
            return _agent[0]
        else:
            return None

    @app.route("/", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def hello():
        """Check if the server is running and responds with the version."""
        return "hello from Rasa Core: " + __version__

    @app.route("/version", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def version():
        """respond with the version number of the installed rasa core."""

        return jsonify({'version': __version__})

    # <sender_id> can be be 'default' if there's only 1 client
    @app.route("/conversations/<sender_id>/continue",
               methods=['POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def continue_predicting(sender_id):
        """Continue a prediction started with parse.

        Caller should have executed the action returned from the parse
        endpoint. The events returned from that executed action are
        passed to continue which will trigger the next action prediction.

        If continue predicts action listen, the caller should wait for the
        next user message."""

        request_params = request.get_json(force=True)
        encoded_events = request_params.get("events", [])
        executed_action = request_params.get("executed_action", None)
        evts = events.deserialise_events(encoded_events)
        try:
            response = agent().continue_message_handling(
                sender_id, executed_action, evts)
        except ValueError as e:
            return Response(jsonify(error=e.message),
                            status=400,
                            content_type="application/json")
        except Exception as e:
            logger.exception(e)
            return Response(jsonify(error="Server failure. Error: {}"
                                    "".format(e)),
                            status=500,
                            content_type="application/json")
        return jsonify(response)

    @app.route("/conversations/<sender_id>/tracker/events",
               methods=['POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def append_events(sender_id):
        """Append a list of events to the state of a conversation"""

        request_params = request.get_json(force=True)
        evts = events.deserialise_events(request_params)
        tracker = agent().tracker_store.get_or_create_tracker(sender_id)
        for e in evts:
            tracker.update(e)
        tracker.replay_events()
        agent().tracker_store.save(tracker)
        return jsonify(tracker.current_state())

    @app.route("/conversations", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def list_trackers():
        return jsonify(list(agent().tracker_store.keys()))

    @app.route("/conversations/<sender_id>/tracker",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def retrieve_tracker(sender_id):
        """Get a dump of a conversations tracker including its events."""

        # parameters
        use_history = bool_arg('ignore_restarts', default=False)
        should_include_events = bool_arg('events', default=True)
        until_time = request.args.get('until', None)

        # retrieve tracker and set to requested state
        tracker = agent().tracker_store.get_or_create_tracker(sender_id)
        if until_time is not None:
            tracker = tracker.travel_back_in_time(float(until_time))

        # dump and return tracker
        state = tracker.current_state(
            should_include_events=should_include_events,
            only_events_after_latest_restart=use_history)
        return jsonify(state)

    @app.route("/conversations/<sender_id>/tracker",
               methods=['PUT', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def update_tracker(sender_id):
        """Use a list of events to set a conversations tracker to a state."""

        request_params = request.get_json(force=True)
        tracker = DialogueStateTracker.from_dict(sender_id, request_params,
                                                 agent().domain)
        agent().tracker_store.save(tracker)

        # will override an existing tracker with the same id!
        agent().tracker_store.save(tracker)
        return jsonify(tracker.current_state(should_include_events=True))

    @app.route("/conversations/<sender_id>/tracker/reset_intent",
               methods=['GET', 'PUT', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def replace_last_intent(sender_id):
        """Replaces the last intent from the user in the tracker state with the one provided
        and executes the next actions."""

        request_params = request_parameters()

        if 'intent' in request_params:
            intent = request_params.get('intent')
        else:
            return Response(jsonify(error="No intent parameter specified."),
                            status=400,
                            mimetype="application/json")

        tracker = agent().tracker_store.get_or_create_tracker(sender_id)

        while len(tracker.events) > 0 and not isinstance(
                tracker.events[-1], UserUttered):
            tracker.events.pop()

        if len(tracker.events) == 0:
            logger.debug("No user utterance in history of tracker")
            return jsonify(tracker.current_state())

        user_utterance = tracker.events.pop()
        tracker.update(
            UserUttered(text=user_utterance.text,
                        intent={
                            "name": intent,
                            "confidence": 1.0,
                        },
                        entities=user_utterance.entities))

        out = CollectingOutputChannel()
        message = UserMessage(text=user_utterance.text,
                              output_channel=out,
                              sender_id=sender_id)
        processor = agent()._create_processor()
        processor._predict_and_execute_next_action(message, tracker)
        agent().tracker_store.save(tracker)

        response = {
            "responses": message.output_channel.messages,
            "confidence": 1,
            "alternatives": [],
        }

        # save utterance in training data for nlu
        file_path = "/app/nlu/" + os.environ[
            "RASA_NLU_PROJECT_NAME"] + "/user_input/" + intent + ".md"
        add_intent_definition = False
        if not os.path.exists(file_path):
            add_intent_definition = True

        with open(file_path, "a") as file:
            if add_intent_definition:
                file.write("## intent:" + intent + "\n")
            file.write("- " + user_utterance.text + "\n")

        return jsonify(response)

    @app.route("/conversations/<sender_id>/alternatives",
               methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def alternatives(sender_id):
        return jsonify(_get_alternatives(agent(), sender_id))

    @app.route("/conversations/<sender_id>/parse",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def parse(sender_id):
        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.get('query')
        elif 'q' in request_params:
            message = request_params.get('q')
        else:
            return Response(
                jsonify(error="Invalid parse parameter specified."),
                status=400,
                mimetype="application/json")

        try:
            # Fetches the predicted action in a json format
            response = agent().start_message_handling(message, sender_id)
            return jsonify(response)

        except Exception as e:
            logger.exception("Caught an exception during parse.")
            return Response(jsonify(error="Server failure. Error: {}"
                                    "".format(e)),
                            status=500,
                            content_type="application/json")

    @app.route("/conversations/<sender_id>/respond",
               methods=['GET', 'POST', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    @requires_auth(auth_token)
    @ensure_loaded_agent(agent)
    def respond(sender_id):
        request_params = request_parameters()

        if 'query' in request_params:
            message = request_params.get('query')
        elif 'q' in request_params:
            message = request_params.get('q')
        else:
            return Response(jsonify(error="Invalid respond parameter "
                                    "specified."),
                            status=400,
                            mimetype="application/json")

        try:
            # Set the output channel
            out = CollectingOutputChannel()
            # Fetches the appropriate bot response in a json format
            responses = agent().handle_message(message,
                                               output_channel=out,
                                               sender_id=sender_id)
            tracker = agent().tracker_store.get_or_create_tracker(sender_id)
            result = {
                "responses": responses,
                "confidence": tracker.latest_message.intent["confidence"],
                "alternatives": _get_alternatives(agent(), sender_id)
            }
            return jsonify(result)

        except Exception as e:
            logger.exception("Caught an exception during respond.")
            return Response(jsonify(error="Server failure. Error: {}"
                                    "".format(e)),
                            status=500,
                            content_type="application/json")

    @app.route("/load", methods=['POST', 'OPTIONS'])
    @requires_auth(auth_token)
    @cross_origin(origins=cors_origins)
    def load_model():
        """Loads a zipped model, replacing the existing one."""

        if 'model' not in request.files:
            # model file is missing
            abort(400)

        model_file = request.files['model']

        logger.info("Received new model through REST interface.")
        zipped_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip")
        zipped_path.close()

        model_file.save(zipped_path.name)

        logger.debug("Downloaded model to {}".format(zipped_path.name))

        zip_ref = zipfile.ZipFile(zipped_path.name, 'r')
        zip_ref.extractall(model_directory)
        zip_ref.close()
        logger.debug("Unzipped model to {}".format(
            os.path.abspath(model_directory)))

        _agent[0] = _create_agent(model_directory, interpreter, action_factory,
                                  tracker_store)
        logger.debug("Finished loading new agent.")
        return jsonify({'success': 1})

    return app