def test_extract_requested_slot_default(): """Test default extraction of a slot value from entity with the same name """ spec = {"name": "default_form"} form, tracker = new_form_and_tracker(spec, "some_slot") tracker.update( UserUttered(entities=[{"entity": "some_slot", "value": "some_value"}]) ) slot_values = form.extract_requested_slot( OutputChannel(), nlg, tracker, Domain.empty() ) assert slot_values == {"some_slot": "some_value"}
def test_provide_removes_or_replaces_expected_information( default_model_storage: ModelStorage, default_execution_context: ExecutionContext, input_domain: Union[Text, Dict], ): # prepare input if isinstance(input_domain, str): original_domain = Domain.from_file(path=input_domain) else: original_domain = Domain.from_dict(input_domain) # pass through component component = DomainForCoreTrainingProvider.create( {"arbitrary-unused": 234}, default_model_storage, Resource("xy"), default_execution_context, ) modified_domain = component.provide(domain=original_domain) # convert to dict for comparison modified_dict = modified_domain.as_dict() original_dict = original_domain.as_dict() default_dict = Domain.empty().as_dict() assert sorted(original_dict.keys()) == sorted(modified_dict.keys()) for key in original_dict.keys(): # replaced with default values if key in ["config", SESSION_CONFIG_KEY]: assert modified_dict[key] == default_dict[key] # for responses, we only keep the keys elif key == KEY_RESPONSES: assert set(modified_dict[key].keys()) == set(original_dict[key].keys()) for sub_key in original_dict[key]: assert modified_dict[key][sub_key] == [] # for forms, we only keep the keys (and the Domain will add a default key) elif key == KEY_FORMS: assert set(modified_dict[key].keys()) == set(original_dict[key].keys()) for sub_key in original_dict[key]: assert set(modified_dict[key][sub_key].keys()) == {REQUIRED_SLOTS_KEY} assert modified_dict[key][sub_key][REQUIRED_SLOTS_KEY] == [] # everything else remains unchanged else: assert original_dict[key] == modified_dict[key]
async def test_only_getting_e2e_conversation_tests_if_e2e_enabled( tmp_path: Path): from rasa.shared.core.training_data.structures import StoryGraph import rasa.shared.core.training_data.loading as core_loading config = {"imports": ["bots/Bot A"]} config_path = str(tmp_path / "config.yml") utils.dump_obj_as_yaml_to_file(config_path, config) story_file = tmp_path / "bots" / "Bot A" / "data" / "stories.yml" story_file.parent.mkdir(parents=True) rasa.shared.utils.io.write_text_file( """ stories: - story: story steps: - intent: greet - action: utter_greet """, story_file, ) test_story = """ stories: - story: story test steps: - user: hello intent: greet - action: utter_greet """ story_test_file = tmp_path / "bots" / "Bot A" / "test_stories.yml" rasa.shared.utils.io.write_text_file(test_story, story_test_file) selector = MultiProjectImporter(config_path) story_steps = await core_loading.load_data_from_resource( resource=str(story_test_file), domain=Domain.empty(), template_variables=None, use_e2e=True, exclusion_percentage=None, ) expected = StoryGraph(story_steps) actual = await selector.get_stories(use_e2e=True) assert expected.as_story_string() == actual.as_story_string()
def _create_domain(domain: Union[Domain, Text, None]) -> Domain: if isinstance(domain, str): domain = Domain.load(domain) domain.check_missing_responses() return domain elif isinstance(domain, Domain): return domain elif domain is None: return Domain.empty() else: raise InvalidParameterException( f"Invalid param `domain`. Expected a path to a domain " f"specification or a domain instance. But got " f"type '{type(domain)}' with value '{domain}'." )
def test_unfeaturized_slot_in_domain_warnings(): # create empty domain domain = Domain.empty() # add one unfeaturized and one text slot unfeaturized_slot = UnfeaturizedSlot("unfeaturized_slot", "value1") text_slot = TextSlot("text_slot", "value2") domain.slots.extend([unfeaturized_slot, text_slot]) # ensure both are in domain assert all(slot in domain.slots for slot in (unfeaturized_slot, text_slot)) # text slot should appear in domain warnings, unfeaturized slot should not in_domain_slot_warnings = domain.domain_warnings()["slot_warnings"]["in_domain"] assert text_slot.name in in_domain_slot_warnings assert unfeaturized_slot.name not in in_domain_slot_warnings
def _create_domain(domain: Union[Domain, Text, None]) -> Domain: if isinstance(domain, str): domain = Domain.load(domain) domain.check_missing_templates() return domain elif isinstance(domain, Domain): return domain elif domain is None: return Domain.empty() else: raise ValueError( "Invalid param `domain`. Expected a path to a domain " "specification or a domain instance. But got " "type '{}' with value '{}'".format(type(domain), domain) )
def get_domain(self) -> Domain: """Retrieves model domain (see parent class for full docstring).""" domain = Domain.empty() # If domain path is None, return an empty domain if not self._domain_path: return domain try: domain = Domain.load(self._domain_path) except InvalidDomain as e: rasa.shared.utils.io.raise_warning( f"Loading domain from '{self._domain_path}' failed. Using " f"empty domain. Error: '{e}'" ) return domain
async def test_ask_affirmation(): tracker = DialogueStateTracker.from_events( "some-sender", evts=_message_requiring_fallback()) domain = Domain.empty() action = TwoStageFallbackAction() events = await action.run( CollectingOutputChannel(), TemplatedNaturalLanguageGenerator(domain.templates), tracker, domain, ) assert len(events) == 2 assert events[0] == ActiveLoop(ACTION_TWO_STAGE_FALLBACK_NAME) assert isinstance(events[1], BotUttered)
async def test_ask_rephrasing_successful(default_processor: MessageProcessor): tracker = DialogueStateTracker.from_events( "some-sender", evts=[ ActionExecuted(ACTION_LISTEN_NAME), UserUttered("my name is John", { "name": "say_name", "confidence": 1.0 }), SlotSet("some_slot", "example_value"), # User sends message with low NLU confidence *_message_requiring_fallback(), ActiveLoop(ACTION_TWO_STAGE_FALLBACK_NAME), # Action asks user to affirm *_two_stage_clarification_request(), ActionExecuted(ACTION_LISTEN_NAME), # User denies suggested intents UserUttered("hi", {"name": USER_INTENT_OUT_OF_SCOPE}), *_two_stage_clarification_request(), # Action asks user to rephrase ActionExecuted(ACTION_LISTEN_NAME), # User rephrases successfully UserUttered("hi", {"name": "greet"}), ], ) domain = Domain.empty() action = TwoStageFallbackAction() await default_processor._run_action( action, tracker, CollectingOutputChannel(), TemplatedNaturalLanguageGenerator(domain.responses), PolicyPrediction([], "some policy"), ) applied_events = tracker.applied_events() assert applied_events == [ ActionExecuted(ACTION_LISTEN_NAME), UserUttered("my name is John", { "name": "say_name", "confidence": 1.0 }), SlotSet("some_slot", "example_value"), ActionExecuted(ACTION_LISTEN_NAME), UserUttered("hi", {"name": "greet"}), ]
async def test_update_with_new_domain(trained_rasa_model: Text, tmpdir: Path): _ = model.unpack_model(trained_rasa_model, tmpdir) new_domain = Domain.empty() mocked_importer = Mock() async def get_domain() -> Domain: return new_domain mocked_importer.get_domain = get_domain await model.update_model_with_new_domain(mocked_importer, tmpdir) actual = Domain.load(tmpdir / DEFAULT_CORE_SUBDIRECTORY_NAME / DEFAULT_DOMAIN_PATH) assert actual.is_empty()
def test_extract_requested_slot_default(): """Test default extraction of a slot value from entity with the same name.""" form = FormAction("some form", None) tracker = DialogueStateTracker.from_events( "default", [ SlotSet(REQUESTED_SLOT, "some_slot"), UserUttered( "bla", entities=[{"entity": "some_slot", "value": "some_value"}] ), ActionExecuted(ACTION_LISTEN_NAME), ], ) slot_values = form.extract_requested_slot(tracker, Domain.empty()) assert slot_values == {"some_slot": "some_value"}
def __init__( self, domain: Optional[Domain], event_broker: Optional[EventBroker] = None, **kwargs: Dict[Text, Any], ) -> None: """Create a TrackerStore. Args: domain: The `Domain` to initialize the `DialogueStateTracker`. event_broker: An event broker to publish any new events to another destination. kwargs: Additional kwargs. """ self._domain = domain or Domain.empty() self.event_broker = event_broker self.max_event_history: Optional[int] = None
def test_domain_validation_with_invalid_marker(depth: int, max_branches: int, seed: int): rng = np.random.default_rng(seed=seed) marker, expected_size = generate_random_marker( depth=depth, max_branches=max_branches, rng=rng, possible_conditions=CONDITION_MARKERS, possible_operators=OPERATOR_MARKERS, constant_condition_text=None, constant_negated=None, ) domain = Domain.empty() with pytest.warns(None): is_valid = marker.validate_against_domain(domain) assert not is_valid
def test_get_next_action_probabilities_pass_policy_predictions_without_interpreter_arg( predict_function: Callable, ): policy = TEDPolicy() policy.predict_action_probabilities = predict_function ensemble = SimplePolicyEnsemble(policies=[policy]) interpreter = Mock() domain = Domain.empty() processor = MessageProcessor(interpreter, ensemble, domain, InMemoryTrackerStore(domain), Mock()) with pytest.warns(DeprecationWarning): processor._get_next_action_probabilities( DialogueStateTracker.from_events( "lala", [ActionExecuted(ACTION_LISTEN_NAME)]))
def _create_from_endpoint_config( endpoint_config: Optional[EndpointConfig] = None, domain: Optional[Domain] = None, event_broker: Optional[EventBroker] = None, ) -> "TrackerStore": """Given an endpoint configuration, create a proper tracker store object.""" domain = domain or Domain.empty() if endpoint_config is None or endpoint_config.type is None: # default tracker store if no type is set tracker_store = InMemoryTrackerStore(domain, event_broker) elif endpoint_config.type.lower() == "redis": tracker_store = RedisTrackerStore( domain=domain, host=endpoint_config.url, event_broker=event_broker, **endpoint_config.kwargs, ) elif endpoint_config.type.lower() == "mongod": tracker_store = MongoTrackerStore( domain=domain, host=endpoint_config.url, event_broker=event_broker, **endpoint_config.kwargs, ) elif endpoint_config.type.lower() == "sql": tracker_store = SQLTrackerStore( domain=domain, host=endpoint_config.url, event_broker=event_broker, **endpoint_config.kwargs, ) elif endpoint_config.type.lower() == "dynamo": tracker_store = DynamoTrackerStore( domain=domain, event_broker=event_broker, **endpoint_config.kwargs ) else: tracker_store = _load_from_module_name_in_endpoint_config( domain, endpoint_config, event_broker ) logger.debug(f"Connected to {tracker_store.__class__.__name__}.") return tracker_store
def __init__( self, sender_id: Text, slots: Optional[Iterable[Slot]], max_event_history: Optional[int] = None, domain: Optional[Domain] = None, is_augmented: bool = False, is_rule_tracker: bool = False, ) -> None: """Initializes a tracker with cached states.""" super().__init__( sender_id, slots, max_event_history, is_rule_tracker=is_rule_tracker ) self._states_for_hashing: Deque[FrozenState] = deque() self.domain = domain if domain is not None else Domain.empty() # T/F property to filter augmented stories self.is_augmented = is_augmented self.__skip_states = False
async def get_domain(self) -> Domain: domain = Domain.empty() if not self._domain_path: return domain try: domain = Domain.load(self._domain_path) domain.check_missing_templates() ## legacy form injection bf_forms = [] for slot in domain.slots: if slot.name == "bf_forms": bf_forms = slot.initial_value if bf_forms: domain.forms = {form.get("name"): form for form in bf_forms} domain.form_names = list(domain.forms.keys()) domain.action_names = ( domain._combine_user_with_default_actions( domain.user_actions) + domain.form_names) ## finally: return domain
async def test_ask_for_slot_if_not_utter_ask( monkeypatch: MonkeyPatch, default_nlg: TemplatedNaturalLanguageGenerator): action_from_name = Mock(return_value=action.ActionListen()) endpoint_config = Mock() monkeypatch.setattr(action, action.action_for_name_or_text.__name__, action_from_name) form = FormAction("my_form", endpoint_config) events = await form._ask_for_slot( Domain.empty(), default_nlg, CollectingOutputChannel(), "some slot", DialogueStateTracker.from_events("dasd", []), ) assert not events action_from_name.assert_not_called()
async def test_2nd_affirm_successful(): tracker = DialogueStateTracker.from_events( "some-sender", evts=[ # User sends message with low NLU confidence *_message_requiring_fallback(), ActiveLoop(ACTION_TWO_STAGE_FALLBACK_NAME), # Action asks user to affirm *_two_stage_clarification_request(), ActionExecuted(ACTION_LISTEN_NAME), # User denies suggested intents UserUttered("hi", {"name": USER_INTENT_OUT_OF_SCOPE}), # Action asks user to rephrase *_two_stage_clarification_request(), # User rephrased with low confidence *_message_requiring_fallback(), *_two_stage_clarification_request(), # Actions asks user to affirm for the last time ActionExecuted(ACTION_LISTEN_NAME), # User affirms successfully UserUttered("hi", {"name": "greet"}), ], ) domain = Domain.empty() action = TwoStageFallbackAction() events = await action.run( CollectingOutputChannel(), TemplatedNaturalLanguageGenerator(domain.templates), tracker, domain, ) for event in events: tracker.update(event) applied_events = tracker.applied_events() assert applied_events == [ ActionExecuted(ACTION_LISTEN_NAME), UserUttered("hi", {"name": "greet"}), ]
def test_tracker_store_retrieve_with_events_from_previous_sessions( tracker_store_type: Type[TrackerStore], tracker_store_kwargs: Dict): tracker_store = tracker_store_type(Domain.empty(), **tracker_store_kwargs) conversation_id = uuid.uuid4().hex tracker = DialogueStateTracker.from_events( conversation_id, [ ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), UserUttered("hi"), ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), ], ) tracker_store.save(tracker) actual = tracker_store.retrieve_full_tracker(conversation_id) assert len(actual.events) == len(tracker.events)
async def test_markers_cli_results_save_correctly(tmp_path: Path): domain = Domain.empty() store = InMemoryTrackerStore(domain) for i in range(5): tracker = DialogueStateTracker(str(i), None) tracker.update_with_events([SlotSet(str(j), "slot") for j in range(5)], domain) tracker.update(ActionExecuted(ACTION_SESSION_START_NAME)) tracker.update(UserUttered("hello")) tracker.update_with_events( [SlotSet(str(5 + j), "slot") for j in range(5)], domain) await store.save(tracker) tracker_loader = MarkerTrackerLoader(store, "all") results_path = tmp_path / "results.csv" markers = OrMarker(markers=[ SlotSetMarker("2", name="marker1"), SlotSetMarker("7", name="marker2") ]) await markers.evaluate_trackers(tracker_loader.load(), results_path) with open(results_path, "r") as results: result_reader = csv.DictReader(results) senders = set() for row in result_reader: senders.add(row["sender_id"]) if row["marker"] == "marker1": assert row["session_idx"] == "0" assert int(row["event_idx"]) >= 2 assert row["num_preceding_user_turns"] == "0" if row["marker"] == "marker2": assert row["session_idx"] == "1" assert int(row["event_idx"]) >= 3 assert row["num_preceding_user_turns"] == "1" assert len(senders) == 5
async def test_loop_without_deactivate(): expected_activation_events = [ ActionExecutionRejected("tada"), ActionExecuted("test"), ] expected_do_events = [ActionExecuted("do")] form_name = "my form" class MyLoop(LoopAction): def name(self) -> Text: return form_name async def activate(self, *args: Any) -> List[Event]: return expected_activation_events async def do(self, *args: Any) -> List[Event]: return expected_do_events async def deactivate(self, *args) -> List[Event]: raise ValueError("this shouldn't be called") async def is_done(self, *args) -> bool: return False tracker = DialogueStateTracker.from_events("some sender", []) domain = Domain.empty() action = MyLoop() actual = await action.run( CollectingOutputChannel(), TemplatedNaturalLanguageGenerator(domain.responses), tracker, domain, ) assert actual == [ ActiveLoop(form_name), *expected_activation_events, *expected_do_events, ]
def test_load_sessions(tmp_path): """Tests loading a tracker with multiple sessions.""" domain = Domain.empty() store = SQLTrackerStore(domain, db=os.path.join(tmp_path, "temp.db")) tracker = DialogueStateTracker("test123", None) tracker.update_with_events( [ UserUttered("0"), UserUttered("1"), SessionStarted(), UserUttered("2"), UserUttered("3"), ], domain, ) store.save(tracker) loader = MarkerTrackerLoader(store, STRATEGY_ALL) result = list(loader.load()) assert len(result) == 1 # contains only one tracker assert len(result[0].events) == len(tracker.events)
def test_metadata_version_check(): trained_at = datetime.utcnow() old_version = "2.7.2" expected_message = ( f"The model version is trained using Rasa Open Source " f"{old_version} and is not compatible with your current " f"installation .*") with pytest.raises(UnsupportedModelVersionError, match=expected_message): ModelMetadata( trained_at, old_version, "some id", Domain.empty(), GraphSchema(nodes={}), GraphSchema(nodes={}), project_fingerprint="some_fingerprint", training_type=TrainingType.NLU, core_target="core", nlu_target="nlu", language="zh", )
def test_number_of_examples_per_intent_with_yaml(tmp_path: Path): domain_path = tmp_path / "domain.yml" domain_path.write_text(Domain.empty().as_yaml()) config_path = tmp_path / "config.yml" config_path.touch() importer = TrainingDataImporter.load_from_dict( {}, str(config_path), str(domain_path), [ "data/test_number_nlu_examples/nlu.yml", "data/test_number_nlu_examples/stories.yml", "data/test_number_nlu_examples/rules.yml", ], ) training_data = importer.get_nlu_data() assert training_data.intents == {"greet", "ask_weather"} assert training_data.number_of_examples_per_intent["greet"] == 2 assert training_data.number_of_examples_per_intent["ask_weather"] == 3
def test_merge_with_empty_other_domain(): domain = Domain.from_yaml("""config: store_entities_as_slots: false session_config: session_expiration_time: 20 carry_over_slots: true entities: - cuisine intents: - greet slots: cuisine: type: text responses: utter_goodbye: - text: bye! utter_greet: - text: hey you!""") merged = domain.merge(Domain.empty(), override=True) assert merged.as_dict() == domain.as_dict()
async def test_1st_affirmation_is_successful(): tracker = DialogueStateTracker.from_events( "some-sender", evts=[ # User sends message with low NLU confidence *_message_requiring_fallback(), ActiveLoop(ACTION_TWO_STAGE_FALLBACK_NAME), # Action asks user to affirm *_two_stage_clarification_request(), ActionExecuted(ACTION_LISTEN_NAME), # User affirms UserUttered("hi", { "name": "greet", "confidence": 1.0 }), ], ) domain = Domain.empty() action = TwoStageFallbackAction() events = await action.run( CollectingOutputChannel(), TemplatedNaturalLanguageGenerator(domain.templates), tracker, domain, ) for events in events: tracker.update(events, domain) applied_events = tracker.applied_events() assert applied_events == [ ActionExecuted(ACTION_LISTEN_NAME), UserUttered("hi", { "name": "greet", "confidence": 1.0 }), ]
async def test_update_conversation_with_events( rasa_app: SanicASGITestClient, monkeypatch: MonkeyPatch, initial_tracker_events: List[Event], events_to_append: List[Event], expected_events: List[Event], ): conversation_id = "some-conversation-ID" domain = Domain.empty() tracker_store = InMemoryTrackerStore(domain) monkeypatch.setattr(rasa_app.app.agent, "tracker_store", tracker_store) if initial_tracker_events: tracker = DialogueStateTracker.from_events( conversation_id, initial_tracker_events ) tracker_store.save(tracker) fetched_tracker = await rasa.server.update_conversation_with_events( conversation_id, rasa_app.app.agent.create_processor(), domain, events_to_append ) assert list(fetched_tracker.events) == expected_events
def test_generating_trackers( default_model_storage: ModelStorage, default_execution_context: ExecutionContext, config: Dict[Text, Any], expected_trackers: int, ): reader = YAMLStoryReader() steps = reader.read_from_file("data/test_yaml_stories/stories.yml") component = TrainingTrackerProvider.create( { **TrainingTrackerProvider.get_default_config(), **config }, default_model_storage, Resource("xy"), default_execution_context, ) trackers = component.generate_trackers(story_graph=StoryGraph(steps), domain=Domain.empty()) assert len(trackers) == expected_trackers assert all(isinstance(t, TrackerWithCachedStates) for t in trackers)
def test_tracker_store_deprecated_session_retrieval_kwarg(): tracker_store = SQLTrackerStore( Domain.empty(), retrieve_events_from_previous_conversation_sessions=True) conversation_id = uuid.uuid4().hex tracker = DialogueStateTracker.from_events( conversation_id, [ ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), UserUttered("hi"), ], ) mocked_retrieve_full_tracker = Mock() tracker_store.retrieve_full_tracker = mocked_retrieve_full_tracker tracker_store.save(tracker) _ = tracker_store.retrieve(conversation_id) mocked_retrieve_full_tracker.assert_called_once()