def test_get_string_variations(self): # Given language = LANGUAGE_EN string = "a and b 2" # When variations = get_string_variations( string, language, BuiltinEntityParser.build(language="en")) # Then expected_variations = { "a and b 2", "a & b 2", "a b 2", "a and b two", "a & b two", "a b two", "A B two", "A And B two", "A B 2", "A and B two", "A & B two", "A & B 2", "A B two", "A and B 2", "a b 2", "a b two", "A B 2", "A And B 2", } self.assertSetEqual(variations, expected_variations)
def test_should_not_build_builtin_parser_when_provided(self): # Given dataset_stream = io.StringIO(""" --- type: intent name: MakeTea utterances: - make me a [beverage_temperature:Temperature](hot) cup of tea - make me [number_of_cups:snips/number](five) tea cups --- type: intent name: MakeCoffee utterances: - make me [number_of_cups:snips/number](one) cup of coffee please - brew [number_of_cups] cups of coffee""") dataset = Dataset.from_yaml_files("en", [dataset_stream]).json dataset = validate_and_format_dataset(dataset) builtin_entity_parser = BuiltinEntityParser.build(language="en") # When with patch("snips_nlu.entity_parser.builtin_entity_parser" ".BuiltinEntityParser.build") as mocked_build_parser: engine = SnipsNLUEngine( builtin_entity_parser=builtin_entity_parser) engine.fit(dataset) # Then mocked_build_parser.assert_not_called()
def test_builtin_entity_match_factory(self): # Given def mock_builtin_entity_scope(dataset, _): if dataset[LANGUAGE] == LANGUAGE_EN: return {SNIPS_NUMBER, SNIPS_DATETIME} return [] config = { "factory_name": "builtin_entity_match", "args": { "tagging_scheme_code": TaggingScheme.BILOU.value, }, "offsets": [0] } tokens = tokenize("one tea tomorrow at 2pm", LANGUAGE_EN) cache = [{TOKEN_NAME: token} for token in tokens] builtin_entity_parser = BuiltinEntityParser.build(language="en") factory = CRFFeatureFactory.from_config( config, builtin_entity_parser=builtin_entity_parser) # pylint: disable=protected-access factory._get_builtin_entity_scope = mock_builtin_entity_scope # pylint: enable=protected-access mocked_dataset = {"language": "en"} factory.fit(mocked_dataset, None) # When features = factory.build_features() features = sorted(features, key=lambda f: f.base_name) res0 = features[0].compute(0, cache) res1 = features[0].compute(1, cache) res2 = features[0].compute(2, cache) res3 = features[0].compute(3, cache) res4 = features[0].compute(4, cache) res5 = features[1].compute(0, cache) res6 = features[1].compute(1, cache) res7 = features[1].compute(2, cache) res8 = features[1].compute(3, cache) res9 = features[1].compute(4, cache) # Then self.assertIsInstance(factory, BuiltinEntityMatchFactory) self.assertEqual(len(features), 2) self.assertEqual(features[0].base_name, "builtin_entity_match_snips/datetime") self.assertEqual(features[1].base_name, "builtin_entity_match_snips/number") self.assertEqual(res0, UNIT_PREFIX) self.assertEqual(res1, None) self.assertEqual(res2, BEGINNING_PREFIX) self.assertEqual(res3, INSIDE_PREFIX) self.assertEqual(res4, LAST_PREFIX) self.assertEqual(res5, UNIT_PREFIX) self.assertEqual(res6, None) self.assertEqual(res7, None) self.assertEqual(res8, None) self.assertEqual(res9, None)
def fit_builtin_entity_parser_if_needed(self, dataset): # We only fit a builtin entity parser when the unit has already been # fitted or if the parser is none. # In the other cases the parser is provided fitted by another unit. if self.builtin_entity_parser is None or self.fitted: self.builtin_entity_parser = BuiltinEntityParser.build( dataset=dataset) return self
def test_should_variate_case_and_normalization(self): # Given language = LANGUAGE_EN string = "Küche" # When variations = get_string_variations( string, language, BuiltinEntityParser.build(language="en")) # Then expected_variations = {"kuche", "küche", "Kuche", "Küche"} self.assertSetEqual(variations, expected_variations)
def get_shared_data(cls, dataset, parser_usage=None): from snips_nlu.entity_parser import (BuiltinEntityParser, CustomEntityParser, CustomEntityParserUsage) if parser_usage is None: parser_usage = CustomEntityParserUsage.WITH_AND_WITHOUT_STEMS resources = cls.get_resources(dataset["language"]) builtin_entity_parser = BuiltinEntityParser.build(dataset) custom_entity_parser = CustomEntityParser.build( dataset, parser_usage, resources) return { "resources": resources, "builtin_entity_parser": builtin_entity_parser, "custom_entity_parser": custom_entity_parser }
def test_should_be_serializable_into_bytearray(self): # Given dataset = BEVERAGE_DATASET intent_parser = DeterministicIntentParser().fit(dataset) custom_entity_parser = intent_parser.custom_entity_parser # When intent_parser_bytes = intent_parser.to_byte_array() loaded_intent_parser = DeterministicIntentParser.from_byte_array( intent_parser_bytes, builtin_entity_parser=BuiltinEntityParser.build(language="en"), custom_entity_parser=custom_entity_parser) result = loaded_intent_parser.parse("make me two cups of coffee") # Then self.assertEqual("MakeCoffee", result[RES_INTENT][RES_INTENT_NAME])
def test_numbers_variations_should_handle_floats(self): # Given language = LANGUAGE_EN string = "7.62 mm caliber 2 and six" # When variations = numbers_variations( string, language, BuiltinEntityParser.build(language="en")) # Then expected_variations = { "7.62 mm caliber 2 and six", "7.62 mm caliber two and six", "7.62 mm caliber 2 and 6", "7.62 mm caliber two and 6", } self.assertSetEqual(variations, expected_variations)
def test_should_be_serializable_into_bytearray(self): # Given dataset = BEVERAGE_DATASET engine = SnipsNLUEngine().fit(dataset) # When engine_bytes = engine.to_byte_array() builtin_entity_parser = BuiltinEntityParser.build(dataset=dataset) custom_entity_parser = CustomEntityParser.build( dataset, parser_usage=CustomEntityParserUsage.WITHOUT_STEMS) loaded_engine = SnipsNLUEngine.from_byte_array( engine_bytes, builtin_entity_parser=builtin_entity_parser, custom_entity_parser=custom_entity_parser) result = loaded_engine.parse("Make me two cups of coffee") # Then self.assertEqual(result[RES_INTENT][RES_INTENT_NAME], "MakeCoffee")
def test_should_get_intent_after_deserialization(self): # Given dataset = validate_and_format_dataset(BEVERAGE_DATASET) classifier = LogRegIntentClassifier().fit(dataset) classifier.persist(self.tmp_file_path) # When builtin_entity_parser = BuiltinEntityParser.build(language="en") custom_entity_parser = CustomEntityParser.build( dataset, CustomEntityParserUsage.WITHOUT_STEMS) loaded_classifier = LogRegIntentClassifier.from_path( self.tmp_file_path, builtin_entity_parser=builtin_entity_parser, custom_entity_parser=custom_entity_parser) result = loaded_classifier.get_intent("Make me two cups of tea") # Then expected_intent = "MakeTea" self.assertEqual(expected_intent, result[RES_INTENT_NAME])
def test_get_france_24(self): # Given language = LANGUAGE_FR string = "france 24" # When variations = get_string_variations( string, language, BuiltinEntityParser.build(language="en")) # Then expected_variations = { "france vingt-quatre", "France vingt-quatre", "france vingt quatre", "France vingt quatre", "france 24", "France 24", } self.assertSetEqual(variations, expected_variations)
def test_should_be_serializable_into_bytearray(self): # Given dataset = validate_and_format_dataset(BEVERAGE_DATASET) intent_classifier = LogRegIntentClassifier().fit(dataset) # When intent_classifier_bytes = intent_classifier.to_byte_array() custom_entity_parser = CustomEntityParser.build( dataset, CustomEntityParserUsage.WITHOUT_STEMS) builtin_entity_parser = BuiltinEntityParser.build(language="en") loaded_classifier = LogRegIntentClassifier.from_byte_array( intent_classifier_bytes, builtin_entity_parser=builtin_entity_parser, custom_entity_parser=custom_entity_parser) result = loaded_classifier.get_intent("make me two cups of tea") # Then expected_intent = "MakeTea" self.assertEqual(expected_intent, result[RES_INTENT_NAME])
def test_alphabetic_value(self): # Given language = LANGUAGE_EN string = "1 time and 23 times and one thousand and sixty and 1.2" parser = BuiltinEntityParser.build(language=language) entities = parser.parse(string, scope=[SNIPS_NUMBER]) entities = sorted(entities, key=lambda x: x[RES_MATCH_RANGE][START]) expected_values = [ "one", "twenty-three", "one thousand and sixty", None ] self.assertEqual(len(entities), len(expected_values)) for i, ent in enumerate(entities): # When value = alphabetic_value(ent, language) # Then self.assertEqual(value, expected_values[i])
def test_should_get_intent_after_deserialization(self): # Given dataset = validate_and_format_dataset(self.slots_dataset) parser = DeterministicIntentParser().fit(dataset) custom_entity_parser = parser.custom_entity_parser parser.persist(self.tmp_file_path) deserialized_parser = DeterministicIntentParser.from_path( self.tmp_file_path, builtin_entity_parser=BuiltinEntityParser.build(language="en"), custom_entity_parser=custom_entity_parser) text = "this is a dummy_a query with another dummy_c at 10p.m. or " \ "at 12p.m." # When parsing = deserialized_parser.parse(text) # Then probability = 1.0 expected_intent = intent_classification_result( intent_name="dummy_intent_1", probability=probability) self.assertEqual(expected_intent, parsing[RES_INTENT])
def test_should_be_deserializable(self): # Given parser_dict = { "config": { "unit_name": "lookup_intent_parser", "ignore_stop_words": True }, "language_code": "en", "map": { hash_str("make coffee"): [0, []], hash_str("prepare % snipsnumber % coffees"): [0, [0]], hash_str("% snipsnumber % teas at % snipstemperature %"): [1, [0, 1]], }, "slots_names": ["nb_cups", "tea_temperature"], "intents_names": ["MakeCoffee", "MakeTea"], "entity_scopes": [ { "entity_scope": { "builtin": ["snips/number"], "custom": [], }, "intent_group": ["MakeCoffee"] }, { "entity_scope": { "builtin": ["snips/number", "snips/temperature"], "custom": [], }, "intent_group": ["MakeTea"] }, ], "stop_words_whitelist": dict() } self.tmp_file_path.mkdir() metadata = {"unit_name": "lookup_intent_parser"} self.writeJsonContent( self.tmp_file_path / "intent_parser.json", parser_dict) self.writeJsonContent(self.tmp_file_path / "metadata.json", metadata) resources = self.get_resources("en") builtin_entity_parser = BuiltinEntityParser.build(language="en") custom_entity_parser = EntityParserMock() # When parser = LookupIntentParser.from_path( self.tmp_file_path, custom_entity_parser=custom_entity_parser, builtin_entity_parser=builtin_entity_parser, resources=resources) res_make_coffee = parser.parse("make me a coffee") res_make_tea = parser.parse("two teas at 90°C please") # Then expected_result_coffee = parsing_result( input="make me a coffee", intent=intent_classification_result("MakeCoffee", 1.0), slots=[]) expected_result_tea = parsing_result( input="two teas at 90°C please", intent=intent_classification_result("MakeTea", 1.0), slots=[ { "entity": "snips/number", "range": {"end": 3, "start": 0}, "slotName": "nb_cups", "value": "two" }, { "entity": "snips/temperature", "range": {"end": 16, "start": 12}, "slotName": "tea_temperature", "value": "90°C" } ]) self.assertEqual(expected_result_coffee, res_make_coffee) self.assertEqual(expected_result_tea, res_make_tea)
def test_preprocess_for_training(self): # Given language = LANGUAGE_EN resources = { STEMS: { "beautiful": "beauty", "birdy": "bird", "entity": "ent" }, WORD_CLUSTERS: { "my_word_clusters": { "beautiful": "cluster_1", "birdy": "cluster_2", "entity": "cluster_3" } }, STOP_WORDS: set() } dataset_stream = io.StringIO(""" --- type: intent name: intent1 utterances: - dummy utterance --- type: entity name: entity_1 automatically_extensible: false use_synononyms: false matching_strictness: 1.0 values: - [entity 1, alternative entity 1] - [éntity 1, alternative entity 1] --- type: entity name: entity_2 automatically_extensible: false use_synononyms: true matching_strictness: 1.0 values: - entity 1 - [Éntity 2, Éntity_2, Alternative entity 2]""") dataset = Dataset.from_yaml_files("en", [dataset_stream]).json custom_entity_parser = CustomEntityParser.build( dataset, CustomEntityParserUsage.WITH_STEMS, resources) builtin_entity_parser = BuiltinEntityParser.build(dataset, language) utterances = [{ "data": [{ "text": "hÉllo wOrld " }, { "text": " yo " }, { "text": " yo " }, { "text": "yo " }, { "text": "Éntity_2", "entity": "entity_2" }, { "text": " " }, { "text": "Éntity_2", "entity": "entity_2" }] }, { "data": [{ "text": "beauTiful World " }, { "text": "entity 1", "entity": "entity_1" }, { "text": " " }, { "text": "2", "entity": "snips/number" }] }, { "data": [{ "text": "Bird bïrdy" }] }, { "data": [{ "text": "Bird birdy" }] }] config = TfidfVectorizerConfig(use_stemming=True, word_clusters_name="my_word_clusters") vectorizer = TfidfVectorizer( config=config, custom_entity_parser=custom_entity_parser, builtin_entity_parser=builtin_entity_parser, resources=resources) vectorizer._language = language # When processed_data = vectorizer._preprocess(utterances, training=True) processed_data = list(zip(*processed_data)) # Then u_0 = { "data": [{ "text": "hello world" }, { "text": "yo" }, { "text": "yo" }, { "text": "yo" }, { "text": "entity_2", "entity": "entity_2" }, { "text": "" }, { "text": "entity_2", "entity": "entity_2" }] } u_1 = { "data": [{ "text": "beauty world" }, { "text": "ent 1", "entity": "entity_1" }, { "text": "" }, { "text": "2", "entity": "snips/number" }] } u_2 = {"data": [{"text": "bird bird"}]} ent_00 = { "entity_kind": "entity_2", "value": "Éntity_2", "range": { "start": 23, "end": 31 } } ent_01 = { "entity_kind": "entity_2", "value": "Éntity_2", "range": { "start": 32, "end": 40 } } ent_1 = { "entity_kind": "entity_1", "value": "entity 1", "range": { "start": 16, "end": 24 } } num_1 = { "entity_kind": "snips/number", "value": "2", "range": { "start": 25, "end": 26 } } expected_data = [(u_0, [], [ent_00, ent_01], []), (u_1, [num_1], [ent_1], ["cluster_1", "cluster_3"]), (u_2, [], [], []), (u_2, [], [], ["cluster_2"])] self.assertSequenceEqual(expected_data, processed_data)
def test_should_get_slots_after_deserialization(self): # Given dataset = self.slots_dataset dataset = validate_and_format_dataset(dataset) parser = DeterministicIntentParser().fit(dataset) custom_entity_parser = parser.custom_entity_parser parser.persist(self.tmp_file_path) deserialized_parser = DeterministicIntentParser.from_path( self.tmp_file_path, builtin_entity_parser=BuiltinEntityParser.build(language="en"), custom_entity_parser=custom_entity_parser) texts = [ ("this is a dummy a query with another dummy_c at 10p.m. or at" " 12p.m.", [ unresolved_slot(match_range=(10, 17), value="dummy a", entity="dummy_entity_1", slot_name="dummy_slot_name"), unresolved_slot(match_range=(37, 44), value="dummy_c", entity="dummy_entity_2", slot_name="dummy_slot_name2"), unresolved_slot(match_range=(45, 54), value="at 10p.m.", entity="snips/datetime", slot_name="startTime"), unresolved_slot(match_range=(58, 67), value="at 12p.m.", entity="snips/datetime", slot_name="startTime") ]), ("this, is,, a, dummy a query with another dummy_c at 10pm or " "at 12p.m.", [ unresolved_slot(match_range=(14, 21), value="dummy a", entity="dummy_entity_1", slot_name="dummy_slot_name"), unresolved_slot(match_range=(41, 48), value="dummy_c", entity="dummy_entity_2", slot_name="dummy_slot_name2"), unresolved_slot(match_range=(49, 56), value="at 10pm", entity="snips/datetime", slot_name="startTime"), unresolved_slot(match_range=(60, 69), value="at 12p.m.", entity="snips/datetime", slot_name="startTime") ]), ("this is a dummy b", [ unresolved_slot(match_range=(10, 17), value="dummy b", entity="dummy_entity_1", slot_name="dummy_slot_name") ]), (" this is a dummy b ", [ unresolved_slot(match_range=(11, 18), value="dummy b", entity="dummy_entity_1", slot_name="dummy_slot_name") ]) ] for text, expected_slots in texts: # When parsing = deserialized_parser.parse(text) # Then self.assertListEqual(expected_slots, parsing[RES_SLOTS])
def test_preprocess(self): # Given language = LANGUAGE_EN resources = { STEMS: { "beautiful": "beauty", "birdy": "bird", "entity": "ent" }, WORD_CLUSTERS: { "my_word_clusters": { "beautiful": "cluster_1", "birdy": "cluster_2", "entity": "cluster_3" } }, STOP_WORDS: set() } dataset_stream = io.StringIO(""" --- type: intent name: intent1 utterances: - dummy utterance --- type: entity name: entity_1 values: - [entity 1, alternative entity 1] - [éntity 1, alternative entity 1] --- type: entity name: entity_2 values: - entity 1 - [Éntity 2, Éntity_2, Alternative entity 2]""") dataset = Dataset.from_yaml_files("en", [dataset_stream]).json custom_entity_parser = CustomEntityParser.build( dataset, CustomEntityParserUsage.WITH_STEMS, resources) builtin_entity_parser = BuiltinEntityParser.build(dataset, language) utterances = [ text_to_utterance("hÉllo wOrld Éntity_2"), text_to_utterance("beauTiful World entity 1"), text_to_utterance("Bird bïrdy"), text_to_utterance("Bird birdy"), ] config = TfidfVectorizerConfig(use_stemming=True, word_clusters_name="my_word_clusters") vectorizer = TfidfVectorizer( config=config, custom_entity_parser=custom_entity_parser, builtin_entity_parser=builtin_entity_parser, resources=resources) vectorizer._language = language vectorizer.builtin_entity_scope = {"snips/number"} # When processed_data = vectorizer._preprocess(utterances) processed_data = list(zip(*processed_data)) # Then u_0 = {"data": [{"text": "hello world entity_2"}]} u_1 = {"data": [{"text": "beauty world ent 1"}]} u_2 = {"data": [{"text": "bird bird"}]} u_3 = {"data": [{"text": "bird bird"}]} ent_0 = { "entity_kind": "entity_2", "value": "entity_2", "resolved_value": "Éntity 2", "range": { "start": 12, "end": 20 } } num_0 = { "entity_kind": "snips/number", "value": "2", "resolved_value": { "value": 2.0, "kind": "Number" }, "range": { "start": 19, "end": 20 } } ent_11 = { "entity_kind": "entity_1", "value": "ent 1", "resolved_value": "entity 1", "range": { "start": 13, "end": 18 } } ent_12 = { "entity_kind": "entity_2", "value": "ent 1", "resolved_value": "entity 1", "range": { "start": 13, "end": 18 } } num_1 = { "entity_kind": "snips/number", "value": "1", "range": { "start": 23, "end": 24 }, "resolved_value": { "value": 1.0, "kind": "Number" }, } expected_data = [(u_0, [num_0], [ent_0], []), (u_1, [num_1], [ent_11, ent_12], ["cluster_1", "cluster_3"]), (u_2, [], [], []), (u_3, [], [], ["cluster_2"])] self.assertSequenceEqual(expected_data, processed_data)
def test_preprocess(self): # Given language = LANGUAGE_EN resources = { STEMS: { "beautiful": "beauty", "birdy": "bird", "entity": "ent" }, WORD_CLUSTERS: { "my_word_clusters": { "beautiful": "cluster_1", "birdy": "cluster_2", "entity": "cluster_3" } }, STOP_WORDS: set() } dataset_stream = io.StringIO(""" --- type: intent name: intent1 utterances: - dummy utterance --- type: entity name: entity_1 automatically_extensible: false use_synononyms: false matching_strictness: 1.0 values: - [entity 1, alternative entity 1] - [éntity 1, alternative entity 1] --- type: entity name: entity_2 automatically_extensible: false use_synononyms: true matching_strictness: 1.0 values: - entity 1 - [Éntity 2, Éntity_2, Alternative entity 2] """) dataset = Dataset.from_yaml_files("en", [dataset_stream]).json custom_entity_parser = CustomEntityParser.build( dataset, CustomEntityParserUsage.WITHOUT_STEMS, resources) builtin_entity_parser = BuiltinEntityParser.build(dataset, language) u_0 = text_to_utterance("hÉllo wOrld Éntity_2") u_1 = text_to_utterance("beauTiful World entity 1") u_2 = text_to_utterance("Bird bïrdy") u_3 = text_to_utterance("Bird birdy") utterances = [u_0, u_1, u_2, u_3] vectorizer = CooccurrenceVectorizer( custom_entity_parser=custom_entity_parser, builtin_entity_parser=builtin_entity_parser, resources=resources) vectorizer._language = language # When processed_data = vectorizer._preprocess(utterances) processed_data = list(zip(*processed_data)) # Then ent_0 = { "entity_kind": "entity_2", "value": "Éntity_2", "resolved_value": "Éntity 2", "range": { "start": 12, "end": 20 } } num_0 = { "entity_kind": "snips/number", "value": "2", "resolved_value": { "value": 2.0, "kind": "Number" }, "range": { "start": 19, "end": 20 } } ent_11 = { "entity_kind": "entity_1", "value": "entity 1", "resolved_value": "entity 1", "range": { "start": 16, "end": 24 } } ent_12 = { "entity_kind": "entity_2", "value": "entity 1", "resolved_value": "entity 1", "range": { "start": 16, "end": 24 } } num_1 = { "entity_kind": "snips/number", "value": "1", "range": { "start": 23, "end": 24 }, "resolved_value": { "value": 1.0, "kind": "Number" } } expected_data = [(u_0, [num_0], [ent_0]), (u_1, [num_1], [ent_11, ent_12]), (u_2, [], []), (u_3, [], [])] self.assertSequenceEqual(expected_data, processed_data)
def test_augment_slots(self): # Given language = LANGUAGE_EN text = "Find me a flight before 10pm and after 8pm" tokens = tokenize(text, language) missing_slots = {"start_date", "end_date"} tags = ['O' for _ in tokens] def mocked_sequence_probability(_, tags_): tags_1 = [ 'O', 'O', 'O', 'O', '%sstart_date' % BEGINNING_PREFIX, '%sstart_date' % INSIDE_PREFIX, 'O', '%send_date' % BEGINNING_PREFIX, '%send_date' % INSIDE_PREFIX ] tags_2 = [ 'O', 'O', 'O', 'O', '%send_date' % BEGINNING_PREFIX, '%send_date' % INSIDE_PREFIX, 'O', '%sstart_date' % BEGINNING_PREFIX, '%sstart_date' % INSIDE_PREFIX ] tags_3 = ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'] tags_4 = [ 'O', 'O', 'O', 'O', 'O', 'O', 'O', '%sstart_date' % BEGINNING_PREFIX, '%sstart_date' % INSIDE_PREFIX ] tags_5 = [ 'O', 'O', 'O', 'O', 'O', 'O', 'O', '%send_date' % BEGINNING_PREFIX, '%send_date' % INSIDE_PREFIX ] tags_6 = [ 'O', 'O', 'O', 'O', '%sstart_date' % BEGINNING_PREFIX, '%sstart_date' % INSIDE_PREFIX, 'O', 'O', 'O' ] tags_7 = [ 'O', 'O', 'O', 'O', '%send_date' % BEGINNING_PREFIX, '%send_date' % INSIDE_PREFIX, 'O', 'O', 'O' ] tags_8 = [ 'O', 'O', 'O', 'O', '%sstart_date' % BEGINNING_PREFIX, '%sstart_date' % INSIDE_PREFIX, 'O', '%sstart_date' % BEGINNING_PREFIX, '%sstart_date' % INSIDE_PREFIX ] tags_9 = [ 'O', 'O', 'O', 'O', '%send_date' % BEGINNING_PREFIX, '%send_date' % INSIDE_PREFIX, 'O', '%send_date' % BEGINNING_PREFIX, '%send_date' % INSIDE_PREFIX ] if tags_ == tags_1: return 0.6 elif tags_ == tags_2: return 0.8 elif tags_ == tags_3: return 0.2 elif tags_ == tags_4: return 0.2 elif tags_ == tags_5: return 0.99 elif tags_ == tags_6: return 0.0 elif tags_ == tags_7: return 0.0 elif tags_ == tags_8: return 0.5 elif tags_ == tags_9: return 0.5 else: raise ValueError("Unexpected tag sequence: %s" % tags_) slot_filler_config = CRFSlotFillerConfig(random_seed=42) slot_filler = CRFSlotFiller( config=slot_filler_config, builtin_entity_parser=BuiltinEntityParser.build(language="en")) slot_filler.language = LANGUAGE_EN slot_filler.intent = "intent1" slot_filler.slot_name_mapping = { "start_date": "snips/datetime", "end_date": "snips/datetime", } # pylint:disable=protected-access slot_filler._get_sequence_probability = MagicMock( side_effect=mocked_sequence_probability) # pylint:enable=protected-access slot_filler.compute_features = MagicMock(return_value=None) # When # pylint: disable=protected-access augmented_slots = slot_filler._augment_slots(text, tokens, tags, missing_slots) # pylint: enable=protected-access # Then expected_slots = [ unresolved_slot(value='after 8pm', match_range={ START: 33, END: 42 }, entity='snips/datetime', slot_name='end_date') ] self.assertListEqual(augmented_slots, expected_slots)
def __init__(self, lang="en"): super().__init__() self.engine = BuiltinEntityParser.build(language=lang)