def test_should_be_deserializable_before_fitting_with_whitelist(self):
        # Given
        parser_dict = {
            "config": {
                "max_queries": 42,
                "max_pattern_length": 43
            },
            "language_code": None,
            "group_names_to_slot_names": None,
            "patterns": None,
            "slot_names_to_entities": None,
            "stop_words_whitelist": None
        }
        self.tmp_file_path.mkdir()
        metadata = {"unit_name": "deterministic_intent_parser"}
        self.writeJsonContent(self.tmp_file_path / "intent_parser.json",
                              parser_dict)
        self.writeJsonContent(self.tmp_file_path / "metadata.json", metadata)

        # When
        parser = DeterministicIntentParser.from_path(self.tmp_file_path)

        # Then
        config = DeterministicIntentParserConfig(max_queries=42,
                                                 max_pattern_length=43)
        expected_parser = DeterministicIntentParser(config=config)
        self.assertEqual(parser.to_dict(), expected_parser.to_dict())
    def test_should_be_deserializable_without_stop_words(self):
        # Given
        parser_dict = {
            "config": {
                "max_queries": 42,
                "max_pattern_length": 43
            },
            "language_code": "en",
            "group_names_to_slot_names": {
                "hello_group": "hello_slot",
                "world_group": "world_slot"
            },
            "patterns": {
                "my_intent":
                ["(?P<hello_group>hello?)", "(?P<world_group>world$)"]
            },
            "slot_names_to_entities": {
                "my_intent": {
                    "hello_slot": "hello_entity",
                    "world_slot": "world_entity"
                }
            }
        }
        self.tmp_file_path.mkdir()
        metadata = {"unit_name": "deterministic_intent_parser"}
        self.writeJsonContent(self.tmp_file_path / "intent_parser.json",
                              parser_dict)
        self.writeJsonContent(self.tmp_file_path / "metadata.json", metadata)

        # When
        parser = DeterministicIntentParser.from_path(self.tmp_file_path)

        # Then
        patterns = {
            "my_intent":
            ["(?P<hello_group>hello?)", "(?P<world_group>world$)"]
        }
        group_names_to_slot_names = {
            "hello_group": "hello_slot",
            "world_group": "world_slot"
        }
        slot_names_to_entities = {
            "my_intent": {
                "hello_slot": "hello_entity",
                "world_slot": "world_entity"
            }
        }
        config = DeterministicIntentParserConfig(max_queries=42,
                                                 max_pattern_length=43)
        expected_parser = DeterministicIntentParser(config=config)
        expected_parser.language = LANGUAGE_EN
        expected_parser.group_names_to_slot_names = group_names_to_slot_names
        expected_parser.slot_names_to_entities = slot_names_to_entities
        expected_parser.patterns = patterns
        # pylint:disable=protected-access
        expected_parser._stop_words_whitelist = dict()
        # pylint:enable=protected-access

        self.assertEqual(parser.to_dict(), expected_parser.to_dict())
    def test_should_parse_intent_after_deserialization(self):
        # Given
        dataset = self.slots_dataset
        shared = self.get_shared_data(dataset)
        parser = DeterministicIntentParser(**shared).fit(dataset)
        parser.persist(self.tmp_file_path)
        deserialized_parser = DeterministicIntentParser.from_path(
            self.tmp_file_path, **shared)
        text = "this is a dummy_a query with another dummy_c at 10p.m. or " \
               "at 12p.m."

        # When
        parsing = deserialized_parser.parse(text)

        # Then
        probability = 1.0
        expected_intent = intent_classification_result(
            intent_name="dummy_intent_1", probability=probability)
        self.assertEqual(expected_intent, parsing[RES_INTENT])
Exemplo n.º 4
0
    def test_should_get_intent_after_deserialization(self):
        # Given
        dataset = validate_and_format_dataset(self.slots_dataset)

        parser = DeterministicIntentParser().fit(dataset)
        custom_entity_parser = parser.custom_entity_parser
        parser.persist(self.tmp_file_path)
        deserialized_parser = DeterministicIntentParser.from_path(
            self.tmp_file_path,
            builtin_entity_parser=BuiltinEntityParser.build(language="en"),
            custom_entity_parser=custom_entity_parser)
        text = "this is a dummy_a query with another dummy_c at 10p.m. or " \
               "at 12p.m."

        # When
        parsing = deserialized_parser.parse(text)

        # Then
        probability = 1.0
        expected_intent = intent_classification_result(
            intent_name="dummy_intent_1", probability=probability)
        self.assertEqual(expected_intent, parsing[RES_INTENT])
    def test_should_parse_slots_after_deserialization(self):
        # Given
        dataset = self.slots_dataset
        shared = self.get_shared_data(dataset)
        parser = DeterministicIntentParser(**shared).fit(dataset)
        parser.persist(self.tmp_file_path)
        deserialized_parser = DeterministicIntentParser.from_path(
            self.tmp_file_path, **shared)

        texts = [
            ("this is a dummy a query with another dummy_c at 10p.m. or at"
             " 12p.m.", [
                 unresolved_slot(match_range=(10, 17),
                                 value="dummy a",
                                 entity="dummy_entity_1",
                                 slot_name="dummy_slot_name"),
                 unresolved_slot(match_range=(37, 44),
                                 value="dummy_c",
                                 entity="dummy_entity_2",
                                 slot_name="dummy_slot_name2"),
                 unresolved_slot(match_range=(45, 54),
                                 value="at 10p.m.",
                                 entity="snips/datetime",
                                 slot_name="startTime"),
                 unresolved_slot(match_range=(58, 67),
                                 value="at 12p.m.",
                                 entity="snips/datetime",
                                 slot_name="startTime")
             ]),
            ("this, is,, a, dummy a query with another dummy_c at 10pm or "
             "at 12p.m.", [
                 unresolved_slot(match_range=(14, 21),
                                 value="dummy a",
                                 entity="dummy_entity_1",
                                 slot_name="dummy_slot_name"),
                 unresolved_slot(match_range=(41, 48),
                                 value="dummy_c",
                                 entity="dummy_entity_2",
                                 slot_name="dummy_slot_name2"),
                 unresolved_slot(match_range=(49, 56),
                                 value="at 10pm",
                                 entity="snips/datetime",
                                 slot_name="startTime"),
                 unresolved_slot(match_range=(60, 69),
                                 value="at 12p.m.",
                                 entity="snips/datetime",
                                 slot_name="startTime")
             ]),
            ("this is a dummy b", [
                unresolved_slot(match_range=(10, 17),
                                value="dummy b",
                                entity="dummy_entity_1",
                                slot_name="dummy_slot_name")
            ]),
            (" this is a dummy b ", [
                unresolved_slot(match_range=(11, 18),
                                value="dummy b",
                                entity="dummy_entity_1",
                                slot_name="dummy_slot_name")
            ])
        ]

        for text, expected_slots in texts:
            # When
            parsing = deserialized_parser.parse(text)

            # Then
            self.assertListEqual(expected_slots, parsing[RES_SLOTS])