예제 #1
0
def test_restart(default_dispatcher, default_domain):
    tracker = DialogueStateTracker("default",
                                   default_domain.slots,
                                   default_domain.topics,
                                   default_domain.default_topic)
    events = ActionRestart().run(default_dispatcher, tracker, default_domain)
    assert events == [Restarted()]
예제 #2
0
    def _get_next_action(self, tracker, inputQueue=None, outputQueue=None):
        # type: (DialogueStateTracker) -> Action

        follow_up_action = tracker.follow_up_action
        if follow_up_action:
            tracker.clear_follow_up_action()
            if self.domain.index_for_action(
                    follow_up_action.name()) is not None:
                return follow_up_action
            else:
                logger.error(
                    "Trying to run unknown follow up action '{}'!"
                    "Instead of running that, we will ignore the action "
                    "and predict the next action.".format(follow_up_action))

        if (tracker.latest_message.intent.get("name") ==
                self.domain.restart_intent):
            return ActionRestart()

        idx = self.policy_ensemble.predict_next_action(tracker, self.domain,
                                                       inputQueue, outputQueue)
        logger.info("inside get next action")
        logger.info(idx)
        logger.info(self.domain.action_for_index(idx))

        return self.domain.action_for_index(idx)
예제 #3
0
    def modified_predict_and_execute_next_action(self, message, tracker):
        # this will actually send the response to the user
        response = {}
        dispatcher = Dispatcher(message.sender_id, message.output_channel,
                                self.domain)
        # keep taking actions decided by the policy until it chooses to 'listen'
        should_predict_another_action = True
        num_predicted_actions = 0

        self._log_slots(tracker)

        # action loop. predicts actions until we hit action listen
        while (should_predict_another_action
               and self._should_handle_message(tracker)
               and num_predicted_actions < self.max_number_of_predictions):
            # this actually just calls the policy's method by the same name
            action = self._get_next_action(tracker)
            if action.name() != "action_listen":
                response["next_action"] = action.name()
            should_predict_another_action = self._run_action(
                action, tracker, dispatcher)
            num_predicted_actions += 1

        if (num_predicted_actions == self.max_number_of_predictions
                and should_predict_another_action):
            # circuit breaker was tripped
            logger.warn("Circuit breaker tripped. Stopped predicting "
                        "more actions for sender '{}'".format(
                            tracker.sender_id))
            if self.on_circuit_break:
                # call a registered callback
                self.on_circuit_break(tracker, dispatcher)
            response["tracker"] = tracker.current_state()

            # we used contexts and next action has changed so we need to change
            # intent accordingly
            tmp = response["tracker"]["latest_message"]["intent_ranking"][0][
                "name"]
            response["tracker"]["latest_message"]["intent_ranking"][0][
                "name"] = response["tracker"]["latest_message"][
                    "intent_ranking"][1]["name"]
            response["tracker"]["latest_message"]["intent_ranking"][1][
                "name"] = tmp
            response["tracker"]["latest_message"]["intent"] = response[
                "tracker"]["latest_message"]["intent_ranking"][0]

            # added to restart the bot when it becomes unstable when a followup action is triggered
            # due to contexts
            should_predict_another_action = self._run_action(
                ActionRestart(), tracker, dispatcher)
            return response

        response["tracker"] = tracker.current_state()
        return response
예제 #4
0
    def _predict_and_execute_next_action(self, message, tracker):
        # this will actually send the response to the user

        dispatcher = Dispatcher(message.sender_id, message.output_channel,
                                self.domain)
        # keep taking actions decided by the policy until it chooses to 'listen'
        should_predict_another_action = True
        num_predicted_actions = 0

        self._log_slots(tracker)

        # action loop. predicts actions until we hit action listen
        while (should_predict_another_action
               and self._should_handle_message(tracker)
               and num_predicted_actions < self.max_number_of_predictions):
            # this actually just calls the policy's method by the same name
            action = self._get_next_action(tracker)

            should_predict_another_action = self._run_action(
                action, tracker, dispatcher)
            num_predicted_actions += 1

        if (num_predicted_actions == self.max_number_of_predictions
                and should_predict_another_action):
            # circuit breaker was tripped
            logger.warn("Circuit breaker tripped. Stopped predicting "
                        "more actions for sender '{}'".format(
                            tracker.sender_id))
            if self.on_circuit_break:
                # call a registered callback
                self.on_circuit_break(tracker, dispatcher)

            # added to restart the bot when it becomes unstable when a followup action is triggered
            # due to contexts
            should_predict_another_action = self._run_action(
                ActionRestart(), tracker, dispatcher)
예제 #5
0
파일: domain.py 프로젝트: viktara/rasa_core
class Domain(with_metaclass(abc.ABCMeta, object)):
    """The domain specifies the universe in which the bot's policy acts.

    A Domain subclass provides the actions the bot can take, the intents
    and entities it can recognise, and the topics it knows about."""

    DEFAULT_ACTIONS = [ActionListen(), ActionRestart()]

    def __init__(self, topics=None, store_entities_as_slots=True,
                 restart_intent="restart"):
        self.default_topic = DefaultTopic
        self.topics = topics if topics is not None else []
        self.store_entities_as_slots = store_entities_as_slots
        self.restart_intent = restart_intent

    @utils.lazyproperty
    def num_actions(self):
        """Returns the number of available actions."""

        # noinspection PyTypeChecker
        return len(self.actions)

    @utils.lazyproperty
    def action_names(self):
        # type: () -> List[Text]
        """Returns the name of available actions."""

        return [a.name() for a in self.actions]

    @utils.lazyproperty
    def action_map(self):
        # type: () -> Dict[Text, Tuple[int, Action]]
        """Provides a mapping from action names to indices and actions."""
        return {a.name(): (i, a) for i, a in enumerate(self.actions)}

    @utils.lazyproperty
    def num_features(self):
        """Number of used input features for the action prediction."""

        return len(self.input_features)

    def action_for_name(self, action_name):
        # type: (Text) -> Optional[Action]
        """Looks up which action corresponds to this action name."""

        if action_name in self.action_map:
            return self.action_map.get(action_name)[1]
        else:
            self._raise_action_not_found_exception(action_name)

    def action_for_index(self, index):
        """Integer index corresponding to an actions index in the action list.

        This method resolves the index to the actions name."""

        if len(self.actions) <= index or index < 0:
            raise Exception(
                    "Can not access action at index {}. "
                    "Domain has {} actions.".format(index, len(self.actions)))
        return self.actions[index]

    def index_for_action(self, action_name):
        # type: (Text) -> Optional[int]
        """Looks up which action index corresponds to this action name"""

        if action_name in self.action_map:
            return self.action_map.get(action_name)[0]
        else:
            self._raise_action_not_found_exception(action_name)

    def _raise_action_not_found_exception(self, action_name):
        actions = "\n".join(["\t - {}".format(a)
                             for a in sorted(self.action_map)])
        raise Exception(
                "Can not access action '{}', "
                "as that name is not a registered action for this domain. "
                "Available actions are: \n{}".format(action_name, actions))

    @staticmethod
    def _is_predictable_event(event):
        return isinstance(event, ActionExecuted) and not event.unpredictable

    def slice_feature_history(self,
                              featurizer,
                              tracker_history,
                              slice_length):
        # type: (Featurizer, List[Dict[Text, float]], int) -> np.ndarray
        """Slices a featurization from the trackers history.

        If the slice is at the array borders, padding will be added to ensure
        he slice length."""

        slice_end = len(tracker_history)
        slice_start = max(0, slice_end - slice_length)
        padding = [None] * max(0, slice_length - slice_end)
        state_features = padding + tracker_history[slice_start:]
        encoded_features = [featurizer.encode(f, self.input_feature_map)
                            for f in state_features]
        return np.vstack(encoded_features)

    def features_for_tracker_history(self, tracker):
        """Array of features for each state of the trackers history."""

        return [self.get_active_features(tr) for tr in
                tracker.generate_all_prior_states()]

    def feature_vector_for_tracker(self, featurizer, tracker, max_history):
        """Creates a 2D array of shape (max_history,num_features)

        max_history specifies the number of previous steps to be included
        in the input. Each row in the array corresponds to the binarised
        features of each state. Result is padded with default values if
        there are fewer than `max_history` states present."""

        all_features = self.features_for_tracker_history(tracker)
        return self.slice_feature_history(featurizer, all_features, max_history)

    def random_template_for(self, utter_action):
        if utter_action in self.templates:
            return np.random.choice(self.templates[utter_action])
        else:
            return None

    # noinspection PyTypeChecker
    @utils.lazyproperty
    def slot_features(self):
        # type: () -> List[Text]
        """Returns all available slot feature strings."""

        return ["slot_{}_{}".format(s.name, i)
                for s in self.slots
                for i in range(0, s.feature_dimensionality())]

    # noinspection PyTypeChecker
    @utils.lazyproperty
    def prev_action_features(self):
        # type: () -> List[Text]
        """Returns all available previous action feature strings."""

        return ["prev_{0}".format(a.name())
                for a in self.actions]

    # noinspection PyTypeChecker
    @utils.lazyproperty
    def intent_features(self):
        # type: () -> List[Text]
        """Returns all available previous action feature strings."""

        return ["intent_{0}".format(i)
                for i in self.intents]

    # noinspection PyTypeChecker
    @utils.lazyproperty
    def entity_features(self):
        # type: () -> List[Text]
        """Returns all available previous action feature strings."""

        return ["entity_{0}".format(e)
                for e in self.entities]

    def index_of_feature(self, feature_name):
        # type: (Text) -> Optional[int]
        """Provides the index of a feature."""

        return self.input_feature_map.get(feature_name)

    @utils.lazyproperty
    def input_feature_map(self):
        # type: () -> Dict[Text, int]
        """Provides a mapping from feature names to indices."""
        return {f: i for i, f in enumerate(self.input_features)}

    @utils.lazyproperty
    def input_features(self):
        # type: () -> List[Text]
        """Returns all available features."""

        return \
            self.intent_features + \
            self.entity_features + \
            self.slot_features + \
            self.prev_action_features

    def get_active_features(self, tracker):
        # type: (DialogueStateTracker) -> Dict[Text, float]
        """Return a bag of active features from the tracker state"""
        feature_dict = self.get_parsing_features(tracker)
        feature_dict.update(self.get_prev_action_features(tracker))
        return feature_dict

    def get_prev_action_features(self, tracker):
        # type: (DialogueStateTracker) -> Dict[Text, float]
        """Turns the previous taken action into a feature name."""

        latest_action = tracker.latest_action_name
        if latest_action:
            prev_action_name = "prev_{}".format(latest_action)
            if prev_action_name in self.input_feature_map:
                return {prev_action_name: 1}
            else:
                raise Exception(
                        "Failed to use action '{}' in history. "
                        "Please make sure all actions are listed in the "
                        "domains action list.".format(latest_action))
        else:
            return {}

    def get_parsing_features(self, tracker):
        # type: (DialogueStateTracker) -> Dict[Text, float]

        feature_dict = {}

        # Set all found entities with the feature value 1.0
        for entity in tracker.latest_message.entities:
            key = "entity_{0}".format(entity["entity"])
            feature_dict[key] = 1.

        # Set all set slots with the featurization of the stored value
        for key, slot in tracker.slots.items():
            if slot is not None:
                for i, slot_value in enumerate(slot.as_feature()):
                    slot_id = "slot_{}_{}".format(key, i)
                    feature_dict[slot_id] = slot_value

        latest_msg = tracker.latest_message

        if "intent_ranking" in latest_msg.parse_data:
            for intent in latest_msg.parse_data["intent_ranking"]:
                if intent.get("name"):
                    intent_id = "intent_{}".format(intent["name"])
                    feature_dict[intent_id] = intent["confidence"]

        elif latest_msg.intent.get("name"):
            intent_id = "intent_{}".format(latest_msg.intent["name"])
            feature_dict[intent_id] = latest_msg.intent.get("confidence", 1.0)

        return feature_dict

    def slots_for_entities(self, entities):
        if self.store_entities_as_slots:
            return [SlotSet(entity['entity'], entity['value'])
                    for entity in entities
                    for s in self.slots
                    if entity['entity'] == s.name]
        else:
            return []

    def persist(self, file_name):
        raise NotImplementedError

    @classmethod
    def load(cls, file_name):
        raise NotImplementedError

    def persist_specification(self, model_path):
        # type: (Text, List[Text]) -> None
        """Persists the domain specification to storage."""

        domain_spec_path = os.path.join(model_path, 'domain.json')
        utils.create_dir_for_file(domain_spec_path)
        metadata = {
            "features": self.input_features
        }
        with io.open(domain_spec_path, 'w') as f:
            f.write(str(json.dumps(metadata, indent=2)))

    @classmethod
    def load_specification(cls, path):
        matadata_path = os.path.join(path, 'domain.json')
        with io.open(matadata_path) as f:
            specification = json.loads(f.read())
        return specification

    def compare_with_specification(self, path):
        # type: (Text) -> bool
        """Compares the domain spec of the current and the loaded domain.

        Throws exception if the loaded domain specification is different
        to the current domain are different."""

        loaded_domain_spec = self.load_specification(path)
        features = loaded_domain_spec["features"]
        if features != self.input_features:
            missing = ",".join(set(features) - set(self.input_features))
            additional = ",".join(set(self.input_features) - set(features))
            raise Exception(
                    "Domain specification has changed. "
                    "You MUST retrain the policy. " +
                    "Detected mismatch in domain specification. " +
                    "The following features have been \n"
                    "\t - removed: {} \n"
                    "\t - added:   {} ".format(missing, additional))
        else:
            return True

    # Abstract Methods : These have to be implemented in any domain subclass

    @abc.abstractproperty
    def slots(self):
        # type: () -> List[Slot]
        """Domain subclass must provide a list of slots"""
        pass

    @abc.abstractproperty
    def entities(self):
        # type: () -> List[Text]
        raise NotImplementedError(
                "domain must provide a list of entities")

    @abc.abstractproperty
    def intents(self):
        # type: () -> List[Text]
        raise NotImplementedError(
                "domain must provide a list of intents")

    @abc.abstractproperty
    def actions(self):
        # type: () -> List[Action]
        raise NotImplementedError(
                "domain must provide a list of possible actions")

    @abc.abstractproperty
    def templates(self):
        # type: () -> List[Dict[Text, Any]]
        raise NotImplementedError(
                "domain must provide a dictionary of response templates")
예제 #6
0
def test_restart(default_dispatcher_collecting, default_domain):
    tracker = DialogueStateTracker("default", default_domain.slots)
    events = ActionRestart().run(default_dispatcher_collecting, tracker,
                                 default_domain)
    assert events == [Restarted()]
예제 #7
0
파일: domain.py 프로젝트: zyt9749/rasa_core
class Domain(with_metaclass(abc.ABCMeta, object)):
    """The domain specifies the universe in which the bot's policy acts.

    A Domain subclass provides the actions the bot can take, the intents
    and entities it can recognise, and the topics it knows about."""

    DEFAULT_ACTIONS = [ActionListen(), ActionRestart()]

    def __init__(self,
                 topics=None,
                 store_entities_as_slots=True,
                 restart_intent="restart"):
        self.default_topic = DefaultTopic
        self.topics = topics if topics is not None else []
        self.store_entities_as_slots = store_entities_as_slots
        self.restart_intent = restart_intent

    @utils.lazyproperty
    def num_actions(self):
        """Returns the number of available actions."""

        # noinspection PyTypeChecker
        return len(self.actions)

    @utils.lazyproperty
    def action_names(self):
        # type: () -> List[Text]
        """Returns the name of available actions."""

        return [a.name() for a in self.actions]

    @utils.lazyproperty
    def action_map(self):
        # type: () -> Dict[Text, Tuple[int, Action]]
        """Provides a mapping from action names to indices and actions."""
        return {a.name(): (i, a) for i, a in enumerate(self.actions)}

    @utils.lazyproperty
    def num_states(self):
        """Number of used input states for the action prediction."""

        return len(self.input_states)

    def action_for_name(self, action_name):
        # type: (Text) -> Optional[Action]
        """Looks up which action corresponds to this action name."""

        if action_name in self.action_map:
            return self.action_map.get(action_name)[1]
        else:
            self._raise_action_not_found_exception(action_name)

    def action_for_index(self, index):
        """Integer index corresponding to an actions index in the action list.

        This method resolves the index to the actions name."""

        if len(self.actions) <= index or index < 0:
            raise Exception("Can not access action at index {}. "
                            "Domain has {} actions.".format(
                                index, len(self.actions)))
        return self.actions[index]

    def index_for_action(self, action_name):
        # type: (Text) -> Optional[int]
        """Looks up which action index corresponds to this action name"""

        if action_name in self.action_map:
            return self.action_map.get(action_name)[0]
        else:
            self._raise_action_not_found_exception(action_name)

    def _raise_action_not_found_exception(self, action_name):
        actions = "\n".join(
            ["\t - {}".format(a) for a in sorted(self.action_map)])
        raise Exception(
            "Can not access action '{}', "
            "as that name is not a registered action for this domain. "
            "Available actions are: \n{}".format(action_name, actions))

    def random_template_for(self, utter_action):
        if utter_action in self.templates:
            return np.random.choice(self.templates[utter_action])
        else:
            return None

    # noinspection PyTypeChecker
    @utils.lazyproperty
    def slot_states(self):
        # type: () -> List[Text]
        """Returns all available slot state strings."""

        return [
            "slot_{}_{}".format(s.name, i) for s in self.slots
            for i in range(0, s.feature_dimensionality())
        ]

    # noinspection PyTypeChecker
    @utils.lazyproperty
    def prev_action_states(self):
        # type: () -> List[Text]
        """Returns all available previous action state strings."""

        return [PREV_PREFIX + a.name() for a in self.actions]

    # noinspection PyTypeChecker
    @utils.lazyproperty
    def intent_states(self):
        # type: () -> List[Text]
        """Returns all available previous action state strings."""

        return ["intent_{0}".format(i) for i in self.intents]

    # noinspection PyTypeChecker
    @utils.lazyproperty
    def entity_states(self):
        # type: () -> List[Text]
        """Returns all available previous action state strings."""

        return ["entity_{0}".format(e) for e in self.entities]

    def index_of_state(self, state_name):
        # type: (Text) -> Optional[int]
        """Provides the index of a state."""

        return self.input_state_map.get(state_name)

    @utils.lazyproperty
    def input_state_map(self):
        # type: () -> Dict[Text, int]
        """Provides a mapping from state names to indices."""
        return {f: i for i, f in enumerate(self.input_states)}

    @utils.lazyproperty
    def input_states(self):
        # type: () -> List[Text]
        """Returns all available states."""

        return \
            self.intent_states + \
            self.entity_states + \
            self.slot_states + \
            self.prev_action_states

    @staticmethod
    def get_parsing_states(tracker):
        # type: (DialogueStateTracker) -> Dict[Text, float]

        state_dict = {}

        # Set all found entities with the state value 1.0
        for entity in tracker.latest_message.entities:
            key = "entity_{0}".format(entity["entity"])
            state_dict[key] = 1.0

        # Set all set slots with the featurization of the stored value
        for key, slot in tracker.slots.items():
            if slot is not None:
                for i, slot_value in enumerate(slot.as_feature()):
                    if slot_value != 0:
                        slot_id = "slot_{}_{}".format(key, i)
                        state_dict[slot_id] = slot_value

        latest_msg = tracker.latest_message

        if "intent_ranking" in latest_msg.parse_data:
            for intent in latest_msg.parse_data["intent_ranking"]:
                if intent.get("name"):
                    intent_id = "intent_{}".format(intent["name"])
                    state_dict[intent_id] = intent["confidence"]

        elif latest_msg.intent.get("name"):
            intent_id = "intent_{}".format(latest_msg.intent["name"])
            state_dict[intent_id] = latest_msg.intent.get("confidence", 1.0)

        return state_dict

    def get_prev_action_states(self, tracker):
        # type: (DialogueStateTracker) -> Dict[Text, float]
        """Turns the previous taken action into a state name."""

        latest_action = tracker.latest_action_name
        if latest_action:
            prev_action_name = PREV_PREFIX + latest_action
            if prev_action_name in self.input_state_map:
                return {prev_action_name: 1.0}
            else:
                logger.warning(
                    "Failed to use action '{}' in history. "
                    "Please make sure all actions are listed in the "
                    "domains action list. If you recently removed an "
                    "action, don't worry about this warning. It "
                    "should stop appearing after a while. "
                    "".format(latest_action))
                return {}
        else:
            return {}

    def get_active_states(self, tracker):
        # type: (DialogueStateTracker) -> Dict[Text, float]
        """Return a bag of active states from the tracker state"""
        state_dict = self.get_parsing_states(tracker)
        state_dict.update(self.get_prev_action_states(tracker))
        return state_dict

    def states_for_tracker_history(self, tracker):
        # type: (DialogueStateTracker) -> List[Dict[Text, float]]
        """Array of states for each state of the trackers history."""
        return [
            self.get_active_states(tr)
            for tr in tracker.generate_all_prior_trackers()
        ]

    def slots_for_entities(self, entities):
        if self.store_entities_as_slots:
            slot_events = []
            for s in self.slots:
                matching_entities = [
                    e['value'] for e in entities if e['entity'] == s.name
                ]
                if matching_entities:
                    if s.type_name == 'list':
                        slot_events.append(SlotSet(s.name, matching_entities))
                    else:
                        slot_events.append(
                            SlotSet(s.name, matching_entities[-1]))
            return slot_events
        else:
            return []

    def persist(self, filename):
        raise NotImplementedError

    @classmethod
    def load(cls, filename):
        raise NotImplementedError

    def persist_specification(self, model_path):
        # type: (Text) -> None
        """Persists the domain specification to storage."""

        domain_spec_path = os.path.join(model_path, 'domain.json')
        utils.create_dir_for_file(domain_spec_path)

        metadata = {"states": self.input_states}
        utils.dump_obj_as_json_to_file(domain_spec_path, metadata)

    @classmethod
    def load_specification(cls, path):
        # type: (Text) -> Dict[Text, Any]
        """Load a domains specification from a dumped model directory."""

        matadata_path = os.path.join(path, 'domain.json')
        with io.open(matadata_path) as f:
            specification = json.loads(f.read())
        return specification

    def compare_with_specification(self, path):
        # type: (Text) -> bool
        """Compares the domain spec of the current and the loaded domain.

        Throws exception if the loaded domain specification is different
        to the current domain are different."""

        loaded_domain_spec = self.load_specification(path)
        states = loaded_domain_spec["states"]
        if states != self.input_states:
            missing = ",".join(set(states) - set(self.input_states))
            additional = ",".join(set(self.input_states) - set(states))
            raise Exception("Domain specification has changed. "
                            "You MUST retrain the policy. " +
                            "Detected mismatch in domain specification. " +
                            "The following states have been \n"
                            "\t - removed: {} \n"
                            "\t - added:   {} ".format(missing, additional))
        else:
            return True

    # Abstract Methods : These have to be implemented in any domain subclass
    @abc.abstractproperty
    def slots(self):
        # type: () -> List[Slot]
        """Domain subclass must provide a list of slots"""
        pass

    @abc.abstractproperty
    def entities(self):
        # type: () -> List[Text]
        raise NotImplementedError("domain must provide a list of entities")

    @abc.abstractproperty
    def intents(self):
        # type: () -> List[Text]
        raise NotImplementedError("domain must provide a list of intents")

    @abc.abstractproperty
    def actions(self):
        # type: () -> List[Action]
        raise NotImplementedError(
            "domain must provide a list of possible actions")

    @abc.abstractproperty
    def templates(self):
        # type: () -> List[Dict[Text, Any]]
        raise NotImplementedError(
            "domain must provide a dictionary of response templates")