예제 #1
0
def test_custom_slot_type(tmpdir):
    domain_path = utilities.write_text_to_file(tmpdir, "domain.yml", """
       slots:
         custom:
           type: tests.conftest.CustomSlot

       templates:
         utter_greet:
           - hey there!

       actions:
         - utter_greet """)
    Domain.load(domain_path)
예제 #2
0
def test_custom_slot_type(tmpdir):
    domain_path = utilities.write_text_to_file(tmpdir, "domain.yml", """
       slots:
         custom:
           type: tests.conftest.CustomSlot

       templates:
         utter_greet:
           - hey there!

       actions:
         - utter_greet """)
    Domain.load(domain_path)
예제 #3
0
def test_domain_fails_on_unknown_custom_slot_type(tmpdir):
    domain_path = utilities.write_text_to_file(tmpdir, "domain.yml", """
        slots:
            custom:
             type: tests.conftest.Unknown

        templates:
            utter_greet:
             - hey there!

        actions:
            - utter_greet""")
    with pytest.raises(ValueError):
        Domain.load(domain_path)
예제 #4
0
def test_domain_fails_on_unknown_custom_slot_type(tmpdir):
    domain_path = utilities.write_text_to_file(tmpdir, "domain.yml", """
        slots:
            custom:
             type: tests.conftest.Unknown

        templates:
            utter_greet:
             - hey there!

        actions:
            - utter_greet""")
    with pytest.raises(ValueError):
        Domain.load(domain_path)
예제 #5
0
def _load_and_set_updated_model(agent: 'Agent', model_directory: Text,
                                fingerprint: Text):
    """Load the persisted model into memory and set the model on the agent."""

    logger.debug("Found new model with fingerprint {}. Loading..."
                 "".format(fingerprint))

    stack_model_directory = _get_stack_model_directory(model_directory)
    if stack_model_directory:
        from rasa_core.interpreter import RasaNLUInterpreter
        nlu_model = os.path.join(stack_model_directory, "nlu")
        core_model = os.path.join(stack_model_directory, "core")
        interpreter = RasaNLUInterpreter(model_directory=nlu_model)
    else:
        interpreter = agent.interpreter
        core_model = model_directory

    domain_path = os.path.join(os.path.abspath(core_model), "domain.yml")
    domain = Domain.load(domain_path)

    # noinspection PyBroadException
    try:
        policy_ensemble = PolicyEnsemble.load(core_model)
        agent.update_model(domain, policy_ensemble, fingerprint, interpreter)
        logger.debug("Finished updating agent to new model.")
    except Exception:
        logger.exception("Failed to load policy and update agent. "
                         "The previous model will stay loaded instead.")
예제 #6
0
def test_tracker_store(store, pair):
    filename, domainpath = pair
    domain = Domain.load(domainpath)
    tracker = tracker_from_dialogue_file(filename, domain)
    store.save(tracker)
    restored = store.retrieve(tracker.sender_id)
    assert restored == tracker
예제 #7
0
def test_inmemory_tracker_store(filename):
    domain = Domain.load("data/test_domains/default.yml")
    tracker = tracker_from_dialogue_file(filename, domain)
    tracker_store = InMemoryTrackerStore(domain)
    tracker_store.save(tracker)
    restored = tracker_store.retrieve(tracker.sender_id)
    assert restored == tracker
예제 #8
0
 def trained_policy(self, featurizer, priority):
     default_domain = Domain.load(DEFAULT_DOMAIN_PATH)
     policy = self.create_policy(featurizer, priority)
     training_trackers = train_trackers(default_domain,
                                        augmentation_factor=20)
     policy.train(training_trackers, default_domain)
     return policy
예제 #9
0
    async def load_model(request: Request):
        """Loads a zipped model, replacing the existing one."""

        if 'model' not in request.files:
            # model file is missing
            raise ErrorResponse(400, "InvalidParameter",
                                "You did not supply a model as part of your "
                                "request.",
                                {"parameter": "model", "in": "body"})

        model_file = request.files['model']

        logger.info("Received new model through REST interface.")
        zipped_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip")
        zipped_path.close()
        model_directory = tempfile.mkdtemp()

        model_file.save(zipped_path.name)

        logger.debug("Downloaded model to {}".format(zipped_path.name))

        zip_ref = zipfile.ZipFile(zipped_path.name, 'r')
        zip_ref.extractall(model_directory)
        zip_ref.close()
        logger.debug("Unzipped model to {}".format(
            os.path.abspath(model_directory)))

        domain_path = os.path.join(os.path.abspath(model_directory),
                                   "domain.yml")
        domain = Domain.load(domain_path)
        ensemble = PolicyEnsemble.load(model_directory)
        app.agent.update_model(domain, ensemble, None)
        logger.debug("Finished loading new agent.")
        return response.text('', 204)
예제 #10
0
def test_inmemory_tracker_store(filename):
    domain = Domain.load("data/test_domains/default.yml")
    tracker = tracker_from_dialogue_file(filename, domain)
    tracker_store = InMemoryTrackerStore(domain)
    tracker_store.save(tracker)
    restored = tracker_store.retrieve(tracker.sender_id)
    assert restored == tracker
예제 #11
0
파일: agent.py 프로젝트: punitcs81/chatbot
    def load(cls,
             path: Text,
             interpreter: Optional[NaturalLanguageInterpreter] = None,
             generator: Union[EndpointConfig, 'NLG'] = None,
             tracker_store: Optional['TrackerStore'] = None,
             action_endpoint: Optional[EndpointConfig] = None,
             ) -> 'Agent':
        """Load a persisted model from the passed path."""

        if not path:
            raise ValueError("You need to provide a valid directory where "
                             "to load the agent from when calling "
                             "`Agent.load`.")

        if os.path.isfile(path):
            raise ValueError("You are trying to load a MODEL from a file "
                             "('{}'), which is not possible. \n"
                             "The persisted path should be a directory "
                             "containing the various model files. \n\n"
                             "If you want to load training data instead of "
                             "a model, use `agent.load_data(...)` "
                             "instead.".format(path))

        domain = Domain.load(os.path.join(path, "domain.yml"))
        ensemble = PolicyEnsemble.load(path) if path else None

        # ensures the domain hasn't changed between test and train
        domain.compare_with_specification(path)

        return cls(domain=domain,
                   policies=ensemble,
                   interpreter=interpreter,
                   generator=generator,
                   tracker_store=tracker_store,
                   action_endpoint=action_endpoint)
 def __init__(self, log):
     self.logger = log
     directorioNLU = 'model/default/Jarvis'
     directorioDialogo = 'model/dialogue'
     if (os.path.isdir(directorioNLU)):
         self.interpreter = RasaNLUInterpreter(
             model_directory=directorioNLU)
         if (os.path.isdir(directorioDialogo)):
             with open("config/endpoint.yml", 'r') as stream:
                 try:
                     config = yaml.safe_load(stream)
                 except yaml.YAMLError as exc:
                     print(exc)
             action_endopoint = EndpointConfig(
                 url=config["action_endpoint"]["url"])
             tracker_store = MongoTrackerStore(
                 domain=Domain.load('model/dialogue/domain.yml'),
                 host=config["tracker_store"]["url"],
                 db=config["tracker_store"]["db"],
                 username=config["tracker_store"]["username"],
                 password=config["tracker_store"]["password"])
             self.agent = Agent.load(directorioDialogo,
                                     interpreter=self.interpreter,
                                     action_endpoint=action_endopoint,
                                     tracker_store=tracker_store)
             self._slots = {}
예제 #13
0
 def trained_policy(self, featurizer):
     default_domain = Domain.load(DEFAULT_DOMAIN_PATH)
     policy = self.create_policy(featurizer)
     training_trackers = train_trackers(default_domain)
     policy.train(training_trackers, default_domain,
                  attn_before_rnn=True,
                  attn_after_rnn=True)
     return policy
예제 #14
0
def test_inmemory_tracker_store(pair):
    filename, domainpath = pair
    domain = Domain.load(domainpath)
    tracker = tracker_from_dialogue_file(filename, domain)
    tracker_store = InMemoryTrackerStore(domain)
    tracker_store.save(tracker)
    restored = tracker_store.retrieve(tracker.sender_id)
    assert restored == tracker
예제 #15
0
 def trained_policy(self, featurizer):
     default_domain = Domain.load(DEFAULT_DOMAIN_PATH)
     policy = self.create_policy(featurizer)
     training_trackers = train_trackers(default_domain)
     policy.train(training_trackers, default_domain,
                  attn_before_rnn=True,
                  attn_after_rnn=True)
     return policy
예제 #16
0
def test_action():
    domain = Domain.load('domain.yml')
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    dispatcher = Dispatcher("my-sender", CollectingOutputChannel(), nlg)
    uid = str(uuid.uuid1())
    tracker = DialogueStateTracker(uid, domain.slots)
    # print ("dispatcher,uid,tracker ===", dispatcher, uid, tracker)
    action = QuoraSearch()
    action.run(dispatcher, tracker, domain)
예제 #17
0
def tracker_from_dialogue_file(filename: Text, domain: Domain = None):
    dialogue = read_dialogue_file(filename)

    if not domain:
        domain = Domain.load(DEFAULT_DOMAIN_PATH)

    tracker = DialogueStateTracker(dialogue.name, domain.slots)
    tracker.recreate_from_dialogue(dialogue)
    return tracker
예제 #18
0
def test_utter_templates():
    domain_file = "examples/moodbot/domain.yml"
    domain = Domain.load(domain_file)
    expected_template = {
        "text": "Hey! How are you?",
        "buttons": [{"title": "great", "payload": "great"},
                    {"title": "super sad", "payload": "super sad"}]
    }
    assert domain.random_template_for("utter_greet") == expected_template
예제 #19
0
def test_utter_templates():
    domain_file = "examples/moodbot/domain.yml"
    domain = Domain.load(domain_file)
    expected_template = {
        "text": "Hey! How are you?",
        "buttons": [{"title": "great", "payload": "great"},
                    {"title": "super sad", "payload": "super sad"}]
    }
    assert domain.random_template_for("utter_greet") == expected_template
예제 #20
0
def tracker_from_dialogue_file(filename, domain=None):
    dialogue = read_dialogue_file(filename)

    if domain is not None:
        domain = domain
    else:
        domain = Domain.load(DEFAULT_DOMAIN_PATH)
    tracker = DialogueStateTracker(dialogue.name, domain.slots)
    tracker.recreate_from_dialogue(dialogue)
    return tracker
def test_action():
    domain = Domain.load('domain.yml')
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    dispatcher = Dispatcher("my-sender", CollectingOutputChannel(), nlg)
    uid = str(uuid.uuid1())
    tracker = DialogueStateTracker(uid, domain.slots)

    action = ActionJoke()
    action.run(dispatcher, tracker, domain)

    assert 'norris' in dispatcher.output_channel.latest_output()['text'].lower()
예제 #22
0
파일: agent.py 프로젝트: punitcs81/chatbot
    def _create_domain(domain: Union[None, Domain, Text]) -> Domain:

        if isinstance(domain, str):
            return Domain.load(domain)
        elif isinstance(domain, Domain):
            return domain
        elif domain is not None:
            raise ValueError(
                "Invalid param `domain`. Expected a path to a domain "
                "specification or a domain instance. But got "
                "type '{}' with value '{}'".format(type(domain), domain))
예제 #23
0
    def _create_domain(domain):
        # type: (Union[None, Domain, Text]) -> Domain

        if isinstance(domain, string_types):
            return Domain.load(domain)
        elif isinstance(domain, Domain):
            return domain
        elif domain is not None:
            raise ValueError(
                    "Invalid param `domain`. Expected a path to a domain "
                    "specification or a domain instance. But got "
                    "type '{}' with value '{}'".format(type(domain), domain))
예제 #24
0
def test_dispatcher_utter_buttons_from_domain_templ(default_tracker):
    domain_file = "examples/moodbot/domain.yml"
    domain = Domain.load(domain_file)
    bot = CollectingOutputChannel()
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    dispatcher = Dispatcher("my-sender", bot, nlg)
    dispatcher.utter_template("utter_greet", default_tracker)
    assert len(bot.messages) == 1
    assert bot.messages[0]['text'] == "Hey! How are you?"
    assert bot.messages[0]['buttons'] == [
        {'payload': 'great', 'title': 'great'},
        {'payload': 'super sad', 'title': 'super sad'}
    ]
예제 #25
0
def test_dispatcher_utter_buttons_from_domain_templ(default_tracker):
    domain_file = "examples/moodbot/domain.yml"
    domain = Domain.load(domain_file)
    bot = CollectingOutputChannel()
    nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    dispatcher = Dispatcher("my-sender", bot, nlg)
    dispatcher.utter_template("utter_greet", default_tracker)
    assert len(bot.messages) == 1
    assert bot.messages[0]['text'] == "Hey! How are you?"
    assert bot.messages[0]['data'] == [
        {'payload': 'great', 'title': 'great'},
        {'payload': 'super sad', 'title': 'super sad'}
    ]
예제 #26
0
    def test_memorise(self, trained_policy, default_domain):
        domain = Domain.load('data/test_domains/form.yml')
        trackers = training.load_data('data/test_stories/stories_form.md',
                                      domain)
        trained_policy.train(trackers, domain)

        (all_states, all_actions) = \
            trained_policy.featurizer.training_states_and_actions(
                trackers, domain)

        for tracker, states, actions in zip(trackers, all_states, all_actions):
            for state in states:
                if state is not None:
                    # check that 'form: inform' was ignored
                    assert 'intent_inform' not in state.keys()
            recalled = trained_policy.recall(states, tracker, domain)
            active_form = trained_policy._get_active_form_name(states[-1])

            if states[0] is not None and states[-1] is not None:
                # explicitly set intents and actions before listen after
                # which FormPolicy should not predict a form action and
                # should add FormValidation(False) event
                is_no_validation = (
                    ('prev_some_form' in states[0].keys() and
                     'intent_default' in states[-1].keys()) or
                    ('prev_some_form' in states[0].keys() and
                     'intent_stop' in states[-1].keys()) or
                    ('prev_utter_ask_continue' in states[0].keys() and
                     'intent_affirm' in states[-1].keys()) or
                    ('prev_utter_ask_continue' in states[0].keys() and
                     'intent_deny' in states[-1].keys())
                )
            else:
                is_no_validation = False

            if 'intent_start_form' in states[-1]:
                # explicitly check that intent that starts the form
                # is not memorized as non validation intent
                assert recalled is None
            elif is_no_validation:
                assert recalled == active_form
            else:
                assert recalled is None

        nums = np.random.randn(domain.num_states)
        random_states = [{f: num
                          for f, num in
                          zip(domain.input_states, nums)}]
        assert trained_policy.recall(random_states, None, domain) is None
예제 #27
0
def get_domain(path):
    if not path:
        raise ValueError("You need to provide a valid directory where "
                         "to load the agent from when calling "
                         "`Agent.load`.")

    if os.path.isfile(path):
        raise ValueError("You are trying to load a MODEL from a file "
                         "('{}'), which is not possible. \n"
                         "The persisted path should be a directory "
                         "containing the various model files. \n\n"
                         "If you want to load training data instead of "
                         "a model, use `agent.load_data(...)` "
                         "instead.".format(path))

    domain = Domain.load(os.path.join(path, "domain.yml"))
    return domain
예제 #28
0
def _update_model_from_server(model_server: EndpointConfig,
                              agent: 'Agent') -> None:
    """Load a zipped Rasa Core model from a URL and update the passed agent."""

    if not is_url(model_server.url):
        raise InvalidURL(model_server.url)

    model_directory = tempfile.mkdtemp()

    new_model_fingerprint = _pull_model_and_fingerprint(
        model_server, model_directory, agent.fingerprint)
    if new_model_fingerprint:
        domain_path = os.path.join(os.path.abspath(model_directory),
                                   "domain.yml")
        domain = Domain.load(domain_path)
        policy_ensemble = PolicyEnsemble.load(model_directory)
        agent.update_model(domain, policy_ensemble, new_model_fingerprint)
    else:
        logger.debug("No new model found at "
                     "URL {}".format(model_server.url))
예제 #29
0
def _update_model_from_server(
        model_server,  # type: EndpointConfig
        agent,  # type: Agent
):
    # type: (...) -> None
    """Load a zipped Rasa Core model from a URL and update the passed agent."""

    if not is_url(model_server.url):
        raise InvalidURL(model_server.url)

    model_directory = tempfile.mkdtemp()

    new_model_fingerprint = _pull_model_and_fingerprint(
            model_server, model_directory, agent.fingerprint)
    if new_model_fingerprint:
        domain_path = os.path.join(os.path.abspath(model_directory),
                                   "domain.yml")
        domain = Domain.load(domain_path)
        policy_ensemble = PolicyEnsemble.load(model_directory)
        agent.update_model(domain, policy_ensemble, new_model_fingerprint)
    else:
        logger.debug("No new model found at "
                     "URL {}".format(model_server.url))
예제 #30
0
    def load(cls,
             path,  # type: Text
             interpreter=None,  # type: Optional[NaturalLanguageInterpreter]
             generator=None,  # type: Union[EndpointConfig, NLG]
             tracker_store=None,  # type: Optional[TrackerStore]
             action_endpoint=None,  # type: Optional[EndpointConfig]
             ):
        # type: (...) -> Agent
        """Load a persisted model from the passed path."""

        if not path:
            raise ValueError("You need to provide a valid directory where "
                             "to load the agent from when calling "
                             "`Agent.load`.")

        if os.path.isfile(path):
            raise ValueError("You are trying to load a MODEL from a file "
                             "('{}'), which is not possible. \n"
                             "The persisted path should be a directory "
                             "containing the various model files. \n\n"
                             "If you want to load training data instead of "
                             "a model, use `agent.load_data(...)` "
                             "instead.".format(path))

        domain = Domain.load(os.path.join(path, "domain.yml"))
        ensemble = PolicyEnsemble.load(path) if path else None

        # ensures the domain hasn't changed between test and train
        domain.compare_with_specification(path)

        return cls(domain=domain,
                   policies=ensemble,
                   interpreter=interpreter,
                   generator=generator,
                   tracker_store=tracker_store,
                   action_endpoint=action_endpoint)
예제 #31
0
def test_policy_priority():
    domain = Domain.load("data/test_domains/default.yml")
    tracker = DialogueStateTracker.from_events("test", [UserUttered("hi")], [])

    priority_1 = ConstantPolicy(priority=1, predict_index=0)
    priority_2 = ConstantPolicy(priority=2, predict_index=1)

    policy_ensemble_0 = SimplePolicyEnsemble([priority_1, priority_2])
    policy_ensemble_1 = SimplePolicyEnsemble([priority_2, priority_1])

    priority_2_result = priority_2.predict_action_probabilities(
        tracker, domain)

    i = 1  # index of priority_2 in ensemble_0
    result, best_policy = policy_ensemble_0.probabilities_using_best_policy(
        tracker, domain)
    assert best_policy == 'policy_{}_{}'.format(i, type(priority_2).__name__)
    assert (result.tolist() == priority_2_result)

    i = 0  # index of priority_2 in ensemble_1
    result, best_policy = policy_ensemble_1.probabilities_using_best_policy(
        tracker, domain)
    assert best_policy == 'policy_{}_{}'.format(i, type(priority_2).__name__)
    assert (result.tolist() == priority_2_result)
예제 #32
0
def test_domain_from_template():
    domain_file = DEFAULT_DOMAIN_PATH
    domain = Domain.load(domain_file)
    assert len(domain.intents) == 10
    assert len(domain.action_names) == 6
예제 #33
0
def moodbot_domain():
    domain_path = os.path.join(MOODBOT_MODEL_PATH, 'domain.yml')
    return Domain.load(domain_path)
예제 #34
0
from rasa_core import training, restore
from rasa_core import utils
from rasa_core.actions.action import ActionListen, ACTION_LISTEN_NAME
from rasa_core.channels import UserMessage
from rasa_core.domain import Domain
from rasa_core.events import (
    UserUttered, ActionExecuted, Restarted, ActionReverted,
    UserUtteranceReverted)
from rasa_core.tracker_store import InMemoryTrackerStore, RedisTrackerStore
from rasa_core.tracker_store import (
    TrackerStore)
from rasa_core.trackers import DialogueStateTracker
from tests.conftest import DEFAULT_STORIES_FILE
from tests.utilities import tracker_from_dialogue_file, read_dialogue_file

domain = Domain.load("data/test_domains/default.yml")


class MockRedisTrackerStore(RedisTrackerStore):
    def __init__(self, domain):
        self.red = fakeredis.FakeStrictRedis()
        TrackerStore.__init__(self, domain)


def stores_to_be_tested():
    return [MockRedisTrackerStore(domain),
            InMemoryTrackerStore(domain)]


def stores_to_be_tested_ids():
    return ["redis-tracker",
예제 #35
0
    logging.basicConfig(level=logging.DEBUG)

    @app.route("/nlg", methods=['POST', 'OPTIONS'])
    def nlg():
        """Check if the server is running and responds with the version."""
        nlg_call = request.json

        response = generate_response(nlg_call, domain)
        return jsonify(response)

    return app


if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)

    # Running as standalone python application
    arg_parser = create_argument_parser()
    cmdline_args = arg_parser.parse_args()

    domain = Domain.load(cmdline_args.domain)
    app = create_app(domain)
    http_server = WSGIServer(('0.0.0.0', cmdline_args.port), app)

    http_server.start()
    logger.info("NLG endpoint is up and running. on {}"
                "".format(http_server.address))

    http_server.serve_forever()
예제 #36
0
    logging.basicConfig(level=logging.DEBUG)

    @app.route("/nlg", methods=['POST', 'OPTIONS'])
    def nlg():
        """Check if the server is running and responds with the version."""
        nlg_call = request.json

        response = generate_response(nlg_call, domain)
        return jsonify(response)

    return app


if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)

    # Running as standalone python application
    arg_parser = create_argument_parser()
    cmdline_args = arg_parser.parse_args()

    domain = Domain.load(cmdline_args.domain)
    app = create_app(domain)
    http_server = WSGIServer(('0.0.0.0', cmdline_args.port), app)

    http_server.start()
    logger.info("NLG endpoint is up and running. on {}"
                "".format(http_server.address))

    http_server.serve_forever()
예제 #37
0
from rasa_core import utils
from rasa_core.channels import UserMessage
from rasa_core.domain import Domain
from rasa_core.events import SlotSet, ActionExecuted, Restarted
from rasa_core.tracker_store import (TrackerStore, InMemoryTrackerStore,
                                     RedisTrackerStore)
from rasa_core.trackers import DialogueStateTracker
from rasa_core.utils import EndpointConfig
from tests.conftest import DEFAULT_ENDPOINTS_FILE

domain = Domain.load("data/test_domains/default.yml")


def test_get_or_create():
    slot_key = 'location'
    slot_val = 'Easter Island'
    store = InMemoryTrackerStore(domain)

    tracker = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID)
    ev = SlotSet(slot_key, slot_val)
    tracker.update(ev)
    assert tracker.get_slot(slot_key) == slot_val

    store.save(tracker)

    again = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID)
    assert again.get_slot(slot_key) == slot_val


def test_restart_after_retrieval_from_tracker_store(default_domain):
    store = InMemoryTrackerStore(default_domain)
예제 #38
0
from rasa_core import training, restore
from rasa_core import utils
from rasa_core.actions.action import ACTION_LISTEN_NAME
from rasa_core.domain import Domain
from rasa_core.events import (UserUttered, ActionExecuted, Restarted,
                              ActionReverted, UserUtteranceReverted)
from rasa_core.tracker_store import (InMemoryTrackerStore, RedisTrackerStore,
                                     SQLTrackerStore)
from rasa_core.tracker_store import TrackerStore
from rasa_core.trackers import DialogueStateTracker, EventVerbosity
from tests.conftest import (DEFAULT_STORIES_FILE, EXAMPLE_DOMAINS,
                            TEST_DIALOGUES)
from tests.utilities import (tracker_from_dialogue_file, read_dialogue_file,
                             user_uttered, get_tracker)

domain = Domain.load("examples/moodbot/domain.yml")


@pytest.fixture(scope="module")
def loop():
    from pytest_sanic.plugin import loop as sanic_loop
    return utils.enable_async_loop_debugging(next(sanic_loop()))


class MockRedisTrackerStore(RedisTrackerStore):
    def __init__(self, domain):
        self.red = fakeredis.FakeStrictRedis()
        self.record_exp = None
        TrackerStore.__init__(self, domain)

예제 #39
0
 def trained_policy(self, featurizer):
     default_domain = Domain.load(DEFAULT_DOMAIN_PATH)
     policy = self.create_policy(featurizer)
     training_trackers = train_trackers(default_domain)
     policy.train(training_trackers, default_domain)
     return policy
예제 #40
0
def test_tracker_restaurant():
    domain = Domain.load("data/test_domains/default_with_slots.yml")
    filename = 'data/test_dialogues/enter_name.json'
    tracker = tracker_from_dialogue_file(filename, domain)
    assert tracker.get_slot("name") == "holger"
    assert tracker.get_slot("location") is None     # slot doesn't exist!
예제 #41
0
 def trained_policy(self, featurizer):
     default_domain = Domain.load(DEFAULT_DOMAIN_PATH)
     policy = self.create_policy(featurizer)
     training_trackers = train_trackers(default_domain)
     policy.train(training_trackers, default_domain)
     return policy
예제 #42
0
def test_domain_from_template():
    domain_file = DEFAULT_DOMAIN_PATH
    domain = Domain.load(domain_file)
    assert len(domain.intents) == 10
    assert len(domain.action_names) == 10
예제 #43
0
def test_domain_fails_on_unknown_custom_slot_type(tmpdir,
                                                  domain_unkown_slot_type):
    domain_path = utilities.write_text_to_file(tmpdir, "domain.yml",
                                               domain_unkown_slot_type)
    with pytest.raises(ValueError):
        Domain.load(domain_path)
예제 #44
0
def test_tracker_restaurant():
    domain = Domain.load("data/test_domains/default_with_slots.yml")
    filename = 'data/test_dialogues/enter_name.json'
    tracker = tracker_from_dialogue_file(filename, domain)
    assert tracker.get_slot("name") == "holger"
    assert tracker.get_slot("location") is None     # slot doesn't exist!
예제 #45
0
 def default_domain(self):
     return Domain.load(DEFAULT_DOMAIN_PATH)
예제 #46
0
 def default_domain(self):
     return Domain.load(DEFAULT_DOMAIN_PATH)