Exemplo n.º 1
0
    def load_specification(cls, path):
        # type: (Text) -> Dict[Text, Any]
        """Load a domains specification from a dumped model directory."""

        metadata_path = os.path.join(path, 'domain.json')
        specification = json.loads(utils.read_file(metadata_path))
        return specification
Exemplo n.º 2
0
    def load(cls, path: Text) -> 'KerasPolicy':
        from keras.models import load_model

        if os.path.exists(path):
            featurizer = TrackerFeaturizer.load(path)
            meta_path = os.path.join(path, "keras_policy.json")
            if os.path.isfile(meta_path):
                meta = json.loads(utils.read_file(meta_path))
                model_file = os.path.join(path, meta["model"])

                graph = tf.Graph()
                with graph.as_default():
                    session = tf.Session()
                    with session.as_default():
                        model = load_model(model_file,
                                           custom_objects={
                                               'Position_Embedding':
                                               Position_Embedding,
                                               'Attention': Attention
                                           })

                return cls(featurizer=featurizer,
                           model=model,
                           graph=graph,
                           session=session,
                           current_epoch=meta["epochs"])
            else:
                return cls(featurizer=featurizer)
        else:
            raise Exception("Failed to load dialogue model. Path {} "
                            "doesn't exist".format(os.path.abspath(path)))
Exemplo n.º 3
0
async def test_undo_latest_msg(mock_endpoint):
    tracker_dump = utils.read_file("data/test_trackers/tracker_moodbot.json")
    tracker_json = json.loads(tracker_dump)
    evts = tracker_json.get("events")

    sender_id = uuid.uuid4().hex

    url = '{}/conversations/{}/tracker?include_events=ALL'.format(
        mock_endpoint.url, sender_id)
    replace_url = '{}/conversations/{}/tracker/events'.format(
        mock_endpoint.url, sender_id)
    with aioresponses() as mocked:
        mocked.get(url, body=tracker_dump)
        mocked.put(replace_url)

        await interactive._undo_latest(sender_id, mock_endpoint)

        r = latest_request(mocked, 'put', replace_url)

        assert r

        # this should be the events the interactive call send to the endpoint
        # these events should have the last utterance omitted
        replaced_evts = json_of_latest_request(r)
        assert len(replaced_evts) == 6
        assert replaced_evts == evts[:6]
Exemplo n.º 4
0
    def load(cls, path):
        # type: (Text) -> KerasPolicy
        from tensorflow.keras.models import load_model

        if os.path.exists(path):
            featurizer = TrackerFeaturizer.load(path)
            meta_path = os.path.join(path, "keras_policy.json")
            if os.path.isfile(meta_path):
                meta = json.loads(utils.read_file(meta_path))

                model_file = os.path.join(path, meta["model"])

                graph = tf.Graph()
                with graph.as_default():
                    session = tf.Session()
                    with session.as_default():
                        model = load_model(model_file)

                return cls(featurizer=featurizer,
                           model=model,
                           graph=graph,
                           session=session,
                           current_epoch=meta["epochs"])
            else:
                return cls(featurizer=featurizer)
        else:
            raise Exception("Failed to load dialogue model. Path {} "
                            "doesn't exist".format(os.path.abspath(path)))
Exemplo n.º 5
0
    def load(cls, path: Text) -> 'KerasPolicy':
        from tensorflow.keras.models import load_model

        if os.path.exists(path):
            featurizer = TrackerFeaturizer.load(path)
            meta_file = os.path.join(path, "keras_policy.json")
            if os.path.isfile(meta_file):
                meta = json.loads(utils.read_file(meta_file))

                tf_config_file = os.path.join(path,
                                              "keras_policy.tf_config.pkl")
                with io.open(tf_config_file, 'rb') as f:
                    _tf_config = pickle.load(f)

                model_file = os.path.join(path, meta["model"])

                graph = tf.Graph()
                with graph.as_default():
                    session = tf.Session(config=_tf_config)
                    with session.as_default():
                        model = load_model(model_file)

                return cls(featurizer=featurizer,
                           model=model,
                           graph=graph,
                           session=session,
                           current_epoch=meta["epochs"])
            else:
                return cls(featurizer=featurizer)
        else:
            raise Exception("Failed to load dialogue model. Path {} "
                            "doesn't exist".format(os.path.abspath(path)))
Exemplo n.º 6
0
 def load(cls, filename, action_factory=None):
     if not os.path.isfile(filename):
         raise Exception("Failed to load domain specification from '{}'. "
                         "File not found!".format(
                             os.path.abspath(filename)))
     return cls.load_from_yaml(read_file(filename),
                               action_factory=action_factory)
Exemplo n.º 7
0
def test_undo_latest_msg(mock_endpoint):
    tracker_dump = utils.read_file(
            "data/test_trackers/tracker_moodbot.json")
    tracker_json = json.loads(tracker_dump)
    evts = tracker_json.get("events")

    sender_id = uuid.uuid4().hex

    url = '{}/conversations/{}/tracker'.format(
            mock_endpoint.url, sender_id)
    replace_url = '{}/conversations/{}/tracker/events'.format(
            mock_endpoint.url, sender_id)
    httpretty.register_uri(httpretty.GET, url, body=tracker_dump)
    httpretty.register_uri(httpretty.PUT, replace_url)

    httpretty.enable()
    online._undo_latest(sender_id, mock_endpoint)
    httpretty.disable()

    b = httpretty.latest_requests[-1].body.decode("utf-8")

    # this should be the events the online call send to the endpoint
    # these events should have the last utterance omitted
    replaced_evts = json.loads(b)
    assert len(replaced_evts) == 6
    assert replaced_evts == evts[:6]
Exemplo n.º 8
0
    def load(cls, path):
        meta = {}
        if os.path.exists(path):
            meta_path = os.path.join(path, "custom_fallback_policy.json")
            if os.path.isfile(meta_path):
                meta = json.loads(utils.read_file(meta_path))

        return cls(**meta)
Exemplo n.º 9
0
 def load(path):
     featurizer_file = os.path.join(path, "featurizer.json")
     if os.path.isfile(featurizer_file):
         return jsonpickle.decode(utils.read_file(featurizer_file))
     else:
         logger.error("Couldn't load featurizer for policy. "
                      "File '{}' doesn't exist.".format(featurizer_file))
         return None
Exemplo n.º 10
0
    def load(cls, path: Text) -> 'FallbackPolicy':
        meta = {}
        if os.path.exists(path):
            meta_path = os.path.join(path, "two_stage_fallback_policy.json")
            if os.path.isfile(meta_path):
                meta = json.loads(utils.read_file(meta_path))

        return cls(**meta)
Exemplo n.º 11
0
    def load(cls, path: Text) -> "BottisPolicy":

        meta = {}
        if os.path.exists(path):
            meta_path = os.path.join(path, "bottis_policy.json")
            if os.path.isfile(meta_path):
                meta = json.loads(utils.read_file(meta_path))

        return cls(**meta)
Exemplo n.º 12
0
def _load_tracker_from_json(tracker_dump, agent):
    # type: (Text, Agent) -> DialogueStateTracker
    """Read the json dump from the file and instantiate a tracker it."""

    tracker_json = json.loads(utils.read_file(tracker_dump))
    sender_id = tracker_json.get("sender_id", UserMessage.DEFAULT_SENDER_ID)
    return DialogueStateTracker.from_dict(sender_id,
                                          tracker_json.get("events", []),
                                          agent.domain)
Exemplo n.º 13
0
def test_all_events_before_user_msg():
    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    tracker_json = json.loads(utils.read_file(tracker_dump))
    evts = tracker_json.get("events")

    m = online.all_events_before_latest_user_msg(evts)

    assert m is not None
    assert m == evts[:4]
Exemplo n.º 14
0
def load_tracker_from_json(tracker_dump, domain):
    # type: (Text, Agent) -> DialogueStateTracker
    """Read the json dump from the file and instantiate a tracker it."""

    tracker_json = json.loads(utils.read_file(tracker_dump))
    sender_id = tracker_json.get("sender_id", UserMessage.DEFAULT_SENDER_ID)
    return DialogueStateTracker.from_dict(sender_id,
                                          tracker_json.get("events", []),
                                          domain)
Exemplo n.º 15
0
def load_tracker_from_json(tracker_dump: Text,
                           domain: Domain) -> DialogueStateTracker:
    """Read the json dump from the file and instantiate a tracker it."""

    tracker_json = json.loads(utils.read_file(tracker_dump))
    sender_id = tracker_json.get("sender_id", UserMessage.DEFAULT_SENDER_ID)
    return DialogueStateTracker.from_dict(sender_id,
                                          tracker_json.get("events", []),
                                          domain.slots)
Exemplo n.º 16
0
def test_latest_user_message():
    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    tracker_json = json.loads(utils.read_file(tracker_dump))

    m = online.latest_user_message(tracker_json.get("events"))

    assert m is not None
    assert m["event"] == "user"
    assert m["text"] == "/mood_great"
Exemplo n.º 17
0
    def load(cls, path):
        # type: (Text) -> FallbackPolicy
        meta = {}
        if os.path.exists(path):
            meta_path = os.path.join(path, "fallback_policy.json")
            if os.path.isfile(meta_path):
                meta = json.loads(utils.read_file(meta_path))

        return cls(**meta)
Exemplo n.º 18
0
def test_current_state_no_events(default_agent):
    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    tracker_json = json.loads(utils.read_file(tracker_dump))

    tracker = DialogueStateTracker.from_dict(tracker_json.get("sender_id"),
                                             tracker_json.get("events", []),
                                             default_agent.domain.slots)

    state = tracker.current_state(EventVerbosity.NONE)
    assert state.get("events") is None
Exemplo n.º 19
0
def test_splitting_conversation_at_restarts():
    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    evts = json.loads(utils.read_file(tracker_dump)).get("events")
    evts_wo_restarts = evts[:]
    evts.insert(2, {"event": "restart"})
    evts.append({"event": "restart"})

    split = online._split_conversation_at_restarts(evts)
    assert len(split) == 2
    assert [e for s in split for e in s] == evts_wo_restarts
    assert len(split[0]) == 2
    assert len(split[0]) == 2
Exemplo n.º 20
0
    def load(cls, path: Text) -> 'MemoizationPolicy':

        featurizer = TrackerFeaturizer.load(path)
        memorized_file = os.path.join(path, 'memorized_turns.json')
        if os.path.isfile(memorized_file):
            data = json.loads(utils.read_file(memorized_file))
            return cls(featurizer=featurizer, lookup=data["lookup"])
        else:
            logger.info("Couldn't load memoization for policy. "
                        "File '{}' doesn't exist. Falling back to empty "
                        "turn memory.".format(memorized_file))
            return cls()
Exemplo n.º 21
0
def test_is_listening_for_messages(mock_endpoint):
    tracker_dump = utils.read_file("data/test_trackers/tracker_moodbot.json")

    sender_id = uuid.uuid4().hex

    url = '{}/conversations/{}/tracker'.format(mock_endpoint.url, sender_id)
    httpretty.register_uri(httpretty.GET, url, body=tracker_dump)

    httpretty.enable()
    is_listening = online.is_listening_for_message(sender_id, mock_endpoint)
    httpretty.disable()

    assert is_listening
Exemplo n.º 22
0
def test_current_state_all_events(default_agent):
    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    tracker_json = json.loads(utils.read_file(tracker_dump))

    tracker_json["events"].insert(3, {"event": "restart"})

    tracker = DialogueStateTracker.from_dict(tracker_json.get("sender_id"),
                                             tracker_json.get("events", []),
                                             default_agent.domain.slots)

    evts = [e.as_dict() for e in tracker.events]

    state = tracker.current_state(EventVerbosity.ALL)
    assert state.get("events") == evts
Exemplo n.º 23
0
def test_read_json_dump(default_agent):
    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    tracker_json = json.loads(utils.read_file(tracker_dump))

    restored_tracker = restore.load_tracker_from_json(tracker_dump,
                                                      default_agent.domain)

    assert len(restored_tracker.events) == 7
    assert restored_tracker.latest_action_name == "action_listen"
    assert not restored_tracker.is_paused()
    assert restored_tracker.sender_id == "mysender"
    assert restored_tracker.events[-1].timestamp == 1517821726.211042

    restored_state = restored_tracker.current_state(should_include_events=True)
    assert restored_state == tracker_json
Exemplo n.º 24
0
async def test_print_history(mock_endpoint):
    tracker_dump = utils.read_file("data/test_trackers/tracker_moodbot.json")

    sender_id = uuid.uuid4().hex

    url = '{}/conversations/{}/tracker?include_events=AFTER_RESTART'.format(
        mock_endpoint.url, sender_id)
    with aioresponses() as mocked:
        mocked.get(url,
                   body=tracker_dump,
                   headers={"Accept": "application/json"})

        await interactive._print_history(sender_id, mock_endpoint)

        assert latest_request(mocked, 'get', url) is not None
Exemplo n.º 25
0
def test_read_json_dump(default_agent):
    json_content = utils.read_file("data/test_trackers/tracker_moodbot.json")
    tracker_json = json.loads(json_content)
    sender_id = tracker_json.get("sender_id", UserMessage.DEFAULT_SENDER_ID)
    restored_tracker = DialogueStateTracker.from_dict(
        sender_id, tracker_json.get("events", []), default_agent.domain)

    assert len(restored_tracker.events) == 7
    assert restored_tracker.latest_action_name == "action_listen"
    assert not restored_tracker.is_paused()
    assert restored_tracker.sender_id == "mysender"
    assert restored_tracker.events[-1].timestamp == 1517821726.211042

    restored_state = restored_tracker.current_state(should_include_events=True)
    assert restored_state == tracker_json
Exemplo n.º 26
0
def test_read_json_dump(default_agent):
    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    tracker_json = json.loads(utils.read_file(tracker_dump))

    restored_tracker = restore.load_tracker_from_json(tracker_dump,
                                                      default_agent.domain)

    assert len(restored_tracker.events) == 7
    assert restored_tracker.latest_action_name == "action_listen"
    assert not restored_tracker.is_paused()
    assert restored_tracker.sender_id == "mysender"
    assert restored_tracker.events[-1].timestamp == 1517821726.211042

    restored_state = restored_tracker.current_state(should_include_events=True)
    assert restored_state == tracker_json
Exemplo n.º 27
0
async def test_is_listening_for_messages(mock_endpoint):
    tracker_dump = utils.read_file("data/test_trackers/tracker_moodbot.json")

    sender_id = uuid.uuid4().hex

    url = '{}/conversations/{}/tracker?include_events=APPLIED'.format(
        mock_endpoint.url, sender_id)
    with aioresponses() as mocked:
        mocked.get(url,
                   body=tracker_dump,
                   headers={"Content-Type": "application/json"})

        is_listening = await interactive.is_listening_for_message(
            sender_id, mock_endpoint)

        assert is_listening
Exemplo n.º 28
0
def test_dump_and_restore_as_json(default_agent, tmpdir_factory):
    trackers = default_agent.load_data(DEFAULT_STORIES_FILE)

    for tracker in trackers:
        out_path = tmpdir_factory.mktemp("tracker").join("dumped_tracker.json")

        dumped = tracker.current_state(should_include_events=True)
        utils.dump_obj_as_json_to_file(out_path.strpath, dumped)

        tracker_json = json.loads(utils.read_file(out_path.strpath))
        sender_id = tracker_json.get("sender_id",
                                     UserMessage.DEFAULT_SENDER_ID)
        restored_tracker = DialogueStateTracker.from_dict(
            sender_id, tracker_json.get("events", []), default_agent.domain)

        assert restored_tracker == tracker
Exemplo n.º 29
0
def test_print_history(mock_endpoint):
    tracker_dump = utils.read_file("data/test_trackers/tracker_moodbot.json")

    sender_id = uuid.uuid4().hex

    url = '{}/conversations/{}/tracker'.format(mock_endpoint.url, sender_id)
    httpretty.register_uri(httpretty.GET, url, body=tracker_dump)

    httpretty.enable()
    interactive._print_history(sender_id, mock_endpoint)
    httpretty.disable()

    b = httpretty.latest_requests[-1].body.decode("utf-8")
    assert b == ""
    assert (httpretty.latest_requests[-1].path ==
            "/conversations/{}/tracker?include_events=AFTER_RESTART"
            "".format(sender_id))
Exemplo n.º 30
0
def test_dump_and_restore_as_json(default_agent, tmpdir):
    trackers = training.extract_trackers(
        DEFAULT_STORIES_FILE, default_agent.domain, default_agent.featurizer,
        default_agent.interpreter, default_agent.policy_ensemble.max_history())

    out_path = tmpdir.join("dumped_tracker.json")

    for tracker in trackers:
        dumped = tracker.current_state(should_include_events=True)
        utils.dump_obj_as_json_to_file(out_path.strpath, dumped)

        tracker_json = json.loads(utils.read_file(out_path.strpath))
        sender_id = tracker_json.get("sender_id",
                                     UserMessage.DEFAULT_SENDER_ID)
        restored_tracker = DialogueStateTracker.from_dict(
            sender_id, tracker_json.get("events", []), default_agent.domain)

        assert restored_tracker == tracker
Exemplo n.º 31
0
def test_current_state_applied_events(default_agent):
    tracker_dump = "data/test_trackers/tracker_moodbot.json"
    tracker_json = json.loads(utils.read_file(tracker_dump))

    # add some events that result in other events not being applied anymore
    tracker_json["events"].insert(1, {"event": "restart"})
    tracker_json["events"].insert(7, {"event": "rewind"})
    tracker_json["events"].insert(8, {"event": "undo"})

    tracker = DialogueStateTracker.from_dict(tracker_json.get("sender_id"),
                                             tracker_json.get("events", []),
                                             default_agent.domain.slots)

    evts = [e.as_dict() for e in tracker.events]
    applied_events = [evts[2], evts[9]]

    state = tracker.current_state(EventVerbosity.APPLIED)
    assert state.get("events") == applied_events
Exemplo n.º 32
0
def test_concerts_online_example(tmpdir):
    sys.path.append("examples/concertbot/")
    from train_online import train_agent
    from rasa_core import utils

    story_path = tmpdir.join("stories.md").strpath

    with utilities.cwd("examples/concertbot"):
        msgs = iter(["/greet", "/greet", "/greet"])
        msgs_f = functools.partial(next, msgs)

        with utilities.mocked_cmd_input(
                utils,
                text=[
                    "2",  # action is wrong
                    "5",  # choose utter_goodbye action
                    "1"  # yes, action_listen is correct.
                ] * 2 + [  # repeat this twice
                    "0",  # export
                    story_path  # file path to export to
                ]):
            agent = train_agent()

            responses = agent.handle_text("/greet", sender_id="user1")
            assert responses[-1]['text'] == "hey there!"

            online.serve_agent(agent, get_next_message=msgs_f)

            # the model should have been retrained and the model should now
            # directly respond with goodbye
            responses = agent.handle_text("/greet", sender_id="user2")
            assert responses[-1]['text'] == "goodbye :("

            assert os.path.exists(story_path)
            print(utils.read_file(story_path))

            t = training.load_data(story_path,
                                   agent.domain,
                                   use_story_concatenation=False)
            assert len(t) == 1
            assert len(t[0].events) == 9
            assert t[0].events[5] == ActionExecuted("utter_goodbye")
            assert t[0].events[6] == ActionExecuted("action_listen")
Exemplo n.º 33
0
def test_concerts_online_example(tmpdir):
    sys.path.append("examples/concertbot/")
    from train_online import train_agent
    from rasa_core import utils

    story_path = tmpdir.join("stories.md").strpath

    with utilities.cwd("examples/concertbot"):
        msgs = iter(["/greet", "/greet", "/greet"])
        msgs_f = functools.partial(next, msgs)

        with utilities.mocked_cmd_input(
                utils,
                text=["2",  # action is wrong
                      "5",  # choose utter_goodbye action
                      "1"  # yes, action_listen is correct.
                      ] * 2 + [  # repeat this twice
                         "0",  # export
                         story_path  # file path to export to
                     ]):
            agent = train_agent()

            responses = agent.handle_text("/greet", sender_id="user1")
            assert responses[-1]['text'] == "hey there!"

            online.serve_agent(agent, get_next_message=msgs_f)

            # the model should have been retrained and the model should now
            # directly respond with goodbye
            responses = agent.handle_text("/greet", sender_id="user2")
            assert responses[-1]['text'] == "goodbye :("

            assert os.path.exists(story_path)
            print(utils.read_file(story_path))

            t = training.load_data(story_path, agent.domain,
                                   use_story_concatenation=False)
            assert len(t) == 1
            assert len(t[0].events) == 9
            assert t[0].events[5] == ActionExecuted("utter_goodbye")
            assert t[0].events[6] == ActionExecuted("action_listen")
Exemplo n.º 34
0
 def load(cls, filename):
     if not os.path.isfile(filename):
         raise Exception(
                 "Failed to load domain specification from '{}'. "
                 "File not found!".format(os.path.abspath(filename)))
     return cls.from_yaml(read_file(filename))
Exemplo n.º 35
0
def test_restaurant_domain_is_valid():
    # should raise no exception
    TemplateDomain.validate_domain_yaml(read_file(
            'examples/restaurantbot/restaurant_domain.yml'))