示例#1
0
    def test_should_get_slots_after_deserialization(self):
        # Given
        dataset = self.slots_dataset
        dataset = validate_and_format_dataset(dataset)

        parser = DeterministicIntentParser().fit(dataset)
        deserialized_parser = DeterministicIntentParser \
            .from_dict(parser.to_dict())
        texts = [
            (
                "this is a dummy a query with another dummy_c at 10p.m. or at"
                " 12p.m.",
                [
                    unresolved_slot(match_range=(10, 17), value="dummy a",
                                    entity="dummy_entity_1",
                                    slot_name="dummy_slot_name"),
                    unresolved_slot(match_range=(37, 44), value="dummy_c",
                                    entity="dummy_entity_2",
                                    slot_name="dummy_slot_name2"),
                    unresolved_slot(match_range=(45, 54), value="at 10p.m.",
                                    entity="snips/datetime",
                                    slot_name="startTime"),
                    unresolved_slot(match_range=(58, 67), value="at 12p.m.",
                                    entity="snips/datetime",
                                    slot_name="startTime")
                ]
            ),
            (
                "this, is,, a, dummy a query with another dummy_c at 10pm or "
                "at 12p.m.",
                [
                    unresolved_slot(match_range=(14, 21), value="dummy a",
                                    entity="dummy_entity_1",
                                    slot_name="dummy_slot_name"),
                    unresolved_slot(match_range=(41, 48), value="dummy_c",
                                    entity="dummy_entity_2",
                                    slot_name="dummy_slot_name2"),
                    unresolved_slot(match_range=(49, 56), value="at 10pm",
                                    entity="snips/datetime",
                                    slot_name="startTime"),
                    unresolved_slot(match_range=(60, 69), value="at 12p.m.",
                                    entity="snips/datetime",
                                    slot_name="startTime")
                ]
            ),
            (
                "this is a dummy b",
                [
                    unresolved_slot(match_range=(10, 17), value="dummy b",
                                    entity="dummy_entity_1",
                                    slot_name="dummy_slot_name")
                ]
            ),
            (
                " this is a dummy b ",
                [
                    unresolved_slot(match_range=(11, 18), value="dummy b",
                                    entity="dummy_entity_1",
                                    slot_name="dummy_slot_name")
                ]
            )
        ]

        for text, expected_slots in texts:
            # When
            parsing = deserialized_parser.parse(text)

            # Then
            self.assertListEqual(expected_slots, parsing[RES_SLOTS])
示例#2
0
    def test_should_be_deserializable_with_stop_words(self):
        # Given
        parser_dict = {
            "config": {
                "max_queries": 42,
                "max_pattern_length": 43
            },
            "language_code": "en",
            "group_names_to_slot_names": {
                "hello_group": "hello_slot",
                "world_group": "world_slot"
            },
            "patterns": {
                "my_intent":
                ["(?P<hello_group>hello?)", "(?P<world_group>world$)"]
            },
            "slot_names_to_entities": {
                "my_intent": {
                    "hello_slot": "hello_entity",
                    "world_slot": "world_entity"
                }
            },
            "stop_words_whitelist": {
                "my_intent": ["this", "that"],
            }
        }
        self.tmp_file_path.mkdir()
        metadata = {"unit_name": "deterministic_intent_parser"}
        self.writeJsonContent(self.tmp_file_path / "intent_parser.json",
                              parser_dict)
        self.writeJsonContent(self.tmp_file_path / "metadata.json", metadata)

        # When
        parser = DeterministicIntentParser.from_path(self.tmp_file_path)

        # Then
        patterns = {
            "my_intent":
            ["(?P<hello_group>hello?)", "(?P<world_group>world$)"]
        }
        group_names_to_slot_names = {
            "hello_group": "hello_slot",
            "world_group": "world_slot"
        }
        slot_names_to_entities = {
            "my_intent": {
                "hello_slot": "hello_entity",
                "world_slot": "world_entity"
            }
        }
        stop_words_whitelist = {"my_intent": {"this", "that"}}
        config = DeterministicIntentParserConfig(max_queries=42,
                                                 max_pattern_length=43)
        expected_parser = DeterministicIntentParser(config=config)
        expected_parser.language = LANGUAGE_EN
        expected_parser.group_names_to_slot_names = group_names_to_slot_names
        expected_parser.slot_names_to_entities = slot_names_to_entities
        expected_parser.patterns = patterns
        # pylint:disable=protected-access
        expected_parser._stop_words_whitelist = stop_words_whitelist
        # pylint:enable=protected-access

        self.assertEqual(parser.to_dict(), expected_parser.to_dict())
示例#3
0
    def test_should_be_serializable(self, mock_get_stop_words):
        # Given
        dataset_stream = io.StringIO("""
---
type: intent
name: searchFlight
slots:
  - name: origin
    entity: city
  - name: destination
    entity: city
utterances:
  - find me a flight from [origin](Paris) to [destination](New York)
  - I need a flight to [destination](Berlin)

---
type: entity
name: city
values:
  - london
  - [new york, big apple]
  - [paris, city of lights]
            """)

        dataset = Dataset.from_yaml_files("en", [dataset_stream]).json

        mock_get_stop_words.return_value = {"a", "me"}
        config = DeterministicIntentParserConfig(max_queries=42,
                                                 max_pattern_length=100,
                                                 ignore_stop_words=True)
        parser = DeterministicIntentParser(config=config).fit(dataset)

        # When
        parser.persist(self.tmp_file_path)

        # Then
        expected_dict = {
            "config": {
                "unit_name": "deterministic_intent_parser",
                "max_queries": 42,
                "max_pattern_length": 100,
                "ignore_stop_words": True
            },
            "language_code": "en",
            "group_names_to_slot_names": {
                "group0": "destination",
                "group1": "origin",
            },
            "patterns": {
                "searchFlight": [
                    "^\\s*find\\s*flight\\s*from\\s*(?P<group1>%CITY%)\\s*to"
                    "\\s*(?P<group0>%CITY%)\\s*$",
                    "^\\s*i\\s*need\\s*flight\\s*to\\s*(?P<group0>%CITY%)"
                    "\\s*$",
                ]
            },
            "slot_names_to_entities": {
                "searchFlight": {
                    "destination": "city",
                    "origin": "city",
                }
            },
            "stop_words_whitelist": dict()
        }
        metadata = {"unit_name": "deterministic_intent_parser"}
        self.assertJsonContent(self.tmp_file_path / "metadata.json", metadata)
        self.assertJsonContent(self.tmp_file_path / "intent_parser.json",
                               expected_dict)
示例#4
0
    def test_should_parse_slots(self):
        # Given
        dataset = self.slots_dataset
        parser = DeterministicIntentParser().fit(dataset)
        texts = [
            ("this is a dummy a query with another dummy_c at 10p.m. or at"
             " 12p.m.", [
                 unresolved_slot(match_range=(10, 17),
                                 value="dummy a",
                                 entity="dummy_entity_1",
                                 slot_name="dummy_slot_name"),
                 unresolved_slot(match_range=(37, 44),
                                 value="dummy_c",
                                 entity="dummy_entity_2",
                                 slot_name="dummy_slot_name2"),
                 unresolved_slot(match_range=(45, 54),
                                 value="at 10p.m.",
                                 entity="snips/datetime",
                                 slot_name="startTime"),
                 unresolved_slot(match_range=(58, 67),
                                 value="at 12p.m.",
                                 entity="snips/datetime",
                                 slot_name="startTime")
             ]),
            ("this, is,, a, dummy a query with another dummy_c at 10pm or "
             "at 12p.m.", [
                 unresolved_slot(match_range=(14, 21),
                                 value="dummy a",
                                 entity="dummy_entity_1",
                                 slot_name="dummy_slot_name"),
                 unresolved_slot(match_range=(41, 48),
                                 value="dummy_c",
                                 entity="dummy_entity_2",
                                 slot_name="dummy_slot_name2"),
                 unresolved_slot(match_range=(49, 56),
                                 value="at 10pm",
                                 entity="snips/datetime",
                                 slot_name="startTime"),
                 unresolved_slot(match_range=(60, 69),
                                 value="at 12p.m.",
                                 entity="snips/datetime",
                                 slot_name="startTime")
             ]),
            ("this is a dummy b", [
                unresolved_slot(match_range=(10, 17),
                                value="dummy b",
                                entity="dummy_entity_1",
                                slot_name="dummy_slot_name")
            ]),
            (" this is a dummy b ", [
                unresolved_slot(match_range=(11, 18),
                                value="dummy b",
                                entity="dummy_entity_1",
                                slot_name="dummy_slot_name")
            ]),
            (" at 8am ’ there is a dummy  a", [
                unresolved_slot(match_range=(1, 7),
                                value="at 8am",
                                entity="snips/datetime",
                                slot_name="startTime"),
                unresolved_slot(match_range=(21, 29),
                                value="dummy  a",
                                entity="dummy_entity_1",
                                slot_name="dummy_slot_name")
            ])
        ]

        for text, expected_slots in texts:
            # When
            parsing = parser.parse(text)

            # Then
            self.assertListEqual(expected_slots, parsing[RES_SLOTS])
示例#5
0
    def test_training_should_be_reproducible(self):
        # Given
        random_state = 42
        dataset_stream = io.StringIO("""
---
type: intent
name: MakeTea
utterances:
- make me a [beverage_temperature:Temperature](hot) cup of tea
- make me [number_of_cups:snips/number](five) tea cups

---
type: intent
name: MakeCoffee
utterances:
- make me [number_of_cups:snips/number](one) cup of coffee please
- brew [number_of_cups] cups of coffee""")
        dataset = Dataset.from_yaml_files("en", [dataset_stream]).json

        # When
        parser1 = DeterministicIntentParser(random_state=random_state)
        parser1.fit(dataset)

        parser2 = DeterministicIntentParser(random_state=random_state)
        parser2.fit(dataset)

        # Then
        with temp_dir() as tmp_dir:
            dir_parser1 = tmp_dir / "parser1"
            dir_parser2 = tmp_dir / "parser2"
            parser1.persist(dir_parser1)
            parser2.persist(dir_parser2)
            hash1 = dirhash(str(dir_parser1), 'sha256')
            hash2 = dirhash(str(dir_parser2), 'sha256')
            self.assertEqual(hash1, hash2)
示例#6
0
    def test_should_parse_top_intents(self):
        # Given
        dataset_stream = io.StringIO("""
---
type: intent
name: intent1
utterances:
  - meeting [time:snips/datetime](today)

---
type: intent
name: intent2
utterances:
  - meeting tomorrow
  
---
type: intent
name: intent3
utterances:
  - "[event_type](call) [time:snips/datetime](at 9pm)"

---
type: entity
name: event_type
values:
  - meeting
  - feedback session""")
        dataset = Dataset.from_yaml_files("en", [dataset_stream]).json
        parser = DeterministicIntentParser().fit(dataset)
        text = "meeting tomorrow"

        # When
        results = parser.parse(text, top_n=3)

        # Then
        time_slot = {
            "entity": "snips/datetime",
            "range": {
                "end": 16,
                "start": 8
            },
            "slotName": "time",
            "value": "tomorrow"
        }
        event_slot = {
            "entity": "event_type",
            "range": {
                "end": 7,
                "start": 0
            },
            "slotName": "event_type",
            "value": "meeting"
        }
        weight_intent_1 = 1. / 2.
        weight_intent_2 = 1.
        weight_intent_3 = 1. / 3.
        total_weight = weight_intent_1 + weight_intent_2 + weight_intent_3
        proba_intent2 = weight_intent_2 / total_weight
        proba_intent1 = weight_intent_1 / total_weight
        proba_intent3 = weight_intent_3 / total_weight
        expected_results = [
            extraction_result(intent_classification_result(
                intent_name="intent2", probability=proba_intent2),
                              slots=[]),
            extraction_result(intent_classification_result(
                intent_name="intent1", probability=proba_intent1),
                              slots=[time_slot]),
            extraction_result(intent_classification_result(
                intent_name="intent3", probability=proba_intent3),
                              slots=[event_slot, time_slot])
        ]
        self.assertEqual(expected_results, results)