Ejemplo n.º 1
0
    def test_should_be_serializable_before_fitting(self):
        # Given
        config = LookupIntentParserConfig(ignore_stop_words=True)
        parser = LookupIntentParser(config=config)

        # When
        parser.persist(self.tmp_file_path)

        # Then
        expected_dict = {
            "config": {
                "unit_name": "lookup_intent_parser",
                "ignore_stop_words": True,
            },
            "language_code": None,
            "intents_names": [],
            "map": None,
            "slots_names": [],
            "entity_scopes": None,
            "stop_words_whitelist": None
        }

        metadata = {"unit_name": "lookup_intent_parser"}
        self.assertJsonContent(self.tmp_file_path / "metadata.json", metadata)
        self.assertJsonContent(
            self.tmp_file_path / "intent_parser.json", expected_dict)
Ejemplo n.º 2
0
    def test_should_be_serializable(self, mock_get_stop_words):
        # Given
        dataset_stream = io.StringIO(
            """
---
type: intent
name: searchFlight
slots:
  - name: origin
    entity: city
  - name: destination
    entity: city
utterances:
  - find me a flight from [origin](Paris) to [destination](New York)
  - I need a flight to [destination](Berlin)

---
type: entity
name: city
values:
  - london
  - [new york, big apple]
  - [paris, city of lights]""")

        dataset = Dataset.from_yaml_files("en", [dataset_stream]).json

        mock_get_stop_words.return_value = {"a", "me"}
        config = LookupIntentParserConfig(ignore_stop_words=True)
        parser = LookupIntentParser(config=config).fit(dataset)

        # When
        parser.persist(self.tmp_file_path)

        # Then
        expected_dict = {
            "config": {
                "unit_name": "lookup_intent_parser",
                "ignore_stop_words": True,
            },
            "intents_names": ["searchFlight"],
            "language_code": "en",
            "map": {
                "-2020846245": [0, [0, 1]],
                "-1558674456": [0, [1]],
            },
            "slots_names": ["origin", "destination"],
            "entity_scopes": [
                {
                    "entity_scope": {"builtin": [], "custom": ["city"]},
                    "intent_group": ["searchFlight"]
                }
            ],
            "stop_words_whitelist": dict()
        }
        metadata = {"unit_name": "lookup_intent_parser"}
        self.assertJsonContent(self.tmp_file_path / "metadata.json", metadata)
        self.assertJsonContent(
            self.tmp_file_path / "intent_parser.json", expected_dict)
Ejemplo n.º 3
0
    def test_should_parse_intent_after_deserialization(self):
        # Given
        dataset = self.slots_dataset
        shared = self.get_shared_data(dataset)
        parser = LookupIntentParser(**shared).fit(dataset)
        parser.persist(self.tmp_file_path)
        deserialized_parser = LookupIntentParser.from_path(
            self.tmp_file_path, **shared)
        text = "this is a dummy_a query with another dummy_c at 10p.m. or " \
               "at 12p.m."

        # When
        parsing = deserialized_parser.parse(text)

        # Then
        probability = 1.0
        expected_intent = intent_classification_result(
            intent_name="dummy_intent_1", probability=probability)
        self.assertEqual(expected_intent, parsing[RES_INTENT])
Ejemplo n.º 4
0
    def test_should_parse_slots_after_deserialization(self):
        # Given
        dataset = self.slots_dataset
        shared = self.get_shared_data(dataset)
        parser = LookupIntentParser(**shared).fit(dataset)
        parser.persist(self.tmp_file_path)
        deserialized_parser = LookupIntentParser.from_path(
            self.tmp_file_path, **shared)

        texts = [
            (
                "this is a dummy a query with another dummy_c at 10p.m. or at"
                " 12p.m.",
                [
                    unresolved_slot(
                        match_range=(10, 17),
                        value="dummy a",
                        entity="dummy_entity_1",
                        slot_name="dummy_slot_name",
                    ),
                    unresolved_slot(
                        match_range=(37, 44),
                        value="dummy_c",
                        entity="dummy_entity_2",
                        slot_name="dummy_slot_name2",
                    ),
                    unresolved_slot(
                        match_range=(45, 54),
                        value="at 10p.m.",
                        entity="snips/datetime",
                        slot_name="startTime",
                    ),
                    unresolved_slot(
                        match_range=(58, 67),
                        value="at 12p.m.",
                        entity="snips/datetime",
                        slot_name="startTime",
                    ),
                ],
            ),
            (
                "this, is,, a, dummy a query with another dummy_c at 10pm or "
                "at 12p.m.",
                [
                    unresolved_slot(
                        match_range=(14, 21),
                        value="dummy a",
                        entity="dummy_entity_1",
                        slot_name="dummy_slot_name",
                    ),
                    unresolved_slot(
                        match_range=(41, 48),
                        value="dummy_c",
                        entity="dummy_entity_2",
                        slot_name="dummy_slot_name2",
                    ),
                    unresolved_slot(
                        match_range=(49, 56),
                        value="at 10pm",
                        entity="snips/datetime",
                        slot_name="startTime",
                    ),
                    unresolved_slot(
                        match_range=(60, 69),
                        value="at 12p.m.",
                        entity="snips/datetime",
                        slot_name="startTime",
                    ),
                ],
            ),
            (
                "this is a dummy b",
                [
                    unresolved_slot(
                        match_range=(10, 17),
                        value="dummy b",
                        entity="dummy_entity_1",
                        slot_name="dummy_slot_name",
                    )
                ],
            ),
            (
                " this is a dummy b ",
                [
                    unresolved_slot(
                        match_range=(11, 18),
                        value="dummy b",
                        entity="dummy_entity_1",
                        slot_name="dummy_slot_name",
                    )
                ],
            ),
        ]

        for text, expected_slots in texts:
            # When
            parsing = deserialized_parser.parse(text)

            # Then
            self.assertListEqual(expected_slots, parsing[RES_SLOTS])