Example #1
0
    def test_should_not_use_crf_when_dataset_with_no_slots(self):
        # Given
        dataset = {
            "language": "en",
            "intents": {
                "intent1": {
                    "utterances": [{
                        "data": [{
                            "text":
                            "This is an utterance without "
                            "slots"
                        }]
                    }]
                }
            },
            "entities": {}
        }
        slot_filler = CRFSlotFiller(**self.get_shared_data(dataset))
        mock_compute_features = MagicMock()
        slot_filler.compute_features = mock_compute_features

        # When
        slot_filler.fit(dataset, "intent1")
        slots = slot_filler.get_slots("This is an utterance without slots")

        # Then
        mock_compute_features.assert_not_called()
        self.assertListEqual([], slots)
Example #2
0
    def test_should_compute_features(self):
        # Given
        features_factories = [
            {
                "factory_name": NgramFactory.name,
                "args": {
                    "n": 1,
                    "use_stemming": False,
                    "common_words_gazetteer_name": None
                },
                "offsets": [0],
                "drop_out": 0.3
            },
        ]
        slot_filler_config = CRFSlotFillerConfig(
            feature_factory_configs=features_factories, random_seed=40)
        slot_filler = CRFSlotFiller(slot_filler_config)

        tokens = tokenize("foo hello world bar", LANGUAGE_EN)
        dataset = validate_and_format_dataset(SAMPLE_DATASET)
        slot_filler.fit(dataset, intent="dummy_intent_1")

        # When
        features_with_drop_out = slot_filler.compute_features(tokens, True)

        # Then
        expected_features = [
            {"ngram_1": "foo"},
            {},
            {"ngram_1": "world"},
            {},
        ]
        self.assertListEqual(expected_features, features_with_drop_out)
Example #3
0
    def test_should_compute_features(self):
        # Given
        features_factories = [
            {
                "factory_name": NgramFactory.name,
                "args": {
                    "n": 1,
                    "use_stemming": False,
                    "common_words_gazetteer_name": None
                },
                "offsets": [0],
                "drop_out": 0.3
            },
        ]
        slot_filler_config = CRFSlotFillerConfig(
            feature_factory_configs=features_factories, random_seed=40)

        tokens = tokenize("foo hello world bar", LANGUAGE_EN)
        dataset_stream = io.StringIO("""
---
type: intent
name: my_intent
utterances:
- this is [slot1:entity1](my first entity)
- this is [slot2:entity2](second_entity)""")
        dataset = Dataset.from_yaml_files("en", [dataset_stream]).json
        shared = self.get_shared_data(dataset,
                                      CustomEntityParserUsage.WITHOUT_STEMS)
        slot_filler = CRFSlotFiller(slot_filler_config, **shared)
        slot_filler.fit(dataset, intent="my_intent")

        # When
        features_with_drop_out = slot_filler.compute_features(tokens, True)

        # Then
        expected_features = [
            {
                "ngram_1": "foo"
            },
            {},
            {
                "ngram_1": "world"
            },
            {},
        ]
        self.assertListEqual(expected_features, features_with_drop_out)
    def test_augment_slots(self):
        # Given
        language = LANGUAGE_EN
        text = "Find me a flight before 10pm and after 8pm"
        tokens = tokenize(text, language)
        missing_slots = {"start_date", "end_date"}

        tags = ['O' for _ in tokens]

        def mocked_sequence_probability(_, tags_):
            tags_1 = [
                'O', 'O', 'O', 'O',
                '%sstart_date' % BEGINNING_PREFIX,
                '%sstart_date' % INSIDE_PREFIX, 'O',
                '%send_date' % BEGINNING_PREFIX,
                '%send_date' % INSIDE_PREFIX
            ]

            tags_2 = [
                'O', 'O', 'O', 'O',
                '%send_date' % BEGINNING_PREFIX,
                '%send_date' % INSIDE_PREFIX, 'O',
                '%sstart_date' % BEGINNING_PREFIX,
                '%sstart_date' % INSIDE_PREFIX
            ]

            tags_3 = ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']

            tags_4 = [
                'O', 'O', 'O', 'O', 'O', 'O', 'O',
                '%sstart_date' % BEGINNING_PREFIX,
                '%sstart_date' % INSIDE_PREFIX
            ]

            tags_5 = [
                'O', 'O', 'O', 'O', 'O', 'O', 'O',
                '%send_date' % BEGINNING_PREFIX,
                '%send_date' % INSIDE_PREFIX
            ]

            tags_6 = [
                'O', 'O', 'O', 'O',
                '%sstart_date' % BEGINNING_PREFIX,
                '%sstart_date' % INSIDE_PREFIX, 'O', 'O', 'O'
            ]

            tags_7 = [
                'O', 'O', 'O', 'O',
                '%send_date' % BEGINNING_PREFIX,
                '%send_date' % INSIDE_PREFIX, 'O', 'O', 'O'
            ]

            tags_8 = [
                'O', 'O', 'O', 'O',
                '%sstart_date' % BEGINNING_PREFIX,
                '%sstart_date' % INSIDE_PREFIX, 'O',
                '%sstart_date' % BEGINNING_PREFIX,
                '%sstart_date' % INSIDE_PREFIX
            ]

            tags_9 = [
                'O', 'O', 'O', 'O',
                '%send_date' % BEGINNING_PREFIX,
                '%send_date' % INSIDE_PREFIX, 'O',
                '%send_date' % BEGINNING_PREFIX,
                '%send_date' % INSIDE_PREFIX
            ]

            if tags_ == tags_1:
                return 0.6
            elif tags_ == tags_2:
                return 0.8
            elif tags_ == tags_3:
                return 0.2
            elif tags_ == tags_4:
                return 0.2
            elif tags_ == tags_5:
                return 0.99
            elif tags_ == tags_6:
                return 0.0
            elif tags_ == tags_7:
                return 0.0
            elif tags_ == tags_8:
                return 0.5
            elif tags_ == tags_9:
                return 0.5
            else:
                raise ValueError("Unexpected tag sequence: %s" % tags_)

        slot_filler_config = CRFSlotFillerConfig(random_seed=42)
        slot_filler = CRFSlotFiller(config=slot_filler_config)
        slot_filler.language = LANGUAGE_EN
        slot_filler.intent = "intent1"
        slot_filler.slot_name_mapping = {
            "start_date": "snips/datetime",
            "end_date": "snips/datetime",
        }

        # pylint:disable=protected-access
        slot_filler._get_sequence_probability = MagicMock(
            side_effect=mocked_sequence_probability)
        # pylint:enable=protected-access

        slot_filler.compute_features = MagicMock(return_value=None)

        # When
        # pylint: disable=protected-access
        augmented_slots = slot_filler._augment_slots(text, tokens, tags,
                                                     missing_slots)
        # pylint: enable=protected-access

        # Then
        expected_slots = [
            unresolved_slot(value='after 8pm',
                            match_range={
                                START: 33,
                                END: 42
                            },
                            entity='snips/datetime',
                            slot_name='end_date')
        ]
        self.assertListEqual(augmented_slots, expected_slots)