예제 #1
0
def test_reading_of_trackers_with_legacy_form_events():
    loop_name1 = "my loop"
    loop_name2 = "my form"
    tracker = DialogueStateTracker.from_dict(
        "sender",
        events_as_dict=[
            {
                "event": ActiveLoop.type_name,
                LOOP_NAME: loop_name1
            },
            {
                "event": LegacyForm.type_name,
                LOOP_NAME: None
            },
            {
                "event": LegacyForm.type_name,
                LOOP_NAME: loop_name2
            },
        ],
    )

    expected_events = [
        ActiveLoop(loop_name1),
        LegacyForm(None),
        LegacyForm(loop_name2)
    ]
    assert list(tracker.events) == expected_events
    assert tracker.active_loop[LOOP_NAME] == loop_name2
예제 #2
0
    async def retrieve(self,
                       sender_id: Text) -> Optional[DialogueStateTracker]:
        """Retrieve dialogues for a sender_id in reverse-chronological order.

        Based on the session_date sort key.
        """
        dialogues = self.db.query(
            KeyConditionExpression=Key("sender_id").eq(sender_id),
            Limit=1,
            ScanIndexForward=False,
        )["Items"]

        if not dialogues:
            return None

        events = dialogues[0].get("events", [])

        # `float`s are stored as `Decimal` objects - we need to convert them back
        events_with_floats = core_utils.replace_decimals_with_floats(events)

        if self.domain is None:
            slots = []
        else:
            slots = self.domain.slots

        return DialogueStateTracker.from_dict(sender_id, events_with_floats,
                                              slots)
예제 #3
0
    async def replace_events(request: Request, conversation_id: Text):
        """Use a list of events to set a conversations tracker to a state."""
        validate_request_body(
            request,
            "You must provide events in the request body to set the sate of the "
            "conversation tracker.",
        )

        verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART)

        try:
            async with app.agent.lock_store.lock(conversation_id):
                tracker = DialogueStateTracker.from_dict(
                    conversation_id, request.json, app.agent.domain.slots
                )

                # will override an existing tracker with the same id!
                app.agent.tracker_store.save(tracker)

            return response.json(tracker.current_state(verbosity))
        except Exception as e:
            logger.debug(traceback.format_exc())
            raise ErrorResponse(
                500, "ConversationError", f"An unexpected error occurred. Error: {e}"
            )
예제 #4
0
    def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
        """
        Args:
            sender_id: the message owner ID

        Returns:
            `DialogueStateTracker`
        """
        stored = self.conversations.find_one({"sender_id": sender_id})

        # look for conversations which have used an `int` sender_id in the past
        # and update them.
        if not stored and sender_id.isdigit():
            from pymongo import ReturnDocument

            stored = self.conversations.find_one_and_update(
                {"sender_id": int(sender_id)},
                {"$set": {
                    "sender_id": str(sender_id)
                }},
                return_document=ReturnDocument.AFTER,
            )

        if not stored:
            return

        events = self._events_from_serialized_tracker(stored)
        if not self.load_events_from_previous_conversation_sessions:
            events = self._events_since_last_session_start(events)

        return DialogueStateTracker.from_dict(sender_id, events,
                                              self.domain.slots)
예제 #5
0
    def _retrieve(
        self, sender_id: Text, fetch_events_from_all_sessions: bool
    ) -> Optional[DialogueStateTracker]:
        with self.session_scope() as session:

            serialised_events = self._event_query(
                session,
                sender_id,
                fetch_events_from_all_sessions=fetch_events_from_all_sessions,
            ).all()

            events = [json.loads(event.data) for event in serialised_events]

            if self.domain and len(events) > 0:
                logger.debug(f"Recreating tracker from sender id '{sender_id}'")
                return DialogueStateTracker.from_dict(
                    sender_id, events, self.domain.slots
                )
            else:
                logger.debug(
                    f"Can't retrieve tracker matching "
                    f"sender id '{sender_id}' from SQL storage. "
                    f"Returning `None` instead."
                )
                return None
예제 #6
0
    def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
        """Retrieves tracker for the latest conversation session."""
        events = self._retrieve(sender_id, fetch_events_from_all_sessions=False)

        if not events:
            return None

        return DialogueStateTracker.from_dict(sender_id, events, self.domain.slots)
예제 #7
0
 def _convert_tracker(self, sender_id, tracker):
     if self.domain:
         return DialogueStateTracker.from_dict(sender_id, tracker["events"],
                                               self.domain.slots)
     else:
         logger.warning("Can't recreate tracker from mongo storage "
                        "because no domain is set. Returning `None` "
                        "instead.")
         return None
예제 #8
0
def test_tracker_without_slots(key, value, caplog):
    event = SlotSet(key, value)
    tracker = DialogueStateTracker.from_dict("any", [])
    assert key in tracker.slots
    with caplog.at_level(logging.INFO):
        event.apply_to(tracker)
        v = tracker.get_slot(key)
        assert v == value
    assert len(caplog.records) == 0
예제 #9
0
파일: restore.py 프로젝트: attgua/Geco
def load_tracker_from_json(tracker_dump: Text,
                           domain: Domain) -> DialogueStateTracker:
    """Read the json dump from the file and instantiate a tracker it."""

    tracker_json = json.loads(rasa.shared.utils.io.read_file(tracker_dump))
    sender_id = tracker_json.get("sender_id", DEFAULT_SENDER_ID)
    return DialogueStateTracker.from_dict(sender_id,
                                          tracker_json.get("events", []),
                                          domain.slots)
예제 #10
0
    def retrieve_full_tracker(
            self, conversation_id: Text) -> Optional[DialogueStateTracker]:
        events = self._retrieve(conversation_id,
                                fetch_events_from_all_sessions=True)

        if not events:
            return None

        return DialogueStateTracker.from_dict(conversation_id, events,
                                              self.domain.slots)
예제 #11
0
    async def retrieve_full_tracker(
            self, conversation_id: Text) -> Optional[DialogueStateTracker]:
        """Fetching all tracker events across conversation sessions."""
        events = await self._retrieve(conversation_id,
                                      fetch_events_from_all_sessions=True)

        if not events:
            return None

        return DialogueStateTracker.from_dict(conversation_id, events,
                                              self.domain.slots)
예제 #12
0
    def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
        # TODO: Remove this in Rasa Open Source 3.0 along with the
        # deprecation warning in the constructor
        if self.retrieve_events_from_previous_conversation_sessions:
            return self.retrieve_full_tracker(sender_id)

        events = self._retrieve(sender_id, fetch_events_from_all_sessions=False)

        if not events:
            return None

        return DialogueStateTracker.from_dict(sender_id, events, self.domain.slots)
예제 #13
0
def test_current_state_no_events(default_agent):
    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    tracker_json = json.loads(rasa.shared.utils.io.read_file(tracker_dump))

    tracker = DialogueStateTracker.from_dict(
        tracker_json.get("sender_id"),
        tracker_json.get("events", []),
        default_agent.domain.slots,
    )

    state = tracker.current_state(EventVerbosity.NONE)
    assert state.get("events") is None
예제 #14
0
def new_form_and_tracker(form_spec, requested_slot, additional_slots=[]):
    form = ActionBotfrontForm(form_spec.get("form_name"))
    tracker = DialogueStateTracker.from_dict(
        "default",
        [],
        [
            Slot(name=requested_slot),
            *[Slot(name=name) for name in additional_slots],
            Slot(name="requested_slot", initial_value=requested_slot),
        ],
    )
    form.form_spec = form_spec  # load spec manually
    return form, tracker
예제 #15
0
async def generate_response(nlg_call, domain):
    """Mock response generator.

    Generates the responses from the bot's domain file.
    """
    kwargs = nlg_call.get("arguments", {})
    response = nlg_call.get("response")
    sender_id = nlg_call.get("tracker", {}).get("sender_id")
    events = nlg_call.get("tracker", {}).get("events")
    tracker = DialogueStateTracker.from_dict(sender_id, events, domain.slots)
    channel_name = nlg_call.get("channel")

    return await TemplatedNaturalLanguageGenerator(domain.responses).generate(
        response, tracker, channel_name, **kwargs)
예제 #16
0
def test_current_state_all_events(default_agent):
    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    tracker_json = json.loads(rasa.shared.utils.io.read_file(tracker_dump))

    tracker_json["events"].insert(3, {"event": "restart"})

    tracker = DialogueStateTracker.from_dict(
        tracker_json.get("sender_id"),
        tracker_json.get("events", []),
        default_agent.domain.slots,
    )

    evts = [e.as_dict() for e in tracker.events]

    state = tracker.current_state(EventVerbosity.ALL)
    assert state.get("events") == evts
예제 #17
0
def test_reading_of_trackers_with_legacy_form_validation_events():
    tracker = DialogueStateTracker.from_dict(
        "sender",
        events_as_dict=[
            {"event": LegacyFormValidation.type_name, "name": None, "validate": True},
            {"event": LegacyFormValidation.type_name, "name": None, "validate": False},
        ],
    )

    expected_events = [LegacyFormValidation(True), LegacyFormValidation(False)]
    actual_events = list(tracker.events)
    assert list(tracker.events) == expected_events
    assert not actual_events[0].is_interrupted
    assert actual_events[1].is_interrupted

    assert tracker.active_loop[LOOP_INTERRUPTED]
예제 #18
0
async def test_parsing_with_tracker():
    tracker = DialogueStateTracker.from_dict("1", [], [Slot("requested_language")])

    # we'll expect this value 'en' to be part of the result from the interpreter
    tracker._set_slot("requested_language", "en")

    endpoint = EndpointConfig("https://interpreter.com")
    with aioresponses() as mocked:
        mocked.post("https://interpreter.com/parse", repeat=True, status=200)

        # mock the parse function with the one defined for this test
        with patch.object(RasaNLUHttpInterpreter, "parse", mocked_parse):
            interpreter = RasaNLUHttpInterpreter(endpoint_config=endpoint)
            agent = Agent(None, None, interpreter)
            result = await agent.parse_message_using_nlu_interpreter("lunch?", tracker)

            assert result["requested_language"] == "en"
예제 #19
0
    async def tracker_predict(request: Request) -> HTTPResponse:
        """ Given a list of events, predicts the next action"""
        validate_request_body(
            request,
            "No events defined in request_body. Add events to request body in order to "
            "predict the next action.",
        )

        verbosity = event_verbosity_parameter(request,
                                              EventVerbosity.AFTER_RESTART)
        request_params = request.json
        try:
            tracker = DialogueStateTracker.from_dict(DEFAULT_SENDER_ID,
                                                     request_params,
                                                     app.agent.domain.slots)
        except Exception as e:
            logger.debug(traceback.format_exc())
            raise ErrorResponse(
                400,
                "BadRequest",
                f"Supplied events are not valid. {e}",
                {
                    "parameter": "",
                    "in": "body"
                },
            )

        try:
            policy_ensemble = app.agent.policy_ensemble
            probabilities, policy = policy_ensemble.probabilities_using_best_policy(
                tracker, app.agent.domain, app.agent.interpreter)

            scores = [{
                "action": a,
                "score": p
            } for a, p in zip(app.agent.domain.action_names, probabilities)]

            return response.json({
                "scores": scores,
                "policy": policy,
                "tracker": tracker.current_state(verbosity),
            })
        except Exception as e:
            logger.debug(traceback.format_exc())
            raise ErrorResponse(500, "PredictionError",
                                f"An unexpected error occurred. Error: {e}")
예제 #20
0
def test_session_started_not_part_of_applied_events(default_agent: Agent):
    # take tracker dump and insert a SessionStarted event sequence
    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    tracker_json = json.loads(rasa.shared.utils.io.read_file(tracker_dump))
    tracker_json["events"].insert(
        4, {"event": ActionExecuted.type_name, "name": ACTION_SESSION_START_NAME}
    )
    tracker_json["events"].insert(5, {"event": SessionStarted.type_name})

    # initialise a tracker from this list of events
    tracker = DialogueStateTracker.from_dict(
        tracker_json.get("sender_id"),
        tracker_json.get("events", []),
        default_agent.domain.slots,
    )

    # the SessionStart event was at index 5, the tracker's `applied_events()` should
    # be the same as the list of events from index 6 onwards
    assert tracker.applied_events() == list(tracker.events)[6:]
예제 #21
0
def test_current_state_applied_events(default_agent):
    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    tracker_json = json.loads(rasa.shared.utils.io.read_file(tracker_dump))

    # add some events that result in other events not being applied anymore
    tracker_json["events"].insert(1, {"event": "restart"})
    tracker_json["events"].insert(7, {"event": "rewind"})
    tracker_json["events"].insert(8, {"event": "undo"})

    tracker = DialogueStateTracker.from_dict(
        tracker_json.get("sender_id"),
        tracker_json.get("events", []),
        default_agent.domain.slots,
    )

    evts = [e.as_dict() for e in tracker.events]
    applied_events = [evts[2], evts[9]]

    state = tracker.current_state(EventVerbosity.APPLIED)
    assert state.get("events") == applied_events
예제 #22
0
파일: server.py 프로젝트: cr33dx/rasa
    async def tracker_predict(request: Request) -> HTTPResponse:
        """Given a list of events, predicts the next action."""
        validate_request_body(
            request,
            "No events defined in request_body. Add events to request body in order to "
            "predict the next action.",
        )

        verbosity = event_verbosity_parameter(request,
                                              EventVerbosity.AFTER_RESTART)
        request_params = request.json
        try:
            tracker = DialogueStateTracker.from_dict(DEFAULT_SENDER_ID,
                                                     request_params,
                                                     app.agent.domain.slots)
        except Exception as e:
            logger.debug(traceback.format_exc())
            raise ErrorResponse(
                400,
                "BadRequest",
                f"Supplied events are not valid. {e}",
                {
                    "parameter": "",
                    "in": "body"
                },
            )

        try:
            result = app.agent.create_processor().predict_next_with_tracker(
                tracker, verbosity)

            return response.json(result)
        except Exception as e:
            logger.debug(traceback.format_exc())
            raise ErrorResponse(500, "PredictionError",
                                f"An unexpected error occurred. Error: {e}")