コード例 #1
0
ファイル: conftest.py プロジェクト: attgua/Geco
def default_domain() -> Domain:
    return Domain.load(DEFAULT_DOMAIN_PATH_WITH_SLOTS)
コード例 #2
0
ファイル: test_domain.py プロジェクト: wavymazy/rasa
def test_merge_yaml_domains():
    test_yaml_1 = f"""config:
  store_entities_as_slots: true
entities: []
intents: []
slots: {{}}
responses:
  utter_greet:
  - text: hey there!
{KEY_E2E_ACTIONS}:
- Hi"""

    test_yaml_2 = f"""config:
  store_entities_as_slots: false
session_config:
    session_expiration_time: 20
    carry_over_slots: true
entities:
- cuisine
intents:
- greet
slots:
  cuisine:
    type: text
{KEY_E2E_ACTIONS}:
- Bye
responses:
  utter_goodbye:
  - text: bye!
  utter_greet:
  - text: hey you!"""

    domain_1 = Domain.from_yaml(test_yaml_1)
    domain_2 = Domain.from_yaml(test_yaml_2)
    domain = domain_1.merge(domain_2)
    # single attribute should be taken from domain_1
    assert domain.store_entities_as_slots
    # conflicts should be taken from domain_1
    assert domain.responses == {
        "utter_greet": [{
            "text": "hey there!"
        }],
        "utter_goodbye": [{
            "text": "bye!"
        }],
    }
    # lists should be deduplicated and merged
    assert domain.intents == sorted(["greet", *DEFAULT_INTENTS])
    assert domain.entities == ["cuisine"]
    assert isinstance(domain.slots[0], TextSlot)
    assert domain.slots[0].name == "cuisine"
    assert sorted(domain.user_actions) == sorted(
        ["utter_greet", "utter_goodbye"])
    assert domain.session_config == SessionConfig(20, True)

    domain = domain_1.merge(domain_2, override=True)
    # single attribute should be taken from domain_2
    assert not domain.store_entities_as_slots
    # conflicts should take value from domain_2
    assert domain.responses == {
        "utter_greet": [{
            "text": "hey you!"
        }],
        "utter_goodbye": [{
            "text": "bye!"
        }],
    }
    assert domain.session_config == SessionConfig(20, True)
    assert domain.action_texts == ["Bye", "Hi"]
コード例 #3
0
ファイル: test_domain.py プロジェクト: wavymazy/rasa
def test_load_on_invalid_domain_duplicate_entities():
    with pytest.raises(InvalidDomain):
        Domain.load("data/test_domains/duplicate_entities.yml")
コード例 #4
0
ファイル: test_domain.py プロジェクト: wavymazy/rasa
def test_valid_slot_mappings(domain_as_dict: Dict[Text, Any]):
    Domain.from_dict(domain_as_dict)
コード例 #5
0
ファイル: test_domain.py プロジェクト: wavymazy/rasa
def test_domain_from_template(domain: Domain):
    assert not domain.is_empty()
    assert len(domain.intents) == 10 + len(DEFAULT_INTENTS)
    assert len(domain.action_names_or_texts) == 16
コード例 #6
0
async def test_form_unhappy_path_no_validation_from_story():
    form_name = "some_form"
    handle_rejection_action_name = "utter_handle_rejection"

    domain = Domain.from_yaml(f"""
        intents:
        - {GREET_INTENT_NAME}
        actions:
        - {UTTER_GREET_ACTION}
        - {handle_rejection_action_name}
        - some-action
        slots:
          {REQUESTED_SLOT}:
            type: unfeaturized
        forms:
        - {form_name}
    """)

    unhappy_story = TrackerWithCachedStates.from_events(
        "bla",
        domain=domain,
        slots=domain.slots,
        evts=[
            # We are in an active form
            ActionExecuted(form_name),
            ActiveLoop(form_name),
            ActionExecuted(ACTION_LISTEN_NAME),
            # When a user says "hi", and the form is unhappy,
            # we want to run a specific action
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(handle_rejection_action_name),
            ActionExecuted(ACTION_LISTEN_NAME),
            # Next user utterance is an answer to the previous question
            # and shouldn't be validated by the form
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(form_name),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )

    policy = RulePolicy()
    policy.train([unhappy_story], domain, RegexInterpreter())

    # Check that RulePolicy predicts no validation to handle unhappy path
    conversation_events = [
        ActionExecuted(form_name),
        ActiveLoop(form_name),
        SlotSet(REQUESTED_SLOT, "some value"),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("haha", {"name": GREET_INTENT_NAME}),
        ActionExecutionRejected(form_name),
        ActionExecuted(handle_rejection_action_name),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("haha", {"name": GREET_INTENT_NAME}),
    ]

    tracker = DialogueStateTracker.from_events("casd",
                                               evts=conversation_events,
                                               slots=domain.slots)
    action_probabilities = policy.predict_action_probabilities(
        tracker, domain, RegexInterpreter())
    # there is no rule for next action
    assert max(action_probabilities) == policy._core_fallback_threshold
    # check that RulePolicy entered unhappy path based on the training story
    assert tracker.events[-1] == LoopInterrupted(True)
コード例 #7
0
ファイル: test_domain.py プロジェクト: wavymazy/rasa
def test_add_default_intents(domain_dict: Dict):
    domain = Domain.from_dict(domain_dict)

    assert all(intent_name in domain.intents
               for intent_name in DEFAULT_INTENTS)
コード例 #8
0
ファイル: multi_project.py プロジェクト: suryatmodulus/rasa
 async def get_domain(self) -> Domain:
     """Retrieves model domain (see parent class for full docstring)."""
     domains = [Domain.load(path) for path in self._domain_paths]
     return reduce(
         lambda merged, other: merged.merge(other), domains, Domain.empty()
     )
コード例 #9
0
def test_process_unpacks_attributes_from_single_message_and_fallsback_if_needed(
    regex_message_handler: RegexMessageHandler,
    confidence: Optional[Text],
    entities: Optional[Text],
    expected_confidence: float,
    expected_entities: Optional[List[Dict[Text, Any]]],
    should_warn: bool,
):

    # dummy intent
    expected_intent = "my-intent"

    # construct text according to pattern
    text = " \t  " + INTENT_MESSAGE_PREFIX + expected_intent
    if confidence is not None:
        text += f"@{confidence}"
    if entities is not None:
        text += entities
    text += " \t "

    # create a message with some dummy attributes and features
    message = Message(
        data={
            TEXT: text,
            INTENT: "extracted-from-the-pattern-text-via-nlu"
        },
        features=[
            Features(
                features=np.zeros((1, 1)),
                feature_type=FEATURE_TYPE_SENTENCE,
                attribute=TEXT,
                origin="nlu-pipeline",
            )
        ],
    )

    # construct domain from expected intent/entities
    domain_entities = [
        item[ENTITY_ATTRIBUTE_TYPE] for item in expected_entities
    ]
    domain_intents = [expected_intent] if expected_intent is not None else []
    domain = Domain(
        intents=domain_intents,
        entities=domain_entities,
        slots=[],
        responses={},
        action_names=[],
        forms={},
    )

    # extract information
    if should_warn:
        with pytest.warns(UserWarning):
            results = regex_message_handler.process([message], domain)
    else:
        results = regex_message_handler.process([message], domain)

    assert len(results) == 1
    unpacked_message = results[0]

    assert not unpacked_message.features

    assert set(unpacked_message.data.keys()) == {
        TEXT,
        INTENT,
        INTENT_RANKING_KEY,
        ENTITIES,
    }

    assert unpacked_message.data[TEXT] == message.data[TEXT].strip()

    assert set(unpacked_message.data[INTENT].keys()) == {
        INTENT_NAME_KEY,
        PREDICTED_CONFIDENCE_KEY,
    }
    assert unpacked_message.data[INTENT][INTENT_NAME_KEY] == expected_intent
    assert (unpacked_message.data[INTENT][PREDICTED_CONFIDENCE_KEY] ==
            expected_confidence)

    intent_ranking = unpacked_message.data[INTENT_RANKING_KEY]
    assert len(intent_ranking) == 1
    assert intent_ranking[0] == {
        INTENT_NAME_KEY: expected_intent,
        PREDICTED_CONFIDENCE_KEY: expected_confidence,
    }
    if expected_entities:
        entity_data: List[Dict[Text, Any]] = unpacked_message.data[ENTITIES]
        assert all(
            set(item.keys()) == {
                ENTITY_ATTRIBUTE_VALUE,
                ENTITY_ATTRIBUTE_TYPE,
                ENTITY_ATTRIBUTE_START,
                ENTITY_ATTRIBUTE_END,
            } for item in entity_data)
        assert set(
            (item[ENTITY_ATTRIBUTE_TYPE], item[ENTITY_ATTRIBUTE_VALUE])
            for item in expected_entities) == set(
                (item[ENTITY_ATTRIBUTE_TYPE], item[ENTITY_ATTRIBUTE_VALUE])
                for item in entity_data)
    else:
        assert unpacked_message.data[ENTITIES] is not None
        assert len(unpacked_message.data[ENTITIES]) == 0
コード例 #10
0
 async def get_domain(self) -> Domain:
     return Domain.empty()
コード例 #11
0
    async def get_domain(self) -> Domain:
        domains = [importer.get_domain() for importer in self._importers]
        domains = await asyncio.gather(*domains)

        return reduce(lambda merged, other: merged.merge(other), domains,
                      Domain.empty())
コード例 #12
0
ファイル: test_actions.py プロジェクト: sorumehta/rasa-server
def template_sender_tracker(default_domain_path: Text):
    domain = Domain.load(default_domain_path)
    return DialogueStateTracker("template-sender", domain.slots)
コード例 #13
0
def test_tracker_default():
    domain = Domain.load(DEFAULT_DOMAIN_PATH_WITH_SLOTS)
    filename = "data/test_dialogues/default.json"
    tracker = tracker_from_dialogue_file(filename, domain)
    assert tracker.get_slot("name") == "Peter"
    assert tracker.get_slot("price") is None  # slot doesn't exist!
コード例 #14
0
ファイル: conftest.py プロジェクト: attgua/Geco
def moodbot_domain() -> Domain:
    domain_path = os.path.join("examples", "moodbot", "domain.yml")
    return Domain.load(domain_path)
コード例 #15
0
async def test_form_unhappy_path_from_story():
    form_name = "some_form"
    handle_rejection_action_name = "utter_handle_rejection"

    domain = Domain.from_yaml(f"""
        intents:
        - {GREET_INTENT_NAME}
        actions:
        - {UTTER_GREET_ACTION}
        - {handle_rejection_action_name}
        - some-action
        slots:
          {REQUESTED_SLOT}:
            type: unfeaturized
        forms:
        - {form_name}
    """)

    unhappy_story = TrackerWithCachedStates.from_events(
        "bla",
        domain=domain,
        slots=domain.slots,
        evts=[
            # We are in an active form
            ActionExecuted(form_name),
            ActiveLoop(form_name),
            ActionExecuted(ACTION_LISTEN_NAME),
            # in training stories there is either intent or text, never both
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(UTTER_GREET_ACTION),
            # After our bot says "hi", we want to run a specific action
            ActionExecuted(handle_rejection_action_name),
            ActionExecuted(form_name),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )

    policy = RulePolicy()
    policy.train([GREET_RULE, unhappy_story], domain, RegexInterpreter())

    # Check that RulePolicy predicts action to handle unhappy path
    conversation_events = [
        ActionExecuted(form_name),
        ActiveLoop(form_name),
        SlotSet(REQUESTED_SLOT, "some value"),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("haha", {"name": GREET_INTENT_NAME}),
        ActionExecutionRejected(form_name),
    ]

    action_probabilities = policy.predict_action_probabilities(
        DialogueStateTracker.from_events("casd",
                                         evts=conversation_events,
                                         slots=domain.slots),
        domain,
        RegexInterpreter(),
    )
    assert_predicted_action(action_probabilities, domain, UTTER_GREET_ACTION)

    # Check that RulePolicy doesn't trigger form or action_listen
    # after handling unhappy path
    conversation_events.append(ActionExecuted(handle_rejection_action_name))
    action_probabilities = policy.predict_action_probabilities(
        DialogueStateTracker.from_events("casd",
                                         evts=conversation_events,
                                         slots=domain.slots),
        domain,
        RegexInterpreter(),
    )
    assert max(action_probabilities) == policy._core_fallback_threshold
コード例 #16
0
ファイル: test_model.py プロジェクト: sorumehta/rasa-server
    fingerprint = _fingerprint()
    output_directory = tempfile.mkdtemp()

    persist_fingerprint(output_directory, fingerprint)
    actual = fingerprint_from_path(output_directory)

    assert actual == fingerprint


@pytest.mark.parametrize(
    "fingerprint2, changed",
    [
        (_fingerprint(config=["other"]), True),
        (_fingerprint(config_core=["other"]), True),
        (_fingerprint(domain=["other"]), True),
        (_fingerprint(domain=Domain.empty()), True),
        (_fingerprint(stories=["test", "other"]), True),
        (_fingerprint(rasa_version="100"), True),
        (_fingerprint(config=["other"], domain=["other"]), True),
        (_fingerprint(nlg=["other"]), False),
        (_fingerprint(nlu=["test", "other"]), False),
        (_fingerprint(config_nlu=["other"]), False),
        (_fingerprint(config_without_epochs=["other"]), False),
    ],
)
def test_core_fingerprint_changed(fingerprint2: Fingerprint, changed: bool):
    fingerprint1 = _fingerprint()
    assert (did_section_fingerprint_change(fingerprint1, fingerprint2,
                                           SECTION_CORE) is changed)

コード例 #17
0
async def test_form_unhappy_path_no_validation_from_rule():
    form_name = "some_form"
    handle_rejection_action_name = "utter_handle_rejection"

    domain = Domain.from_yaml(f"""
        intents:
        - {GREET_INTENT_NAME}
        actions:
        - {UTTER_GREET_ACTION}
        - {handle_rejection_action_name}
        - some-action
        slots:
          {REQUESTED_SLOT}:
            type: unfeaturized
        forms:
        - {form_name}
    """)

    unhappy_rule = TrackerWithCachedStates.from_events(
        "bla",
        domain=domain,
        slots=domain.slots,
        evts=[
            # We are in an active form
            ActiveLoop(form_name),
            SlotSet(REQUESTED_SLOT, "bla"),
            ActionExecuted(RULE_SNIPPET_ACTION_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            # When a user says "hi", and the form is unhappy,
            # we want to run a specific action
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(handle_rejection_action_name),
            # Next user utterance is an answer to the previous question
            # and shouldn't be validated by the form
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(form_name),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
        is_rule_tracker=True,
    )
    # unhappy rule is multi user turn rule, therefore remove restriction for policy
    policy = RulePolicy(restrict_rules=False)
    # RulePolicy should memorize that unhappy_rule overrides GREET_RULE
    policy.train([GREET_RULE, unhappy_rule], domain, RegexInterpreter())

    # Check that RulePolicy predicts action to handle unhappy path
    conversation_events = [
        ActionExecuted(form_name),
        ActiveLoop(form_name),
        SlotSet(REQUESTED_SLOT, "some value"),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("haha", {"name": GREET_INTENT_NAME}),
        ActionExecutionRejected(form_name),
    ]

    action_probabilities = policy.predict_action_probabilities(
        DialogueStateTracker.from_events("casd",
                                         evts=conversation_events,
                                         slots=domain.slots),
        domain,
        RegexInterpreter(),
    )
    assert_predicted_action(action_probabilities, domain,
                            handle_rejection_action_name)

    # Check that RulePolicy predicts action_listen
    conversation_events.append(ActionExecuted(handle_rejection_action_name))
    action_probabilities = policy.predict_action_probabilities(
        DialogueStateTracker.from_events("casd",
                                         evts=conversation_events,
                                         slots=domain.slots),
        domain,
        RegexInterpreter(),
    )
    assert_predicted_action(action_probabilities, domain, ACTION_LISTEN_NAME)

    # Check that RulePolicy triggers form again after handling unhappy path
    conversation_events.append(ActionExecuted(ACTION_LISTEN_NAME))
    tracker = DialogueStateTracker.from_events("casd",
                                               evts=conversation_events,
                                               slots=domain.slots)
    action_probabilities = policy.predict_action_probabilities(
        tracker, domain, RegexInterpreter())
    assert_predicted_action(action_probabilities, domain, form_name)
    # check that RulePolicy entered unhappy path based on the training story
    assert tracker.events[-1] == LoopInterrupted(True)
コード例 #18
0
ファイル: test_model.py プロジェクト: sorumehta/rasa-server
 async def get_domain() -> Domain:
     return Domain.load(domain_with_categorical_slot_path)
コード例 #19
0
ファイル: test_domain.py プロジェクト: wavymazy/rasa
- greet
slots:
  cuisine:
    type: text
responses:
  utter_goodbye:
  - text: bye!
  utter_greet:
  - text: hey you!""")

    merged = Domain.empty().merge(domain)

    assert merged.as_dict() == domain.as_dict()


@pytest.mark.parametrize("other", [Domain.empty(), None])
def test_merge_with_empty_other_domain(other: Optional[Domain]):
    domain = Domain.from_yaml("""config:
  store_entities_as_slots: false
session_config:
    session_expiration_time: 20
    carry_over_slots: true
entities:
- cuisine
intents:
- greet
slots:
  cuisine:
    type: text
responses:
  utter_goodbye:
コード例 #20
0
def test_convert_config(
    run: Callable[..., RunResult], tmp_path: Path, default_domain_path: Text
):
    deprecated_config = {
        "policies": [{"name": "MappingPolicy"}, {"name": "FallbackPolicy"}],
        "pipeline": [{"name": "WhitespaceTokenizer"}],
    }
    config_file = tmp_path / "config.yml"
    rasa.shared.utils.io.write_yaml(deprecated_config, config_file)

    domain = Domain.from_dict(
        {
            "intents": [{"greet": {"triggers": "action_greet"}}, "leave"],
            "actions": ["action_greet"],
        }
    )
    domain_file = tmp_path / "domain.yml"
    domain.persist(domain_file)

    rules_file = tmp_path / "rules.yml"

    result = run(
        "data",
        "convert",
        "config",
        "--config",
        str(config_file),
        "--domain",
        str(domain_file),
        "--out",
        str(rules_file),
    )

    assert result.ret == 0
    new_config = rasa.shared.utils.io.read_yaml_file(config_file)
    new_domain = rasa.shared.utils.io.read_yaml_file(domain_file)
    new_rules = rasa.shared.utils.io.read_yaml_file(rules_file)

    assert new_config == {
        "policies": [
            {
                "name": "RulePolicy",
                "core_fallback_action_name": "action_default_fallback",
                "core_fallback_threshold": DEFAULT_CORE_FALLBACK_THRESHOLD,
            }
        ],
        "pipeline": [
            {"name": "WhitespaceTokenizer"},
            {
                "name": "FallbackClassifier",
                "ambiguity_threshold": DEFAULT_NLU_FALLBACK_AMBIGUITY_THRESHOLD,
                "threshold": DEFAULT_NLU_FALLBACK_THRESHOLD,
            },
        ],
    }
    assert new_domain["intents"] == ["greet", "leave"]
    assert new_rules == {
        "rules": [
            {
                "rule": "Rule to map `greet` intent to `action_greet` "
                "(automatic conversion)",
                "steps": [{"intent": "greet"}, {"action": "action_greet"}],
            },
            {
                "rule": "Rule to handle messages with low NLU confidence "
                "(automated conversion from 'FallbackPolicy')",
                "steps": [
                    {"intent": "nlu_fallback"},
                    {"action": "action_default_fallback"},
                ],
            },
        ],
        "version": LATEST_TRAINING_DATA_FORMAT_VERSION,
    }

    domain_backup = tmp_path / "domain.yml.bak"
    assert domain_backup.exists()

    config_backup = tmp_path / "config.yml.bak"
    assert config_backup.exists()
コード例 #21
0
ファイル: test_domain.py プロジェクト: wavymazy/rasa
def test_is_retrieval_intent_response(response_key, validation,
                                      domain: Domain):
    assert domain.is_retrieval_intent_response(
        (response_key, [{}])) == validation
コード例 #22
0
ファイル: test_structures.py プロジェクト: ncnynl/rasa
    SlotSet,
    UserUttered,
    ActionExecuted,
    DefinePrevUserUtteredFeaturization,
)
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.training_data.story_reader.yaml_story_reader import (
    YAMLStoryReader,
)
from rasa.shared.core.training_data.story_writer.yaml_story_writer import (
    YAMLStoryWriter,
)
from rasa.shared.core.training_data.structures import Story
from rasa.shared.nlu.constants import INTENT_NAME_KEY

domain = Domain.load("examples/moodbot/domain.yml")


def test_session_start_is_not_serialised(domain: Domain):
    tracker = DialogueStateTracker("default", domain.slots)
    # the retrieved tracker should be empty
    assert len(tracker.events) == 0

    # add SlotSet event
    tracker.update(SlotSet("slot", "value"))

    # add the two SessionStarted events and a user event
    tracker.update(ActionExecuted(ACTION_SESSION_START_NAME))
    tracker.update(SessionStarted())
    tracker.update(
        UserUttered("say something", intent={INTENT_NAME_KEY: "some_intent"})
コード例 #23
0
ファイル: test_domain.py プロジェクト: wavymazy/rasa
def test_form_invalid_mappings(domain_as_dict: Dict[Text, Any]):
    with pytest.raises(InvalidDomain):
        Domain.from_dict(domain_as_dict)
コード例 #24
0
def trained_rule_policy_domain() -> Domain:
    return Domain.load("examples/rules/domain.yml")
コード例 #25
0
ファイル: test_domain.py プロジェクト: wavymazy/rasa
def test_domain_fails_on_unknown_custom_slot_type(tmpdir,
                                                  domain_unkown_slot_type):
    domain_path = str(tmpdir / "domain.yml")
    rasa.shared.utils.io.write_text_file(domain_unkown_slot_type, domain_path)
    with pytest.raises(InvalidSlotTypeException):
        Domain.load(domain_path)
コード例 #26
0
async def test_one_stage_fallback_rule():
    domain = Domain.from_yaml(f"""
        intents:
        - {GREET_INTENT_NAME}
        - {DEFAULT_NLU_FALLBACK_INTENT_NAME}
        actions:
        - {UTTER_GREET_ACTION}
    """)

    fallback_recover_rule = TrackerWithCachedStates.from_events(
        "bla",
        domain=domain,
        slots=domain.slots,
        evts=[
            ActionExecuted(RULE_SNIPPET_ACTION_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(intent={"name": DEFAULT_NLU_FALLBACK_INTENT_NAME}),
            ActionExecuted(ACTION_DEFAULT_FALLBACK_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
        is_rule_tracker=True,
    )

    greet_rule_which_only_applies_at_start = TrackerWithCachedStates.from_events(
        "bla",
        domain=domain,
        evts=[
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(UTTER_GREET_ACTION),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
        is_rule_tracker=True,
    )
    policy = RulePolicy()
    policy.train(
        [greet_rule_which_only_applies_at_start, fallback_recover_rule],
        domain,
        RegexInterpreter(),
    )

    # RulePolicy predicts fallback action
    conversation_events = [
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("dasdakl;fkasd",
                    {"name": DEFAULT_NLU_FALLBACK_INTENT_NAME}),
    ]
    tracker = DialogueStateTracker.from_events("casd",
                                               evts=conversation_events,
                                               slots=domain.slots)
    action_probabilities = policy.predict_action_probabilities(
        tracker, domain, RegexInterpreter())
    assert_predicted_action(action_probabilities, domain,
                            ACTION_DEFAULT_FALLBACK_NAME)

    # Fallback action reverts fallback events, next action is `ACTION_LISTEN`
    conversation_events += await ActionDefaultFallback().run(
        CollectingOutputChannel(),
        TemplatedNaturalLanguageGenerator(domain.templates),
        tracker,
        domain,
    )

    # Rasa is back on track when user rephrased intent
    conversation_events += [
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("haha", {"name": GREET_INTENT_NAME}),
    ]
    tracker = DialogueStateTracker.from_events("casd",
                                               evts=conversation_events,
                                               slots=domain.slots)

    action_probabilities = policy.predict_action_probabilities(
        tracker, domain, RegexInterpreter())
    assert_predicted_action(action_probabilities, domain, UTTER_GREET_ACTION)
コード例 #27
0
ファイル: test_domain.py プロジェクト: wavymazy/rasa
def test_load_on_invalid_domain_duplicate_responses():
    with pytest.raises(YamlSyntaxException):
        Domain.load("data/test_domains/duplicate_responses.yml")
コード例 #28
0
async def test_form_unhappy_path_from_in_form_rule():
    form_name = "some_form"
    handle_rejection_action_name = "utter_handle_rejection"

    domain = Domain.from_yaml(f"""
        intents:
        - {GREET_INTENT_NAME}
        actions:
        - {UTTER_GREET_ACTION}
        - {handle_rejection_action_name}
        - some-action
        slots:
          {REQUESTED_SLOT}:
            type: unfeaturized
        forms:
        - {form_name}
    """)

    unhappy_rule = TrackerWithCachedStates.from_events(
        "bla",
        domain=domain,
        slots=domain.slots,
        evts=[
            # We are in an active form
            ActiveLoop(form_name),
            SlotSet(REQUESTED_SLOT, "bla"),
            ActionExecuted(RULE_SNIPPET_ACTION_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            # When a user says "hi", and the form is unhappy,
            # we want to run a specific action
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(handle_rejection_action_name),
            ActionExecuted(form_name),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
        is_rule_tracker=True,
    )

    policy = RulePolicy()
    # RulePolicy should memorize that unhappy_rule overrides GREET_RULE
    policy.train([GREET_RULE, unhappy_rule], domain, RegexInterpreter())

    # Check that RulePolicy predicts action to handle unhappy path
    conversation_events = [
        ActionExecuted(form_name),
        ActiveLoop(form_name),
        SlotSet(REQUESTED_SLOT, "some value"),
        ActionExecuted(ACTION_LISTEN_NAME),
        UserUttered("haha", {"name": GREET_INTENT_NAME}),
        ActionExecutionRejected(form_name),
    ]

    action_probabilities = policy.predict_action_probabilities(
        DialogueStateTracker.from_events("casd",
                                         evts=conversation_events,
                                         slots=domain.slots),
        domain,
        RegexInterpreter(),
    )
    assert_predicted_action(action_probabilities, domain,
                            handle_rejection_action_name)

    # Check that RulePolicy triggers form again after handling unhappy path
    conversation_events.append(ActionExecuted(handle_rejection_action_name))
    action_probabilities = policy.predict_action_probabilities(
        DialogueStateTracker.from_events("casd",
                                         evts=conversation_events,
                                         slots=domain.slots),
        domain,
        RegexInterpreter(),
    )
    assert_predicted_action(action_probabilities, domain, form_name)
コード例 #29
0
ファイル: test_domain.py プロジェクト: wavymazy/rasa
def test_is_empty():
    assert Domain.empty().is_empty()
コード例 #30
0
ファイル: test_domain.py プロジェクト: ysinjab/rasa
def test_is_retrieval_intent_template(template_key, validation):
    domain = Domain.load(DEFAULT_DOMAIN_PATH_WITH_SLOTS)
    assert domain.is_retrieval_intent_template(
        (template_key, [{}])) == validation