Ejemplo n.º 1
0
    def build(cls, dataset, parser_usage):
        from snips_nlu.dataset import validate_and_format_dataset

        dataset = validate_and_format_dataset(dataset)
        language = dataset[LANGUAGE]
        custom_entities = {
            entity_name: deepcopy(entity)
            for entity_name, entity in iteritems(dataset[ENTITIES])
            if not is_builtin_entity(entity_name)
        }
        if parser_usage == CustomEntityParserUsage.WITH_AND_WITHOUT_STEMS:
            for ent in viewvalues(custom_entities):
                stemmed_utterances = _stem_entity_utterances(
                    ent[UTTERANCES], language)
                ent[UTTERANCES] = _merge_entity_utterances(
                    ent[UTTERANCES], stemmed_utterances)
        elif parser_usage == CustomEntityParserUsage.WITH_STEMS:
            for ent in viewvalues(custom_entities):
                ent[UTTERANCES] = _stem_entity_utterances(
                    ent[UTTERANCES], language)
        elif parser_usage is None:
            raise ValueError("A parser usage must be defined in order to fit "
                             "a CustomEntityParser")
        configuration = _create_custom_entity_parser_configuration(
            custom_entities)
        parser = GazetteerEntityParser.build(configuration)
        return cls(parser, language, parser_usage)
Ejemplo n.º 2
0
    def test_should_parse_from_built_parser_with_scope(self):
        # Given
        parser_config = self.get_test_parser_config()
        parser = GazetteerEntityParser.build(parser_config)

        # When
        text = "I want to listen to what s my age again by blink one eight two"
        res_artist = parser.parse(text, ["music_artist"])
        res_track = parser.parse(text, ["music_track"])

        # Then
        expected_artist_result = [{
            "value": "blink one eight two",
            "resolved_value": "Blink 182",
            "range": {
                "start": 43,
                "end": 62
            },
            "entity_identifier": "music_artist"
        }]

        expected_track_result = [{
            "value": "what s my age again",
            "resolved_value": "What's my age again",
            "range": {
                "start": 20,
                "end": 39
            },
            "entity_identifier": "music_track"
        }]

        self.assertListEqual(expected_artist_result, res_artist)
        self.assertListEqual(expected_track_result, res_track)
Ejemplo n.º 3
0
    def test_should_not_accept_bytes_in_scope(self):
        # Given
        scope = [b"snips/number", b"snips/datetime"]
        parser = GazetteerEntityParser.from_path(CUSTOM_PARSER_PATH)

        # When/Then
        with self.assertRaises(TypeError):
            parser.parse("Raise to sixty", scope)
Ejemplo n.º 4
0
    def test_should_not_accept_bytes_in_text(self):
        # Given
        parser = GazetteerEntityParser.from_path(CUSTOM_PARSER_PATH)
        bytes_text = b"Raise to sixty"

        # When/Then
        with self.assertRaises(TypeError):
            parser.parse(bytes_text)
Ejemplo n.º 5
0
 def from_path(cls, path):
     path = Path(path)
     with (path / "metadata.json").open(encoding="utf8") as f:
         metadata = json.load(f)
     language = metadata["language"]
     parser_usage = CustomEntityParserUsage(metadata["parser_usage"])
     parser_path = path / metadata["parser_directory"]
     parser = GazetteerEntityParser.from_path(parser_path)
     return cls(parser, language, parser_usage)
Ejemplo n.º 6
0
    def test_should_persist_parser(self):
        # Given
        parser = GazetteerEntityParser.from_path(CUSTOM_PARSER_PATH)

        # When
        with temp_dir() as tmpdir:
            persisted_path = str(tmpdir / "persisted_gazetteer_parser")
            parser.persist(persisted_path)
            loaded_parser = GazetteerEntityParser.from_path(persisted_path)
        res = loaded_parser.parse("I want to listen to the stones", None)

        # Then
        expected_result = [{
            "value": "the stones",
            "resolved_value": "The Rolling Stones",
            "range": {
                "start": 20,
                "end": 30
            },
            "entity_identifier": "music_artist"
        }]
        self.assertListEqual(expected_result, res)
Ejemplo n.º 7
0
    def test_should_load_parser_from_path(self):
        # Given
        parser = GazetteerEntityParser.from_path(CUSTOM_PARSER_PATH)

        # When
        res = parser.parse("I want to listen to the stones", None)

        # Then
        expected_result = [{
            "value": "the stones",
            "resolved_value": "The Rolling Stones",
            "range": {
                "start": 20,
                "end": 30
            },
            "entity_identifier": "music_artist"
        }]

        self.assertListEqual(expected_result, res)
Ejemplo n.º 8
0
    def test_should_parse_from_built_parser(self):
        # Given
        parser_config = self.get_test_parser_config()
        parser = GazetteerEntityParser.build(parser_config)

        # When
        res = parser.parse("I want to listen to the stones", None)

        # Then
        expected_result = [{
            "value": "the stones",
            "resolved_value": "The Rolling Stones",
            "range": {
                "start": 20,
                "end": 30
            },
            "entity_identifier": "music_artist"
        }]

        self.assertListEqual(expected_result, res)