示例#1
0
    def test_should_be_serializable_before_fit(self):
        # Given
        features_factories = [
            {
                "factory_name": ShapeNgramFactory.name,
                "args": {"n": 1},
                "offsets": [0]
            },
            {
                "factory_name": IsDigitFactory.name,
                "args": {},
                "offsets": [-1, 0]
            }
        ]
        config = CRFSlotFillerConfig(
            tagging_scheme=TaggingScheme.BILOU,
            feature_factory_configs=features_factories)

        slot_filler = CRFSlotFiller(config)

        # When
        actual_slot_filler_dict = slot_filler.to_dict()

        # Then
        expected_slot_filler_dict = {
            "unit_name": "crf_slot_filler",
            "crf_model_data": None,
            "language_code": None,
            "config": config.to_dict(),
            "intent": None,
            "slot_name_mapping": None,
        }
        self.assertDictEqual(actual_slot_filler_dict,
                             expected_slot_filler_dict)
示例#2
0
    def test_should_be_serializable(self, mock_serialize_crf_model):
        # Given
        mock_serialize_crf_model.return_value = "mocked_crf_model_data"
        features_factories = [
            {
                "factory_name": ShapeNgramFactory.name,
                "args": {"n": 1},
                "offsets": [0]
            },
            {
                "factory_name": IsDigitFactory.name,
                "args": {},
                "offsets": [-1, 0]
            }
        ]
        config = CRFSlotFillerConfig(
            tagging_scheme=TaggingScheme.BILOU,
            feature_factory_configs=features_factories)
        dataset = validate_and_format_dataset(SAMPLE_DATASET)

        slot_filler = CRFSlotFiller(config)
        intent = "dummy_intent_1"
        slot_filler.fit(dataset, intent=intent)

        # When
        actual_slot_filler_dict = slot_filler.to_dict()

        # Then
        expected_feature_factories = [
            {
                "factory_name": ShapeNgramFactory.name,
                "args": {"n": 1, "language_code": "en"},
                "offsets": [0]
            },
            {
                "factory_name": IsDigitFactory.name,
                "args": {},
                "offsets": [-1, 0]
            }
        ]
        expected_config = CRFSlotFillerConfig(
            tagging_scheme=TaggingScheme.BILOU,
            feature_factory_configs=expected_feature_factories)
        expected_slot_filler_dict = {
            "unit_name": "crf_slot_filler",
            "crf_model_data": "mocked_crf_model_data",
            "language_code": "en",
            "config": expected_config.to_dict(),
            "intent": intent,
            "slot_name_mapping": {
                "dummy_slot_name": "dummy_entity_1",
                "dummy_slot_name2": "dummy_entity_2",
                "dummy_slot_name3": "dummy_entity_2",
            }
        }
        self.assertDictEqual(actual_slot_filler_dict,
                             expected_slot_filler_dict)
示例#3
0
    def test_should_get_slots_after_deserialization(self):
        # Given
        dataset = validate_and_format_dataset(BEVERAGE_DATASET)
        config = CRFSlotFillerConfig(random_seed=42)
        intent = "MakeTea"
        slot_filler = CRFSlotFiller(config)
        slot_filler.fit(dataset, intent)
        deserialized_slot_filler = CRFSlotFiller.from_dict(
            slot_filler.to_dict())

        # When
        slots = deserialized_slot_filler.get_slots("make me two cups of tea")

        # Then
        expected_slots = [
            unresolved_slot(match_range={START: 8, END: 11},
                            value='two',
                            entity='snips/number',
                            slot_name='number_of_cups')]
        self.assertListEqual(expected_slots, slots)