Esempio n. 1
0
    def continue_training():
        request.headers.get("Accept")
        epochs = request.args.get("epochs", 30)
        batch_size = request.args.get("batch_size", 5)
        request_params = request.get_json(force=True)
        sender_id = UserMessage.DEFAULT_SENDER_ID

        try:
            tracker = DialogueStateTracker.from_dict(sender_id, request_params,
                                                     agent.domain.slots)
        except Exception as e:
            return error(400, "InvalidParameter",
                         "Supplied events are not valid. {}".format(e), {
                             "parameter": "",
                             "in": "body"
                         })

        try:
            # Fetches the appropriate bot response in a json format
            agent.continue_training([tracker],
                                    epochs=epochs,
                                    batch_size=batch_size)
            return '', 204

        except Exception as e:
            logger.exception("Caught an exception during prediction.")
            return error(500, "TrainingException",
                         "Server failure. Error: {}".format(e))
Esempio n. 2
0
    async def continue_training(request: Request):
        epochs = request.raw_args.get("epochs", 30)
        batch_size = request.raw_args.get("batch_size", 5)
        request_params = request.json
        sender_id = UserMessage.DEFAULT_SENDER_ID

        try:
            tracker = DialogueStateTracker.from_dict(sender_id,
                                                     request_params,
                                                     app.agent.domain.slots)
        except Exception as e:
            raise ErrorResponse(400, "InvalidParameter",
                                "Supplied events are not valid. {}".format(e),
                                {"parameter": "", "in": "body"})

        try:
            # Fetches the appropriate bot response in a json format
            app.agent.continue_training([tracker],
                                        epochs=epochs,
                                        batch_size=batch_size)
            return response.text('', 204)

        except Exception as e:
            logger.exception("Caught an exception during prediction.")
            raise ErrorResponse(500, "TrainingException",
                                "Server failure. Error: {}".format(e))
Esempio n. 3
0
    def tracker_predict():
        """ Given a list of events, predicts the next action"""

        sender_id = UserMessage.DEFAULT_SENDER_ID
        request_params = request.get_json(force=True)
        verbosity = event_verbosity_parameter(EventVerbosity.AFTER_RESTART)

        try:
            tracker = DialogueStateTracker.from_dict(sender_id, request_params,
                                                     agent.domain.slots)
        except Exception as e:
            return error(400, "InvalidParameter",
                         "Supplied events are not valid. {}".format(e), {
                             "parameter": "",
                             "in": "body"
                         })

        policy_ensemble = agent.policy_ensemble
        probabilities, policy = \
            policy_ensemble.probabilities_using_best_policy(tracker,
                                                            agent.domain)

        scores = [{
            "action": a,
            "score": p
        } for a, p in zip(agent.domain.action_names, probabilities)]

        return jsonify({
            "scores": scores,
            "policy": policy,
            "tracker": tracker.current_state(verbosity)
        })
Esempio n. 4
0
    def retrieve(self, sender_id):
        stored = self.conversations.find_one({"sender_id": sender_id})

        # look for conversations which have used an `int` sender_id in the past
        # and update them.
        if stored is None and sender_id.isdigit():
            from pymongo import ReturnDocument
            stored = self.conversations.find_one_and_update(
                {"sender_id": int(sender_id)},
                {"$set": {
                    "sender_id": str(sender_id)
                }},
                return_document=ReturnDocument.AFTER)

        if stored is not None:
            if self.domain:
                return DialogueStateTracker.from_dict(sender_id,
                                                      stored.get("events"),
                                                      self.domain.slots)
            else:
                logger.warning("Can't recreate tracker from mongo storage "
                               "because no domain is set. Returning `None` "
                               "instead.")
                return None
        else:
            return None
Esempio n. 5
0
def load_tracker_from_json(tracker_dump, domain):
    # type: (Text, Agent) -> DialogueStateTracker
    """Read the json dump from the file and instantiate a tracker it."""

    tracker_json = json.loads(utils.read_file(tracker_dump))
    sender_id = tracker_json.get("sender_id", UserMessage.DEFAULT_SENDER_ID)
    return DialogueStateTracker.from_dict(sender_id,
                                          tracker_json.get("events", []),
                                          domain)
Esempio n. 6
0
def _load_tracker_from_json(tracker_dump, agent):
    # type: (Text, Agent) -> DialogueStateTracker
    """Read the json dump from the file and instantiate a tracker it."""

    tracker_json = json.loads(utils.read_file(tracker_dump))
    sender_id = tracker_json.get("sender_id", UserMessage.DEFAULT_SENDER_ID)
    return DialogueStateTracker.from_dict(sender_id,
                                          tracker_json.get("events", []),
                                          agent.domain)
Esempio n. 7
0
    def replace_events(sender_id):
        """Use a list of events to set a conversations tracker to a state."""

        request_params = request.get_json(force=True)
        tracker = DialogueStateTracker.from_dict(sender_id, request_params,
                                                 agent.domain.slots)
        # will override an existing tracker with the same id!
        agent.tracker_store.save(tracker)
        return jsonify(tracker.current_state(EventVerbosity.AFTER_RESTART))
Esempio n. 8
0
def load_tracker_from_json(tracker_dump: Text,
                           domain: Domain) -> DialogueStateTracker:
    """Read the json dump from the file and instantiate a tracker it."""

    tracker_json = json.loads(utils.read_file(tracker_dump))
    sender_id = tracker_json.get("sender_id", UserMessage.DEFAULT_SENDER_ID)
    return DialogueStateTracker.from_dict(sender_id,
                                          tracker_json.get("events", []),
                                          domain.slots)
Esempio n. 9
0
    def replace_events(sender_id):
        """Use a list of events to set a conversations tracker to a state."""

        request_params = request.get_json(force=True)
        tracker = DialogueStateTracker.from_dict(sender_id,
                                                 request_params,
                                                 agent.domain.slots)
        # will override an existing tracker with the same id!
        agent.tracker_store.save(tracker)
        return jsonify(tracker.current_state(should_include_events=True))
Esempio n. 10
0
def generate_response(nlg_call, domain):
    kwargs = nlg_call.get("arguments", {})
    template = nlg_call.get("template")
    sender_id = nlg_call.get("tracker", {}).get("sender_id")
    events = nlg_call.get("tracker", {}).get("events")
    tracker = DialogueStateTracker.from_dict(sender_id, events, domain.slots)
    channel_name = nlg_call.get("channel")

    return TemplatedNaturalLanguageGenerator(domain.templates).generate(
        template, tracker, channel_name, **kwargs)
Esempio n. 11
0
    def _init_tracker(self, sender_id, tracker):
        if self.domain:
            return DialogueStateTracker.from_dict(sender_id, tracker["events"],
                                                  self.domain.slots)

        else:
            logger.warning("Can't recreate tracker from mongo storage "
                           "because no domain is set. Returning `None` "
                           "instead.")
            return None
Esempio n. 12
0
def test_current_state_no_events(default_agent):
    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    tracker_json = json.loads(utils.read_file(tracker_dump))

    tracker = DialogueStateTracker.from_dict(tracker_json.get("sender_id"),
                                             tracker_json.get("events", []),
                                             default_agent.domain.slots)

    state = tracker.current_state(EventVerbosity.NONE)
    assert state.get("events") is None
Esempio n. 13
0
    def update_tracker(sender_id):
        """Use a list of events to set a conversations tracker to a state."""

        request_params = request.get_json(force=True)
        tracker = DialogueStateTracker.from_dict(sender_id, request_params,
                                                 agent().domain)
        agent().tracker_store.save(tracker)

        # will override an existing tracker with the same id!
        agent().tracker_store.save(tracker)
        return jsonify(tracker.current_state(should_include_events=True))
Esempio n. 14
0
def generate_response(nlg_call, domain):
    kwargs = nlg_call.get("arguments", {})
    template = nlg_call.get("template")
    sender_id = nlg_call.get("tracker", {}).get("sender_id")
    events = nlg_call.get("tracker", {}).get("events")
    tracker  = DialogueStateTracker.from_dict(
                    sender_id, events, domain.slots)
    channel_name = nlg_call.get("channel")

    return TemplatedNaturalLanguageGenerator(domain.templates).generate(
            template, tracker, channel_name, **kwargs)
Esempio n. 15
0
    def tracker(self,
                sender_id: Text,
                domain: Domain,
                event_verbosity: EventVerbosity = EventVerbosity.ALL,
                until: Optional[int] = None):
        """Retrieve and recreate a tracker fetched from the remote instance."""

        tracker_json = self.tracker_json(sender_id, event_verbosity, until)

        tracker = DialogueStateTracker.from_dict(
            sender_id, tracker_json.get("events", []), domain.slots)
        return tracker
Esempio n. 16
0
    async def replace_events(request: Request, sender_id: Text):
        """Use a list of events to set a conversations tracker to a state."""

        request_params = request.json
        verbosity = event_verbosity_parameter(request,
                                              EventVerbosity.AFTER_RESTART)

        tracker = DialogueStateTracker.from_dict(sender_id,
                                                 request_params,
                                                 app.agent.domain.slots)
        # will override an existing tracker with the same id!
        app.agent.tracker_store.save(tracker)
        return response.json(tracker.current_state(verbosity))
Esempio n. 17
0
    def update_tracker(self, request, sender_id):
        """Use a list of events to set a conversations tracker to a state."""

        request.setHeader('Content-Type', 'application/json')
        request_params = json.loads(request.content.read().decode(
            'utf-8', 'strict'))
        tracker = DialogueStateTracker.from_dict(sender_id, request_params,
                                                 self.agent.domain)
        self.agent.tracker_store.save(tracker)

        # will override an existing tracker with the same id!
        self.agent.tracker_store.save(tracker)
        return json.dumps(tracker.current_state(should_include_events=True))
Esempio n. 18
0
 def retrieve(self, sender_id):
     stored = self.conversations.find_one({"sender_id": sender_id})
     if stored is not None:
         if self.domain:
             return DialogueStateTracker.from_dict(sender_id,
                                                   stored.get("events"),
                                                   self.domain.slots)
         else:
             logger.warning("Can't recreate tracker from mongo storage "
                            "because no domain is set. Returning `None` "
                            "instead.")
             return None
     else:
         return None
Esempio n. 19
0
def test_current_state_all_events(default_agent):
    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    tracker_json = json.loads(utils.read_file(tracker_dump))

    tracker_json["events"].insert(3, {"event": "restart"})

    tracker = DialogueStateTracker.from_dict(tracker_json.get("sender_id"),
                                             tracker_json.get("events", []),
                                             default_agent.domain.slots)

    evts = [e.as_dict() for e in tracker.events]

    state = tracker.current_state(EventVerbosity.ALL)
    assert state.get("events") == evts
Esempio n. 20
0
 def retrieve(self, sender_id):
     stored = self.conversations.find_one({"sender_id": sender_id})
     if stored is not None:
         if self.domain:
             return DialogueStateTracker.from_dict(sender_id,
                                                   stored.get("events"),
                                                   self.domain.slots)
         else:
             logger.warning("Can't recreate tracker from mongo storage "
                            "because no domain is set. Returning `None` "
                            "instead.")
             return None
     else:
         return None
Esempio n. 21
0
async def generate_response(nlg_call, domain):
    """Mock response generator.

    Generates the responses from the bot's domain file.
    """
    kwargs = nlg_call.get("arguments", {})
    template = nlg_call.get("template")
    sender_id = nlg_call.get("tracker", {}).get("sender_id")
    events = nlg_call.get("tracker", {}).get("events")
    tracker = DialogueStateTracker.from_dict(
        sender_id, events, domain.slots)
    channel_name = nlg_call.get("channel")

    return await TemplatedNaturalLanguageGenerator(domain.templates).generate(
        template, tracker, channel_name, **kwargs)
Esempio n. 22
0
def test_read_json_dump(default_agent):
    json_content = utils.read_file("data/test_trackers/tracker_moodbot.json")
    tracker_json = json.loads(json_content)
    sender_id = tracker_json.get("sender_id", UserMessage.DEFAULT_SENDER_ID)
    restored_tracker = DialogueStateTracker.from_dict(
        sender_id, tracker_json.get("events", []), default_agent.domain)

    assert len(restored_tracker.events) == 7
    assert restored_tracker.latest_action_name == "action_listen"
    assert not restored_tracker.is_paused()
    assert restored_tracker.sender_id == "mysender"
    assert restored_tracker.events[-1].timestamp == 1517821726.211042

    restored_state = restored_tracker.current_state(should_include_events=True)
    assert restored_state == tracker_json
Esempio n. 23
0
    def tracker(self,
                sender_id,  # type: Text
                domain,  # type: Domain
                only_events_after_latest_restart=False,  # type: bool
                include_events=True,  # type: bool
                until=None  # type: Optional[int]
                ):
        """Retrieve and recreate a tracker fetched from the remote instance."""

        tracker_json = self.tracker_json(
                sender_id, only_events_after_latest_restart,
                include_events, until)

        tracker = DialogueStateTracker.from_dict(
                sender_id, tracker_json.get("events", []), domain)
        return tracker
Esempio n. 24
0
def test_dump_and_restore_as_json(default_agent, tmpdir_factory):
    trackers = default_agent.load_data(DEFAULT_STORIES_FILE)

    for tracker in trackers:
        out_path = tmpdir_factory.mktemp("tracker").join("dumped_tracker.json")

        dumped = tracker.current_state(should_include_events=True)
        utils.dump_obj_as_json_to_file(out_path.strpath, dumped)

        tracker_json = json.loads(utils.read_file(out_path.strpath))
        sender_id = tracker_json.get("sender_id",
                                     UserMessage.DEFAULT_SENDER_ID)
        restored_tracker = DialogueStateTracker.from_dict(
            sender_id, tracker_json.get("events", []), default_agent.domain)

        assert restored_tracker == tracker
Esempio n. 25
0
    def tracker(
            self,
            sender_id,  # type: Text
            domain,  # type: Domain
            should_ignore_restarts=False,  # type: bool
            include_events=True,  # type: bool
            until=None  # type: Optional[int]
    ):
        """Retrieve and recreate a tracker fetched from the remote instance."""

        tracker_json = self.tracker_json(sender_id, should_ignore_restarts,
                                         include_events, until)

        tracker = DialogueStateTracker.from_dict(
            sender_id, tracker_json.get("events", []), domain.slots)
        return tracker
Esempio n. 26
0
    def retrieve(self, sender_id: Text) -> DialogueStateTracker:
        """Create a tracker from all previously stored events."""

        query = self.session.query(self.SQLEvent)
        result = query.filter_by(sender_id=sender_id).all()
        events = [json.loads(event.data) for event in result]

        if self.domain and len(events) > 0:
            logger.debug("Recreating tracker "
                         "from sender id '{}'".format(sender_id))

            return DialogueStateTracker.from_dict(sender_id, events,
                                                  self.domain.slots)
        else:
            logger.debug("Can't retrieve tracker matching"
                         "sender id '{}' from SQL storage.  "
                         "Returning `None` instead.".format(sender_id))
Esempio n. 27
0
 def tracker_predict():
     """ Given a list of events, predicts the next action"""
     sender_id = UserMessage.DEFAULT_SENDER_ID
     request_params = request.get_json(force=True)
     for param in request_params:
         if param.get('event', None) is None:
             return Response(
                 """Invalid list of events provided.""",
                 status=400)
     tracker = DialogueStateTracker.from_dict(sender_id,
                                              request_params,
                                              agent.domain.slots)
     policy_ensemble = agent.policy_ensemble
     probabilities = policy_ensemble.probabilities_using_best_policy(tracker, agent.domain)
     probability_dict = {agent.domain.action_for_index(idx, agent.action_endpoint).name(): probability
                         for idx, probability in enumerate(probabilities)}
     return jsonify(probability_dict)
Esempio n. 28
0
def test_current_state_applied_events(default_agent):
    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    tracker_json = json.loads(utils.read_file(tracker_dump))

    # add some events that result in other events not being applied anymore
    tracker_json["events"].insert(1, {"event": "restart"})
    tracker_json["events"].insert(7, {"event": "rewind"})
    tracker_json["events"].insert(8, {"event": "undo"})

    tracker = DialogueStateTracker.from_dict(tracker_json.get("sender_id"),
                                             tracker_json.get("events", []),
                                             default_agent.domain.slots)

    evts = [e.as_dict() for e in tracker.events]
    applied_events = [evts[2], evts[9]]

    state = tracker.current_state(EventVerbosity.APPLIED)
    assert state.get("events") == applied_events
Esempio n. 29
0
def test_dump_and_restore_as_json(default_agent, tmpdir):
    trackers = training.extract_trackers(
        DEFAULT_STORIES_FILE, default_agent.domain, default_agent.featurizer,
        default_agent.interpreter, default_agent.policy_ensemble.max_history())

    out_path = tmpdir.join("dumped_tracker.json")

    for tracker in trackers:
        dumped = tracker.current_state(should_include_events=True)
        utils.dump_obj_as_json_to_file(out_path.strpath, dumped)

        tracker_json = json.loads(utils.read_file(out_path.strpath))
        sender_id = tracker_json.get("sender_id",
                                     UserMessage.DEFAULT_SENDER_ID)
        restored_tracker = DialogueStateTracker.from_dict(
            sender_id, tracker_json.get("events", []), default_agent.domain)

        assert restored_tracker == tracker
Esempio n. 30
0
 def tracker_predict():
     """ Given a list of events, predicts the next action"""
     sender_id = UserMessage.DEFAULT_SENDER_ID
     request_params = request.get_json(force=True)
     for param in request_params:
         if param.get('event', None) is None:
             return Response("""Invalid list of events provided.""",
                             status=400)
     tracker = DialogueStateTracker.from_dict(sender_id, request_params,
                                              agent.domain.slots)
     policy_ensemble = agent.policy_ensemble
     probabilities = policy_ensemble.probabilities_using_best_policy(
         tracker, agent.domain)
     probability_dict = {
         agent.domain.action_for_index(idx, agent.action_endpoint).name():
         probability
         for idx, probability in enumerate(probabilities)
     }
     return jsonify(probability_dict)
Esempio n. 31
0
    def continue_training():
        request.headers.get("Accept")
        epochs = request.args.get("epochs", 30)
        batch_size = request.args.get("batch_size", 5)
        request_params = request.get_json(force=True)
        tracker = DialogueStateTracker.from_dict(UserMessage.DEFAULT_SENDER_ID,
                                                 request_params,
                                                 agent.domain.slots)

        try:
            # Fetches the appropriate bot response in a json format
            agent.continue_training([tracker],
                                    epochs=epochs,
                                    batch_size=batch_size)
            return '', 204

        except Exception as e:
            logger.exception("Caught an exception during prediction.")
            return Response(jsonify(error="Server failure. Error: {}"
                                    "".format(e)),
                            status=500,
                            content_type="application/json")
Esempio n. 32
0
    def continue_training():
        request.headers.get("Accept")
        epochs = request.args.get("epochs", 30)
        batch_size = request.args.get("batch_size", 5)
        request_params = request.get_json(force=True)
        tracker = DialogueStateTracker.from_dict(UserMessage.DEFAULT_SENDER_ID,
                                                 request_params,
                                                 agent.domain.slots)

        try:
            # Fetches the appropriate bot response in a json format
            agent.continue_training([tracker],
                                    epochs=epochs,
                                    batch_size=batch_size)
            return '', 204

        except Exception as e:
            logger.exception("Caught an exception during prediction.")
            return Response(jsonify(error="Server failure. Error: {}"
                                          "".format(e)),
                            status=500,
                            content_type="application/json")