def read_endpoints(endpoint_file): nlg = utils.read_endpoint_config(endpoint_file, endpoint_type="nlg") nlu = utils.read_endpoint_config(endpoint_file, endpoint_type="nlu") action = utils.read_endpoint_config(endpoint_file, endpoint_type="action_endpoint") model = utils.read_endpoint_config(endpoint_file, endpoint_type="models") return AvailableEndpoints(nlg, nlu, action, model)
def test_tracker_store_from_string(default_domain): endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml" store_config = utils.read_endpoint_config(endpoints_path, "tracker_store") tracker_store = TrackerStore.find_tracker_store(default_domain, store_config) assert isinstance(tracker_store, ExampleTrackerStore)
def test_file_broker_from_config(): cfg = utils.read_endpoint_config( "data/test_endpoints/event_brokers/" "file_endpoint.yml", "event_broker") actual = broker.from_endpoint_config(cfg) assert isinstance(actual, FileProducer) assert actual.path == "rasa_event.log"
def test_tracker_store_from_invalid_module(default_domain): endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml" store_config = utils.read_endpoint_config(endpoints_path, "tracker_store") store_config.type = "a.module.which.cannot.be.found" tracker_store = TrackerStore.find_tracker_store(default_domain, store_config) assert isinstance(tracker_store, InMemoryTrackerStore)
def test_tracker_store_from_invalid_string(default_domain): endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml" store_config = utils.read_endpoint_config(endpoints_path, "tracker_store") store_config.type = "any string" tracker_store = TrackerStore.find_tracker_store(default_domain, store_config) assert isinstance(tracker_store, InMemoryTrackerStore)
def create_tracker_store(core_model, endpoints): domain = get_domain(core_model) tracker_publish = utils.read_endpoint_config(endpoints, "tracker-publish") # Setup tracker store tracker_store = None store = utils.read_endpoint_config(endpoints, "tracker-store") if store.type == 'memory': tracker_store = InMemoryTrackerStore(domain=domain, publish_url=tracker_publish.url) elif store.type == 'redis': tracker_store = RedisTrackerStore(domain=domain, host=store.host, port=store.port, db=store.db, password=store.password, publish_url=tracker_publish.url) return tracker_store
def get_agent(): from rasa_nlu import utils, config from rasa_core.policies.fallback import FallbackPolicy from rasa_core.policies.keras_policy import KerasPolicy from rasa_core.interpreter import RasaNLUInterpreter from rasa_core.agent import Agent interpreter = RasaNLUInterpreter('sample/models/current/nlu/default/default') action_endpoint_conf = utils.read_endpoint_config(os.path.join(current_dir, "sample/core_config.yml"), endpoint_type="action_endpoint") agent = Agent.load("sample/models/current/dialogue", interpreter=interpreter, action_endpoint=action_endpoint_conf) return agent
def test_broker_from_config(): cfg = utils.read_endpoint_config(EVENT_BROKER_ENDPOINT_FILE, "event_broker") actual = PikaProducer.from_endpoint_config(cfg) expected = PikaProducer("localhost", "username", "password", "queue") assert actual.host == expected.host assert actual.credentials.username == expected.credentials.username assert actual.queue == expected.queue
def test_pika_broker_from_config(): cfg = utils.read_endpoint_config( 'data/test_endpoints/event_brokers/' 'pika_endpoint.yml', "event_broker") actual = broker.from_endpoint_config(cfg) assert isinstance(actual, PikaProducer) assert actual.host == "localhost" assert actual.credentials.username == "username" assert actual.queue == "queue"
def test_tracker_store_endpoint_config_loading(): cfg = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "tracker_store") assert cfg == EndpointConfig.from_dict({ "type": "redis", "url": "localhost", "port": 6379, "db": 0, "password": "******", "timeout": 30000 })
def test_find_tracker_store(default_domain): store = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "tracker_store") tracker_store = RedisTrackerStore(domain=default_domain, host="localhost", port=6379, db=0, password="******", record_exp=3000) assert isinstance( tracker_store, type(TrackerStore.find_tracker_store(default_domain, store)))
def read_endpoints(cls, endpoint_file): nlg = read_endpoint_config( endpoint_file, endpoint_type="nlg") nlu = read_endpoint_config( endpoint_file, endpoint_type="nlu") action = read_endpoint_config( endpoint_file, endpoint_type="action_endpoint") model = read_endpoint_config( endpoint_file, endpoint_type="models") tracker_store = read_endpoint_config( endpoint_file, endpoint_type="tracker_store") event_broker = read_endpoint_config( endpoint_file, endpoint_type="event_broker") rules = read_endpoint_config( endpoint_file, endpoint_type="rules") credentials = read_endpoint_config( endpoint_file, endpoint_type="credentials") nlu_models_info = read_endpoint_config( endpoint_file, endpoint_type="nlu_models_info") return cls(nlg, nlu, action, model, tracker_store, event_broker, rules, credentials, nlu_models_info)
def _read_endpoints(endpoint_file): """Read different endpoints from a file. Possible endpoints are: --------- nlg url to a nlg-server nlu url to a nlu-server action url to a custom-action-server model url from which to fetch the core-model Parameters: ---------- endpoint_file: str path to yaml file which contains the endpoints Returns: ---------- AvailableEndpoints tuple with the endpoints nlg, nlu, action and model """ AvailableEndpoints = namedtuple('AvailableEndpoints', 'nlg ' 'nlu ' 'action ' 'model') nlg = utils.read_endpoint_config(endpoint_file, endpoint_type="nlg") nlu = utils.read_endpoint_config(endpoint_file, endpoint_type="nlu") action = utils.read_endpoint_config(endpoint_file, endpoint_type="action_endpoint") model = utils.read_endpoint_config(endpoint_file, endpoint_type="model") return AvailableEndpoints(nlg, nlu, action, model)
def train_dialogue_model(domain_file, stories_file, output_path, nlu_model_path=None, endpoints=None, max_history=None, dump_flattened_stories=False, kwargs=None): if not kwargs: kwargs = {} action_endpoint = utils.read_endpoint_config(endpoints, "action_endpoint") fallback_args, kwargs = utils.extract_args(kwargs, {"nlu_threshold", "core_threshold", "fallback_action_name"}) policies = [ FallbackPolicy( fallback_args.get("nlu_threshold", DEFAULT_NLU_FALLBACK_THRESHOLD), fallback_args.get("core_threshold", DEFAULT_CORE_FALLBACK_THRESHOLD), fallback_args.get("fallback_action_name", DEFAULT_FALLBACK_ACTION)), MemoizationPolicy( max_history=max_history), KerasPolicy( MaxHistoryTrackerFeaturizer(BinarySingleStateFeaturizer(), max_history=max_history))] agent = Agent(domain_file, action_endpoint=action_endpoint, interpreter=nlu_model_path, policies=policies) data_load_args, kwargs = utils.extract_args(kwargs, {"use_story_concatenation", "unique_last_num_states", "augmentation_factor", "remove_duplicates", "debug_plots"}) training_data = agent.load_data(stories_file, **data_load_args) agent.train(training_data, **kwargs) agent.persist(output_path, dump_flattened_stories) return agent
def test_kafka_broker_from_config(): endpoints_path = 'data/test_endpoints/event_brokers/' \ 'kafka_plaintext_endpoint.yml' cfg = utils.read_endpoint_config(endpoints_path, "event_broker") actual = KafkaProducer.from_endpoint_config(cfg) expected = KafkaProducer("localhost", "username", "password", topic="topic", security_protocol="SASL_PLAINTEXT") assert actual.host == expected.host assert actual.sasl_username == expected.sasl_username assert actual.sasl_password == expected.sasl_password assert actual.topic == expected.topic
def recreate_agent( model_directory, # type: Text nlu_model=None, # type: Optional[Text] tracker_dump=None, # type: Optional[Text] endpoints=None): # type: (...) -> Tuple[Agent, DialogueStateTracker] """Recreate an agent instance.""" nlg_endpoint = utils.read_endpoint_config(endpoints, "nlg") logger.debug("Loading Rasa Core Agent") agent = Agent.load(model_directory, nlu_model, generator=nlg_endpoint) logger.debug("Finished loading agent. Loading stories now.") tracker = load_tracker_from_json(tracker_dump, agent.domain) replay_events(tracker, agent) return agent, tracker
def recreate_agent(model_directory, # type: Text nlu_model=None, # type: Optional[Text] tracker_dump=None, # type: Optional[Text] endpoints=None ): # type: (...) -> Tuple[Agent, DialogueStateTracker] """Recreate an agent instance.""" nlg_endpoint = utils.read_endpoint_config(endpoints, "nlg") logger.debug("Loading Rasa Core Agent") agent = Agent.load(model_directory, nlu_model, generator=nlg_endpoint) logger.debug("Finished loading agent. Loading stories now.") tracker = load_tracker_from_json(tracker_dump, agent.domain) replay_events(tracker, agent) return agent, tracker
log.setLevel(logging.WARN) logger.info("Rasa process starting") interpreter = interpreter_from_args(nlu_model, nlu_endpoint) agent = Agent.load(model_directory, interpreter, generator=nlg_endpoint) logger.info("Finished loading agent, starting input channel & server.") if channel: input_channel = create_input_channel(channel, port, credentials_file) agent.handle_channel(input_channel) return agent if __name__ == '__main__': # Running as standalone python application arg_parser = create_argument_parser() cmdline_args = arg_parser.parse_args() utils.configure_colored_logging(cmdline_args.loglevel) utils.configure_file_logging(cmdline_args.loglevel, cmdline_args.log_file) nlg_endpoint = utils.read_endpoint_config(cmdline_args.endpoints, "nlg") nlu_endpoint = utils.read_endpoint_config(cmdline_args.endpoints, "nlu") main(cmdline_args.core, cmdline_args.nlu, cmdline_args.connector, cmdline_args.port, cmdline_args.credentials, nlg_endpoint, nlu_endpoint)
def test_nlg_endpoint_config_loading(): cfg = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "nlg") assert cfg == EndpointConfig.from_dict( {"url": "http://localhost:5055/nlg"})
def test_nlg_endpoint_config_loading(): cfg = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "nlg") assert cfg == EndpointConfig.from_dict({ "url": "http://localhost:5055/nlg" })
def create_app( model_directory, # type: Text interpreter=None, # type: Union[Text, NLI, None] loglevel="INFO", # type: Optional[Text] logfile="rasa_core.log", # type: Optional[Text] cors_origins=None, # type: Optional[List[Text]] action_factory=None, # type: Optional[Text] auth_token=None, # type: Optional[Text] tracker_store=None, # type: Optional[TrackerStore] endpoints=None): """Class representing a Rasa Core HTTP server.""" app = Flask(__name__) CORS(app, resources={r"/*": {"origins": "*"}}) # Setting up logfile utils.configure_file_logging(loglevel, logfile) if not cors_origins: cors_origins = [] model_directory = model_directory nlg_endpoint = utils.read_endpoint_config(endpoints, "nlg") nlu_endpoint = utils.read_endpoint_config(endpoints, "nlu") tracker_store = tracker_store action_factory = action_factory _interpreter = run.interpreter_from_args(interpreter, nlu_endpoint) # this needs to be an array, so we can modify it in the nested functions... _agent = [ _create_agent(model_directory, _interpreter, action_factory, tracker_store, nlg_endpoint) ] def agent(): if _agent and _agent[0]: return _agent[0] else: return None @app.route("/", methods=['GET', 'OPTIONS']) @cross_origin(origins=cors_origins) def hello(): """Check if the server is running and responds with the version.""" return "hello from Rasa Core: " + __version__ @app.route("/version", methods=['GET', 'OPTIONS']) @cross_origin(origins=cors_origins) def version(): """respond with the version number of the installed rasa core.""" return jsonify({'version': __version__}) # <sender_id> can be be 'default' if there's only 1 client @app.route("/conversations/<sender_id>/continue", methods=['POST', 'OPTIONS']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def continue_predicting(sender_id): """Continue a prediction started with parse. Caller should have executed the action returned from the parse endpoint. The events returned from that executed action are passed to continue which will trigger the next action prediction. If continue predicts action listen, the caller should wait for the next user message.""" request_params = request.get_json(force=True) encoded_events = request_params.get("events", []) executed_action = request_params.get("executed_action", None) evts = events.deserialise_events(encoded_events) try: response = agent().continue_message_handling( sender_id, executed_action, evts) except ValueError as e: return Response(jsonify(error=e.message), status=400, content_type="application/json") except Exception as e: logger.exception(e) return Response(jsonify(error="Server failure. Error: {}" "".format(e)), status=500, content_type="application/json") return jsonify(response) @app.route("/conversations/<sender_id>/tracker/events", methods=['POST', 'OPTIONS']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def append_events(sender_id): """Append a list of events to the state of a conversation""" request_params = request.get_json(force=True) evts = events.deserialise_events(request_params) tracker = agent().tracker_store.get_or_create_tracker(sender_id) for e in evts: tracker.update(e) agent().tracker_store.save(tracker) return jsonify(tracker.current_state()) @app.route("/conversations", methods=['GET', 'OPTIONS']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def list_trackers(): return jsonify(list(agent().tracker_store.keys())) @app.route("/conversations/<sender_id>/tracker", methods=['GET', 'OPTIONS']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def retrieve_tracker(sender_id): """Get a dump of a conversations tracker including its events.""" # parameters use_history = bool_arg('ignore_restarts', default=False) should_include_events = bool_arg('events', default=True) until_time = request.args.get('until', None) # retrieve tracker and set to requested state tracker = agent().tracker_store.get_or_create_tracker(sender_id) if until_time is not None: tracker = tracker.travel_back_in_time(float(until_time)) # dump and return tracker state = tracker.current_state( should_include_events=should_include_events, only_events_after_latest_restart=use_history) return jsonify(state) @app.route("/conversations/<sender_id>/tracker", methods=['PUT', 'OPTIONS']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def update_tracker(sender_id): """Use a list of events to set a conversations tracker to a state.""" request_params = request.get_json(force=True) tracker = DialogueStateTracker.from_dict(sender_id, request_params, agent().domain) agent().tracker_store.save(tracker) # will override an existing tracker with the same id! agent().tracker_store.save(tracker) return jsonify(tracker.current_state(should_include_events=True)) @app.route("/domain", methods=['GET']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def get_domain(): """Get current domain in yaml format.""" accepts = request.headers.get("Accept", default="application/json") if accepts.endswith("json"): domain = agent().domain.as_dict() return jsonify(domain) elif accepts.endswith("yml"): domain_yaml = agent().domain.as_yaml() return Response(domain_yaml, status=200, content_type="application/x-yml") else: return Response("""Invalid accept header. Domain can be provided as json ("Accept: application/json") or yml ("Accept: application/x-yml"). Make sure you've set the appropriate Accept header.""", status=406) @app.route("/conversations/<sender_id>/parse", methods=['GET', 'POST', 'OPTIONS']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def parse(sender_id): request_params = request_parameters() if 'query' in request_params: message = request_params.pop('query') elif 'q' in request_params: message = request_params.pop('q') else: return Response( jsonify(error="Invalid parse parameter specified."), status=400, mimetype="application/json") try: # Fetches the predicted action in a json format response = agent().start_message_handling(message, sender_id) return jsonify(response) except Exception as e: logger.exception("Caught an exception during parse.") return Response(jsonify(error="Server failure. Error: {}" "".format(e)), status=500, content_type="application/json") @app.route("/conversations/<sender_id>/respond", methods=['GET', 'POST', 'OPTIONS']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def respond(sender_id): request_params = request_parameters() if 'query' in request_params: message = request_params.pop('query') elif 'q' in request_params: message = request_params.pop('q') else: return Response(jsonify(error="Invalid respond parameter " "specified."), status=400, mimetype="application/json") try: # Set the output channel out = CollectingOutputChannel() # Fetches the appropriate bot response in a json format responses = agent().handle_message(message, output_channel=out, sender_id=sender_id) return jsonify(responses) except Exception as e: logger.exception("Caught an exception during respond.") return Response(jsonify(error="Server failure. Error: {}" "".format(e)), status=500, content_type="application/json") @app.route("/load", methods=['POST', 'OPTIONS']) @requires_auth(auth_token) @cross_origin(origins=cors_origins) def load_model(): """Loads a zipped model, replacing the existing one.""" if 'model' not in request.files: # model file is missing abort(400) model_file = request.files['model'] logger.info("Received new model through REST interface.") zipped_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip") zipped_path.close() model_file.save(zipped_path.name) logger.debug("Downloaded model to {}".format(zipped_path.name)) zip_ref = zipfile.ZipFile(zipped_path.name, 'r') zip_ref.extractall(model_directory) zip_ref.close() logger.debug("Unzipped model to {}".format( os.path.abspath(model_directory))) _agent[0] = _create_agent(model_directory, interpreter, action_factory, tracker_store, nlg_endpoint) logger.debug("Finished loading new agent.") return jsonify({'success': 1}) return app
def test_no_broker_in_config(): cfg = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "event_broker") actual = broker.from_endpoint_config(cfg) assert actual is None
def create_app(model_directory, # type: Text interpreter=None, # type: Union[Text, NLI, None] loglevel="INFO", # type: Optional[Text] logfile="rasa_core.log", # type: Optional[Text] cors_origins=None, # type: Optional[List[Text]] action_factory=None, # type: Optional[Text] auth_token=None, # type: Optional[Text] tracker_store=None, # type: Optional[TrackerStore] endpoints=None ): """Class representing a Rasa Core HTTP server.""" app = Flask(__name__) CORS(app, resources={r"/*": {"origins": "*"}}) # Setting up logfile utils.configure_file_logging(loglevel, logfile) if not cors_origins: cors_origins = [] model_directory = model_directory nlg_endpoint = utils.read_endpoint_config(endpoints, "nlg") nlu_endpoint = utils.read_endpoint_config(endpoints, "nlu") tracker_store = tracker_store action_factory = action_factory _interpreter = run.interpreter_from_args(interpreter, nlu_endpoint) # this needs to be an array, so we can modify it in the nested functions... _agent = [_create_agent(model_directory, _interpreter, action_factory, tracker_store, nlg_endpoint)] def agent(): if _agent and _agent[0]: return _agent[0] else: return None @app.route("/", methods=['GET', 'OPTIONS']) @cross_origin(origins=cors_origins) def hello(): """Check if the server is running and responds with the version.""" return "hello from Rasa Core: " + __version__ @app.route("/version", methods=['GET', 'OPTIONS']) @cross_origin(origins=cors_origins) def version(): """respond with the version number of the installed rasa core.""" return jsonify({'version': __version__}) # <sender_id> can be be 'default' if there's only 1 client @app.route("/conversations/<sender_id>/continue", methods=['POST', 'OPTIONS']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def continue_predicting(sender_id): """Continue a prediction started with parse. Caller should have executed the action returned from the parse endpoint. The events returned from that executed action are passed to continue which will trigger the next action prediction. If continue predicts action listen, the caller should wait for the next user message.""" request_params = request.get_json(force=True) encoded_events = request_params.get("events", []) executed_action = request_params.get("executed_action", None) evts = events.deserialise_events(encoded_events) try: response = agent().continue_message_handling(sender_id, executed_action, evts) except ValueError as e: return Response(jsonify(error=e.message), status=400, content_type="application/json") except Exception as e: logger.exception(e) return Response(jsonify(error="Server failure. Error: {}" "".format(e)), status=500, content_type="application/json") return jsonify(response) @app.route("/conversations/<sender_id>/tracker/events", methods=['POST', 'OPTIONS']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def append_events(sender_id): """Append a list of events to the state of a conversation""" request_params = request.get_json(force=True) evts = events.deserialise_events(request_params) tracker = agent().tracker_store.get_or_create_tracker(sender_id) for e in evts: tracker.update(e) agent().tracker_store.save(tracker) return jsonify(tracker.current_state()) @app.route("/conversations", methods=['GET', 'OPTIONS']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def list_trackers(): return jsonify(list(agent().tracker_store.keys())) @app.route("/conversations/<sender_id>/tracker", methods=['GET', 'OPTIONS']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def retrieve_tracker(sender_id): """Get a dump of a conversations tracker including its events.""" # parameters use_history = bool_arg('ignore_restarts', default=False) should_include_events = bool_arg('events', default=True) until_time = request.args.get('until', None) # retrieve tracker and set to requested state tracker = agent().tracker_store.get_or_create_tracker(sender_id) if until_time is not None: tracker = tracker.travel_back_in_time(float(until_time)) # dump and return tracker state = tracker.current_state( should_include_events=should_include_events, only_events_after_latest_restart=use_history) return jsonify(state) @app.route("/conversations/<sender_id>/tracker", methods=['PUT', 'OPTIONS']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def update_tracker(sender_id): """Use a list of events to set a conversations tracker to a state.""" request_params = request.get_json(force=True) tracker = DialogueStateTracker.from_dict(sender_id, request_params, agent().domain) agent().tracker_store.save(tracker) # will override an existing tracker with the same id! agent().tracker_store.save(tracker) return jsonify(tracker.current_state(should_include_events=True)) @app.route("/domain", methods=['GET']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def get_domain(): """Get current domain in yaml format.""" accepts = request.headers.get("Accept", default="application/json") if accepts.endswith("json"): domain = agent().domain.as_dict() return jsonify(domain) elif accepts.endswith("yml"): domain_yaml = agent().domain.as_yaml() return Response(domain_yaml, status=200, content_type="application/x-yml") else: return Response( """Invalid accept header. Domain can be provided as json ("Accept: application/json") or yml ("Accept: application/x-yml"). Make sure you've set the appropriate Accept header.""", status=406) @app.route("/conversations/<sender_id>/parse", methods=['GET', 'POST', 'OPTIONS']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def parse(sender_id): request_params = request_parameters() if 'query' in request_params: message = request_params.pop('query') elif 'q' in request_params: message = request_params.pop('q') else: return Response( jsonify(error="Invalid parse parameter specified."), status=400, mimetype="application/json") try: # Fetches the predicted action in a json format response = agent().start_message_handling(message, sender_id) return jsonify(response) except Exception as e: logger.exception("Caught an exception during parse.") return Response(jsonify(error="Server failure. Error: {}" "".format(e)), status=500, content_type="application/json") @app.route("/conversations/<sender_id>/respond", methods=['GET', 'POST', 'OPTIONS']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def respond(sender_id): request_params = request_parameters() if 'query' in request_params: message = request_params.pop('query') elif 'q' in request_params: message = request_params.pop('q') else: return Response(jsonify(error="Invalid respond parameter " "specified."), status=400, mimetype="application/json") try: # Set the output channel out = CollectingOutputChannel() # Fetches the appropriate bot response in a json format responses = agent().handle_message(message, output_channel=out, sender_id=sender_id) return jsonify(responses) except Exception as e: logger.exception("Caught an exception during respond.") return Response(jsonify(error="Server failure. Error: {}" "".format(e)), status=500, content_type="application/json") @app.route("/load", methods=['POST', 'OPTIONS']) @requires_auth(auth_token) @cross_origin(origins=cors_origins) def load_model(): """Loads a zipped model, replacing the existing one.""" if 'model' not in request.files: # model file is missing abort(400) model_file = request.files['model'] logger.info("Received new model through REST interface.") zipped_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip") zipped_path.close() model_file.save(zipped_path.name) logger.debug("Downloaded model to {}".format(zipped_path.name)) zip_ref = zipfile.ZipFile(zipped_path.name, 'r') zip_ref.extractall(model_directory) zip_ref.close() logger.debug("Unzipped model to {}".format( os.path.abspath(model_directory))) _agent[0] = _create_agent(model_directory, interpreter, action_factory, tracker_store, nlg_endpoint) logger.debug("Finished loading new agent.") return jsonify({'success': 1}) return app
def create_app( model_directory, # type: Text interpreter=None, # type: Union[Text, NLI, None] loglevel="INFO", # type: Optional[Text] logfile="rasa_core.log", # type: Optional[Text] cors_origins=None, # type: Optional[List[Text]] action_factory=None, # type: Optional[Text] auth_token=None, # type: Optional[Text] tracker_store=None, # type: Optional[TrackerStore] endpoints=None): """Class representing a Rasa Core HTTP server.""" app = Flask(__name__) CORS(app, resources={r"/*": {"origins": "*"}}) # Setting up logfile utils.configure_file_logging(loglevel, logfile) if not cors_origins: cors_origins = [] model_directory = model_directory nlg_endpoint = utils.read_endpoint_config(endpoints, "nlg") nlu_endpoint = utils.read_endpoint_config(endpoints, "nlu") tracker_store = tracker_store action_factory = action_factory _interpreter = run.interpreter_from_args(interpreter, nlu_endpoint) # this needs to be an array, so we can modify it in the nested functions... _agent = [ _create_agent(model_directory, _interpreter, action_factory, tracker_store, nlg_endpoint) ] def agent(): if _agent and _agent[0]: return _agent[0] else: return None @app.route("/", methods=['GET', 'OPTIONS']) @cross_origin(origins=cors_origins) def hello(): """Check if the server is running and responds with the version.""" return "hello from Rasa Core: " + __version__ @app.route("/status/health.check", methods=['GET', 'OPTIONS']) @cross_origin(origins=cors_origins) def version(): """respond with the version number of the installed rasa core.""" # return jsonify({'version': __version__}) return "OK:dlg-app:reg_20181121_01" @app.route("/dlg-app/service/chat/v1/im", methods=['GET', 'POST', 'OPTIONS']) @cross_origin(origins=cors_origins) @requires_auth(auth_token) @ensure_loaded_agent(agent) def respond(): global list_data try: request_params = request_parameters() if 'partyNo' in request_params: partyNo = request_params.pop('partyNo') if len(partyNo.strip()) == 0: logger.exception("partyNo is empty") return jsonify({ "resCode": 400101, "resMsg": "partyNo is empty", "data": "" }) else: logger.exception("partyNo does not exist") return jsonify({ "resCode": 400102, "resMsg": "partyNo [1] does not exist", "data": "" }) if 'question' in request_params: question = request_params.pop('question') if len(question.strip()) == 0: logger.exception("question is empty") return jsonify({ "resCode": 0, "resMsg": "question is empty", "data": { "type": 1, "msgType": "LU:CsMsg", "content": str({ "title": "", "content": "question is empty" }) } }) if 'channel' in request_params: channel = request_params.pop('channel') if len(channel.strip()) == 0: logger.exception("channel is empty") return jsonify({ "resCode": 400101, "resMsg": "channel is empty", "data": "" }) else: logger.exception("channel does not exist") return jsonify({ "resCode": 400103, "resMsg": "channel [3] does not exist", "data": "" }) appVersion = request_params.pop('appVersion') if 'msgId' in request_params: msgId = request_params.pop('msgId') if len(msgId.strip()) == 0: logger.exception("msgId is empty") return jsonify({ "resCode": 400101, "resMsg": "msgId is empty", "data": "" }) else: logger.exception("msgId does not exist") return jsonify({ "resCode": 400103, "resMsg": "msgId does not exist", "data": "" }) # Set the output channel out = CollectingOutputChannel() # Fetches the appropriate bot response in a json format try: responses = agent().handle_message(question, output_channel=out, sender_id=partyNo) # print("response: "+str(responses)) text = str(responses[0]["text"]) if text.find("<utter_restart>") >= 0: data_json = { "type": 1, "msgType": "LU:CsMsg", "content": str({ "title": "", "content": "restart" }) } else: content = '' link = [] text_new = eval(text) if type(text_new) is list: # print(text_new) data_json = { "type": 1, "msgType": "LU:CsMsg", "content": str({ "title": "", "content": text_new[0] }) } elif type(text_new) is dict: # print(text_new) tuple_tmp = list(text_new.items())[0] if tuple_tmp[0] == 'REC': data_json = { "type": 1, "msgType": "LU:LucyRecommend", "content": tuple_tmp[1] } elif tuple_tmp[0] == 'XIAOAN': data_json = { "type": 1, "msgType": "XIAOAN", "content": tuple_tmp[1] } elif tuple_tmp[0] == 'CONTENT_LINK': tmp = eval(tuple_tmp[1]) data_json = { "type": 1, "msgType": "LU:CsMsg", "content": str({ "title": "", "content": tmp["content"], "link": tmp["link"] }) } elif tuple_tmp[0] == 'TO_PERSON': data_json = { "type": 0, "msgType": "LU:CsMsg", "content": str({ "title": "", "content": tuple_tmp[1] }) } else: print("output error format") else: print("output error format") except Exception as e: logger.exception("Caught an exception of eval(text) is error.") data_json = { "type": 1, "msgType": "LU:CsMsg", "content": str({ "title": "", "content": "test1" }) } # try: # print(data_json) # except Exception as e: # logger.exception("Caught an exception of data_json is None.") # data_json={"type":1,"msgType":"LU:CsMsg", "content":str({ "title":"", "content":"test2"})} # sql = "INSERT INTO dialogue (msgId, partyNo, question,responses,time_stamp) VALUES ( %s, %s, %s, %s,%s)" sql = "INSERT INTO dialogue (msgId, partyNo, question,responses) VALUES ( %s, %s, %s, %s)" # list_data.append((msgId,partyNo,question,text,str(datetime.now()))) list_data.append((msgId, partyNo, question, str(data_json))) # print("sql: "+sql) if len(list_data) >= 10: try: connect = pool.connection( ) #以后每次需要数据库连接就是用connection()函数获取连接就好了 cursor = connect.cursor() cursor.executemany(sql, list_data) connect.commit() print('成功插入', cursor.rowcount, '条数据') list_data = [] except Exception as e: # print(e.message) logger.exception( "Caught an exception during insert mysql.") connect.rollback() # 事务回滚 cursor.close() connect.close() return jsonify({ "resCode": 0, "resMsg": "success", "data": data_json }) except Exception as e: logger.exception("Caught an exception during respond.") return jsonify({ "resCode": 9999, "resMsg": "failure", "data": { "type": 1, "msgType": "LU:CsMsg", "content": str({ "title": "", "content": "test3" }) } }) return app
logger.info("Finished loading agent, starting input channel & server.") if channel: input_channel = create_input_channel(channel, port, credentials_file) agent.handle_channel(input_channel) return agent if __name__ == '__main__': # Running as standalone python application arg_parser = create_argument_parser() cmdline_args = arg_parser.parse_args() utils.configure_colored_logging(cmdline_args.loglevel) utils.configure_file_logging(cmdline_args.loglevel, cmdline_args.log_file) nlg_endpoint = utils.read_endpoint_config(cmdline_args.endpoints, "nlg") nlu_endpoint = utils.read_endpoint_config(cmdline_args.endpoints, "nlu") main(cmdline_args.core, cmdline_args.nlu, cmdline_args.connector, cmdline_args.port, cmdline_args.credentials, nlg_endpoint, nlu_endpoint)