Example #1
0
def endpoint_app(cors_origins=None, action_package_name=None):
    app = Flask(__name__)

    if not cors_origins:
        cors_origins = []

    executor = ActionExecutor()
    executor.register_package(action_package_name)

    CORS(app, resources={r"/*": {"origins": cors_origins}})

    @app.route("/health", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def health():
        """Ping endpoint to check if the server is running and well."""
        return jsonify({"status": "ok"})

    @app.route("/webhook", methods=['POST', 'OPTIONS'])
    @cross_origin()
    def webhook():
        """Webhook to retrieve action calls."""
        action_call = request.json
        check_version_compatibility(action_call.get("version"))
        try:
            response = executor.run(action_call)
        except ActionExecutionRejection as e:
            logger.error(str(e))
            result = {"error": str(e), "action_name": e.action_name}
            response = jsonify(result)
            response.status_code = 400
            return response

        return jsonify(response)

    return app
Example #2
0
def endpoint_app(cors_origins=None, action_package_name=None):
    app = Flask(__name__)

    if not cors_origins:
        cors_origins = []

    executor = ActionExecutor()
    executor.register_package(action_package_name)

    CORS(app, resources={r"/*": {"origins": cors_origins}})

    @app.route("/health", methods=['GET', 'OPTIONS'])
    @cross_origin(origins=cors_origins)
    def health():
        """Ping endpoint to check if the server is running and well."""
        return jsonify({"status": "ok"})

    @app.route("/webhook", methods=['POST', 'OPTIONS'])
    @cross_origin()
    def webhook():
        """Webhook to retrieve action calls."""
        action_call = request.json
        response = executor.run(action_call)

        return jsonify(response)

    return app
Example #3
0
def my_custom_action():
    from actions.procs.myactions import MyCustomAction

    executor = ActionExecutor()
    executor.register_action(MyCustomAction())
    domain = default_domain()
    req = {
        'domain': domain.as_dict(),
        'next_action': 'my_custom_action',
        'sender_id': 'default',
        'tracker': {
            'latest_message': {
                'entities': [],
                'intent': {},
                'text': None
            },
            'active_form': {},
            'latest_action_name': None,
            'sender_id': 'default',
            'paused': False,
            'followup_action': 'action_listen',
            'latest_event_time': None,
            'slots': {
                'name': None
            },
            'events': [],
            'latest_input_channel': None
        }
    }
    result = executor.run(req)
    dump(result)
Example #4
0
def test_abstract_action():
    executor = ActionExecutor()
    executor.register_package("tests")
    assert CustomAction.name() in executor.actions
    assert CustomActionBase.name() not in executor.actions

    dispatcher = CollectingDispatcher()
    tracker = Tracker("test", {}, {}, [], False, None, {}, "listen")
    domain = {}

    events = CustomAction().run(dispatcher, tracker, domain)
    assert events == [SlotSet("test", "test")]
Example #5
0
 def __init__(self):
     self.config = Config()
     self.endpoints = AvailableEndpoints.read_endpoints(
         self.config.RASA_CONFIG_ENDPOINTS_FILE)
     # self.interpreter = NaturalLanguageInterpreter.create(self.config.RASA_CONFIG_NLU_TRAIN_PACKAGE_NAME, self.endpoints.nlu)
     self.agent = load_agent(
         self.config.RASA_CONFIG_CORE_DIALOGUE_PACKAGE_NAME,
         interpreter=None,
         endpoints=self.endpoints)
     self.executor = ActionExecutor()
     self.executor.register_package(
         self.config.RASA_CONFIG_ENDPOINTS_ACTION_PACKAGE_NAME)
     self.message_processor = MessageProcessor(
         # self.interpreter,
         None,
         self.agent.policy_ensemble,
         self.agent.domain,
         self.agent.tracker_store,
         self.agent.nlg,
         action_endpoint=self.agent.action_endpoint,
         message_preprocessor=None)
Example #6
0
class Endpoints:
    def __init__(self):
        self.config = Config()
        self.endpoints = AvailableEndpoints.read_endpoints(
            self.config.RASA_CONFIG_ENDPOINTS_FILE)
        self.interpreter = NaturalLanguageInterpreter.create(
            self.config.RASA_CONFIG_NLU_TRAIN_PACKAGE_NAME, self.endpoints.nlu)
        self.agent = load_agent(
            self.config.RASA_CONFIG_CORE_DIALOGUE_PACKAGE_NAME,
            interpreter=None,
            endpoints=self.endpoints)
        self.executor = ActionExecutor()
        self.executor.register_package(
            self.config.RASA_CONFIG_ENDPOINTS_ACTION_PACKAGE_NAME)

        self.message_processor = MessageProcessor(
            # self.interpreter,
            None,
            self.agent.policy_ensemble,
            self.agent.domain,
            self.agent.tracker_store,
            self.agent.nlg,
            action_endpoint=self.agent.action_endpoint,
            message_preprocessor=self.agent.preprocessor)

    def execute_actions(self, action_call):
        try:
            response = self.executor.run(action_call)
        except ActionExecutionRejection as e:
            result = {"error": str(e), "action_name": e.action_name}
            response = jsonify(result)
            response.status_code = 400
            return response

        return jsonify(response)

    def event_verbosity_parameter(self, default_verbosity):
        event_verbosity_str = request.args.get(
            'include_events', default=default_verbosity.name).upper()
        try:
            return EventVerbosity[event_verbosity_str]
        except KeyError:
            enum_values = ", ".join([e.name for e in EventVerbosity])
            abort(
                error(
                    404, "InvalidParameter",
                    "Invalid parameter value for 'include_events'. "
                    "Should be one of {}".format(enum_values), {
                        "parameter": "include_events",
                        "in": "query"
                    }))

    def ask_for_action(self, action_name, action_endpoint):
        if action_name not in self.agent.domain.action_names:
            logger.warning("action not found")
            return None
        defaults = {a.name(): a for a in action.default_actions()}
        if action_name in defaults and action_name not in self.agent.domain.user_actions:
            return defaults.get(action_name)
        elif action_name.startswith("utter_"):
            return UtterAction(action_name)
        else:
            return RemoteAction(action_name, action_endpoint)

    def handle_actions(self, message, action_name):
        verbosity = self.event_verbosity_parameter(
            EventVerbosity.AFTER_RESTART)
        try:
            output_channel = CollectingOutputChannel()
            dispatcher = Dispatcher(message.sender_id, output_channel,
                                    self.agent.nlg)
            tracker = self.message_processor._get_tracker(message.sender_id)
            if tracker:
                #拿到action实例
                action = self.ask_for_action(
                    action_name, self.message_processor.action_endpoint,
                    self.ask_for_action)
                # action = self._get_action(action_name)
                self.message_processor._run_action(action, tracker, dispatcher)
                # save tracker state to continue conversation from this state
                self.message_processor._save_tracker(tracker)

            # retrieve tracker and set to requested state
            tracker = self.agent.tracker_store.get_or_create_tracker(
                message.sender_id)
            state = tracker.current_state(verbosity)
            return jsonify({
                "tracker": state,
                "messages": output_channel.messages
            })

        except ValueError as e:
            return error(400, "ValueError", e)
        except Exception as e:
            return error(500, "ValueError",
                         "Server failure. Error: {}".format(e))
Example #7
0
class RasaCore:
    def __init__(self):
        self.config = Config()
        self.endpoints = AvailableEndpoints.read_endpoints(
            self.config.RASA_CONFIG_ENDPOINTS_FILE)
        # self.interpreter = NaturalLanguageInterpreter.create(self.config.RASA_CONFIG_NLU_TRAIN_PACKAGE_NAME, self.endpoints.nlu)
        self.agent = load_agent(
            self.config.RASA_CONFIG_CORE_DIALOGUE_PACKAGE_NAME,
            interpreter=None,
            endpoints=self.endpoints)
        self.executor = ActionExecutor()
        self.executor.register_package(
            self.config.RASA_CONFIG_ENDPOINTS_ACTION_PACKAGE_NAME)
        self.message_processor = MessageProcessor(
            # self.interpreter,
            None,
            self.agent.policy_ensemble,
            self.agent.domain,
            self.agent.tracker_store,
            self.agent.nlg,
            action_endpoint=self.agent.action_endpoint,
            message_preprocessor=None)

    def handle_message(self, message):
        #message: UserMessage(text_message.get("text"),
        #             output_channel,
        #             sender_id)
        # out = CollectingOutputChannel()
        return self.message_processor.handle_message(message)

    # self.parse_data = {
    #     "intent": self.intent,
    #     "entities": self.entities,
    #     "text": text,
    # }
    def resolve_nlu_message(self, message):
        if message.parse_data:
            parse_data = message.parse_data
        else:
            parse_data = self.agent._parse_message(message)
        return parse_data

    def receive_nlu_message(self, message, parse_data):
        tracker = self.message_processor._get_tracker(message.sender_id)
        if tracker:
            tracker.update(
                UserUttered(message.text,
                            parse_data["intent"],
                            parse_data["entities"],
                            parse_data,
                            input_channel=message.input_channel))
            # store all entities as slots
            for e in self.agent.domain.slots_for_entities(
                    parse_data["entities"]):
                tracker.update(e)
            self.predict_and_execute_next_action(message, tracker)
            self.message_processor._save_tracker(tracker)
            if isinstance(message.output_channel, CollectingOutputChannel):
                return message.output_channel.messages
            else:
                return None
        return None

    def predict_and_execute_next_action(self, message, tracker):
        dispatcher = Dispatcher(message.sender_id, message.output_channel,
                                self.message_processor.nlg)
        # keep taking actions decided by the policy until it chooses to 'listen'
        should_predict_another_action = True
        num_predicted_actions = 0

        self.log_slots(tracker)
        # action loop. predicts actions until we hit action listen
        while (should_predict_another_action
               and self.should_handle_message(tracker)
               and num_predicted_actions <
               self.message_processor.max_number_of_predictions):
            # this actually just calls the policy's method by the same name
            probabilities, policy = self.message_processor._get_next_action_probabilities(
                tracker)
            max_index = int(np.argmax(probabilities))
            if self.message_processor.domain.num_actions <= max_index or max_index < 0:
                raise Exception("Can not access action at index {}. "
                                "Domain has {} actions.".format(
                                    max_index,
                                    self.message_processor.domain.num_actions))

            action = self.ask_for_action(
                self.message_processor.domain.action_names[max_index],
                self.message_processor.action_endpoint)
            confidence = probabilities[max_index]
            # action, policy, confidence = self.agent.predict_next_action(tracker)

            should_predict_another_action = self.run_action(
                action, tracker, dispatcher, policy, confidence)
            num_predicted_actions += 1

        if (num_predicted_actions
                == self.message_processor.max_number_of_predictions
                and should_predict_another_action):
            # circuit breaker was tripped
            if self.message_processor.on_circuit_break:
                # call a registered callback
                self.message_processor.on_circuit_break(tracker, dispatcher)

    def ask_for_action(self, action_name, action_endpoint):
        if action_name not in self.agent.domain.action_names:
            logger.warning("action not found")
            return None
        defaults = {a.name(): a for a in action.default_actions()}
        if action_name in defaults and action_name not in self.agent.domain.user_actions:
            return defaults.get(action_name)
        elif action_name.startswith("utter_"):
            return UtterAction(action_name)
        else:
            return RemoteAction(action_name, action_endpoint, self.executor)

    def should_handle_message(self, tracker):
        return (not tracker.is_paused()
                or tracker.latest_message.intent.get("name")
                == self.agent.domain.restart_intent)

    def log_slots(self, tracker):
        # Log currently set slots
        slot_values = "\n".join([
            "\t{}: {}".format(s.name, s.value) for s in tracker.slots.values()
        ])
        logger.debug("Current slot values: \n{}".format(slot_values))

    def run_action(self,
                   action,
                   tracker,
                   dispatcher,
                   policy=None,
                   confidence=None):
        # events and return values are used to update
        # the tracker state after an action has been taken
        try:
            events = action.run(dispatcher, tracker,
                                self.message_processor.domain)
        except Exception as e:
            logger.error(
                "Encountered an exception while running action '{}'. "
                "Bot will continue, but the actions events are lost. "
                "Make sure to fix the exception in your custom "
                "code.".format(action.name()), )
            logger.error(e, exc_info=True)
            events = []

        self.log_action_on_tracker(tracker, action.name(), events, policy,
                                   confidence)
        self.message_processor.log_bot_utterances_on_tracker(
            tracker, dispatcher)
        self.schedule_reminders(events, dispatcher)

        return self.message_processor.should_predict_another_action(
            action.name(), events)

    def schedule_reminders(self, events, dispatcher):
        # type: (List[Event], Dispatcher) -> None
        """Uses the scheduler to time a job to trigger the passed reminder.

        Reminders with the same `id` property will overwrite one another
        (i.e. only one of them will eventually run)."""

        if events is not None:
            for e in events:
                if isinstance(e, ReminderScheduled):
                    scheduler.add_job(self.message_processor.handle_reminder,
                                      "date",
                                      run_date=e.trigger_date_time,
                                      args=[e, dispatcher],
                                      id=e.name,
                                      replace_existing=True)

    def log_action_on_tracker(self, tracker, action_name, events, policy,
                              policy_confidence):
        # Ensures that the code still works even if a lazy programmer missed
        # to type `return []` at the end of an action or the run method
        # returns `None` for some other reason.
        if events is None:
            events = []

        logger.debug("Action '{}' ended with events '{}'".format(
            action_name, ['{}'.format(e) for e in events]))

        self.warn_about_new_slots(tracker, action_name, events)

        if action_name is not None:
            # log the action and its produced events
            tracker.update(
                ActionExecuted(action_name, policy, policy_confidence))

        for e in events:
            e.timestamp = time.time()
            tracker.update(e)

    def warn_about_new_slots(self, tracker, action_name, events):
        # these are the events from that action we have seen during training

        if action_name not in self.message_processor.policy_ensemble.action_fingerprints:
            return

        fp = self.message_processor.policy_ensemble.action_fingerprints[
            action_name]
        slots_seen_during_train = fp.get("slots", set())
        for e in events:
            if isinstance(e, SlotSet) and e.key not in slots_seen_during_train:
                s = tracker.slots.get(e.key)
                if s and s.has_features():
                    logger.warning("Action '{0}' set a slot type '{1}' that "
                                   "it never set during the training. This "
                                   "can throw of the prediction. Make sure to "
                                   "include training examples in your stories "
                                   "for the different types of slots this "
                                   "action can return. Remember: you need to "
                                   "set the slots manually in the stories by "
                                   "adding '- slot{{\"{1}\": {2}}}' "
                                   "after the action."
                                   "".format(action_name, e.key,
                                             json.dumps(e.value)))

    def execute_actions(self, action_call):
        try:
            response = self.executor.run(action_call)
        except ActionExecutionRejection as e:
            result = {"error": str(e), "action_name": e.action_name}
            response = jsonify(result)
            response.status_code = 400
            return response

        return jsonify(response)
import logging

from flask import Blueprint, request, jsonify
from flask_cors import CORS, cross_origin
from rasa_core_sdk.executor import ActionExecutor

from robot import app
from robot.config.setting import Config
from robot.exception import ActionExecutionRejection

logger = logging.getLogger(__name__)
endpoints_api = Blueprint('endpoints_api', __name__)
config = Config
executor = ActionExecutor()
executor.register_package(config.RASA_CONFIG_ENDPOINTS_ACTION_PACKAGE_NAME)
cors_origins = None or []
CORS(app, resources={r"/*": {"origins": []}})


@endpoints_api.route("/health", methods=['GET', 'OPTIONS'])
@cross_origin(origins=cors_origins)
def health():
    """Ping endpoint to check if the server is running and well."""
    return jsonify({"status": "ok"})


@endpoints_api.route("/webhook", methods=['POST', 'OPTIONS'])
@cross_origin()
def webhook():
    """Webhook to retrieve action calls."""
    action_call = request.json
Example #9
0
def get_action_executor():
    executor = ActionExecutor()
    #executor.register_package('actions')
    print('register actions')
    executor.register_package('actions')