def test_should_be_serializable_before_fitting(self): # Given config = LookupIntentParserConfig(ignore_stop_words=True) parser = LookupIntentParser(config=config) # When parser.persist(self.tmp_file_path) # Then expected_dict = { "config": { "unit_name": "lookup_intent_parser", "ignore_stop_words": True, }, "language_code": None, "intents_names": [], "map": None, "slots_names": [], "entity_scopes": None, "stop_words_whitelist": None } metadata = {"unit_name": "lookup_intent_parser"} self.assertJsonContent(self.tmp_file_path / "metadata.json", metadata) self.assertJsonContent( self.tmp_file_path / "intent_parser.json", expected_dict)
def test_should_be_serializable(self, mock_get_stop_words): # Given dataset_stream = io.StringIO( """ --- type: intent name: searchFlight slots: - name: origin entity: city - name: destination entity: city utterances: - find me a flight from [origin](Paris) to [destination](New York) - I need a flight to [destination](Berlin) --- type: entity name: city values: - london - [new york, big apple] - [paris, city of lights]""") dataset = Dataset.from_yaml_files("en", [dataset_stream]).json mock_get_stop_words.return_value = {"a", "me"} config = LookupIntentParserConfig(ignore_stop_words=True) parser = LookupIntentParser(config=config).fit(dataset) # When parser.persist(self.tmp_file_path) # Then expected_dict = { "config": { "unit_name": "lookup_intent_parser", "ignore_stop_words": True, }, "intents_names": ["searchFlight"], "language_code": "en", "map": { "-2020846245": [0, [0, 1]], "-1558674456": [0, [1]], }, "slots_names": ["origin", "destination"], "entity_scopes": [ { "entity_scope": {"builtin": [], "custom": ["city"]}, "intent_group": ["searchFlight"] } ], "stop_words_whitelist": dict() } metadata = {"unit_name": "lookup_intent_parser"} self.assertJsonContent(self.tmp_file_path / "metadata.json", metadata) self.assertJsonContent( self.tmp_file_path / "intent_parser.json", expected_dict)
def test_should_parse_intent_after_deserialization(self): # Given dataset = self.slots_dataset shared = self.get_shared_data(dataset) parser = LookupIntentParser(**shared).fit(dataset) parser.persist(self.tmp_file_path) deserialized_parser = LookupIntentParser.from_path( self.tmp_file_path, **shared) text = "this is a dummy_a query with another dummy_c at 10p.m. or " \ "at 12p.m." # When parsing = deserialized_parser.parse(text) # Then probability = 1.0 expected_intent = intent_classification_result( intent_name="dummy_intent_1", probability=probability) self.assertEqual(expected_intent, parsing[RES_INTENT])
def test_should_parse_slots_after_deserialization(self): # Given dataset = self.slots_dataset shared = self.get_shared_data(dataset) parser = LookupIntentParser(**shared).fit(dataset) parser.persist(self.tmp_file_path) deserialized_parser = LookupIntentParser.from_path( self.tmp_file_path, **shared) texts = [ ( "this is a dummy a query with another dummy_c at 10p.m. or at" " 12p.m.", [ unresolved_slot( match_range=(10, 17), value="dummy a", entity="dummy_entity_1", slot_name="dummy_slot_name", ), unresolved_slot( match_range=(37, 44), value="dummy_c", entity="dummy_entity_2", slot_name="dummy_slot_name2", ), unresolved_slot( match_range=(45, 54), value="at 10p.m.", entity="snips/datetime", slot_name="startTime", ), unresolved_slot( match_range=(58, 67), value="at 12p.m.", entity="snips/datetime", slot_name="startTime", ), ], ), ( "this, is,, a, dummy a query with another dummy_c at 10pm or " "at 12p.m.", [ unresolved_slot( match_range=(14, 21), value="dummy a", entity="dummy_entity_1", slot_name="dummy_slot_name", ), unresolved_slot( match_range=(41, 48), value="dummy_c", entity="dummy_entity_2", slot_name="dummy_slot_name2", ), unresolved_slot( match_range=(49, 56), value="at 10pm", entity="snips/datetime", slot_name="startTime", ), unresolved_slot( match_range=(60, 69), value="at 12p.m.", entity="snips/datetime", slot_name="startTime", ), ], ), ( "this is a dummy b", [ unresolved_slot( match_range=(10, 17), value="dummy b", entity="dummy_entity_1", slot_name="dummy_slot_name", ) ], ), ( " this is a dummy b ", [ unresolved_slot( match_range=(11, 18), value="dummy b", entity="dummy_entity_1", slot_name="dummy_slot_name", ) ], ), ] for text, expected_slots in texts: # When parsing = deserialized_parser.parse(text) # Then self.assertListEqual(expected_slots, parsing[RES_SLOTS])