class BOWParser:
    """
    This layer stores the NLP specific tooling and logic.

    References:
        https://stackoverflow.com/questions/42322902/how-to-get-parse-tree-using-python-nltk
        https://www.nltk.org/book/ch08.html
    """
    extra_stopwords = {'s', 'hey', 'want', 'you'}

    def __init__(self, db_path, keywords):
        # Download the stopwords if necessary.
        try:
            nltk.data.find('corpora/stopwords')
        except LookupError:
            nltk.download('stopwords')

        self.commands = keywords
        self.kb_api = KnowledgeBaseAPI(db_path)
        self.db_nouns = self.kb_api.get_all_music_entities()

    def _get_stop_words(self):
        # Remove all keywords from stopwords
        stop_words = set(stopwords.words('english'))
        stop_words |= BOWParser.extra_stopwords
        for _, words in self.commands.items():
            for word in words:
                try:
                    stop_words.remove(word)
                except:
                    pass
        return stop_words

    def _gen_patterns(self):
        # Generate RegEx patterns from keywords
        patterns = {}
        for intent, keys in self.commands.items():
            patterns[intent] = re.compile(r'\b'+r'\b|\b'.join(keys)+r'\b')
        return patterns

    def __call__(self, msg: str):
        # Identify the first subject from the database that matches.
        subjects = []
        for noun in self.db_nouns:
            if noun.lower() in msg.lower():
                pattern = re.compile(noun, re.IGNORECASE)
                msg = pattern.sub('', msg)
                subjects.append(noun.strip())

        # Remove punctuation from the string
        msg = re.sub(r"[,.;@#?!&$']+\ *",
                     " ",
                     msg,
                     flags=re.VERBOSE)

        # Clean the stopwords from the input.
        stop_words = self._get_stop_words()
        clean_msg = ' '.join([word for word in msg.lower().split(' ')
                              if word not in stop_words])

        # Parse the keywords from the filtered input.
        patterns = self._gen_patterns()
        intents = []
        for intent, pattern in patterns.items():
            sub_msg = re.sub(pattern, '', clean_msg)
            if sub_msg != clean_msg:
                intents.append(intent)
                clean_msg = sub_msg

        remaining_text = clean_msg.strip()
        return subjects, intents, remaining_text
Ejemplo n.º 2
0
class TreeParser:
    """
    This class contains tree-parsing logic that will convert
    input text to an NLTK Parse-Tree using a Context Free
    Grammar.

    References:
        https://stackoverflow.com/questions/42322902/how-to-get-parse-tree-using-python-nltk
        https://www.nltk.org/book/ch08.html

    """
    def __init__(self, db_path, keywords):
        self.keywords = keywords
        self.kb_api = KnowledgeBaseAPI(db_path)
        self.kb_named_entities = self.kb_api.get_all_music_entities()

    def __call__(self, msg: str):
        """Creates an NLTK Parse Tree from the user input msg.

        Args:
            msg: A string of user input.
                 i.e 'play something similar to Justin Bieber'

        Returns: An NLTK parse tree, as defined by the CFG given
                 in the "parser" function.

        """
        # Remove punctuation from the string
        msg = re.sub(r"[.?']+\ *", " ", msg, flags=re.VERBOSE)

        # Parse sentence into list of tokens containing
        #  only entities and commands.
        tokens = self._lexer(msg)

        # Generate an NLTK parse tree
        tree = self._parser(tokens)
        return tree

    @property
    @lru_cache(1)
    def _unary_command_regexes(self):
        """Generates RegEx patterns from command signifiers.

        """
        patterns = {}
        for intent, keys in self.keywords.get("unary").items():
            if keys:
                patterns[intent] = re.compile(r'\b' + r'\b|\b'.join(keys) +
                                              r'\b')
        return patterns

    @property
    @lru_cache(1)
    def _terminal_command_regexes(self):
        """Generates RegEx patterns from command signifiers.

        """
        patterns = {}
        for intent, keys in self.keywords.get("terminal").items():
            if keys:
                patterns[intent] = re.compile(r'\b' + r'\b|\b'.join(keys) +
                                              r'\b')
        return patterns

    @property
    @lru_cache(1)
    def _binary_command_regexes(self):
        """Generates RegEx patterns from command signifiers.

        """
        patterns = {}
        for intent, keys in self.keywords.get("binary").items():
            if keys:
                patterns[intent] = re.compile(r'\b' + r'\b|\b'.join(keys) +
                                              r'\b')
        return patterns

    def _lexer(self, msg: str):
        """Lexes an input string into a list of tokens.

        This lexer first looks for Entities in the input string
        and parses them into tokens of the same name
        (i.e. 'blah U2 blah' -> ['U2']).

        Next, the lexer will look for Commands by searching for
        keywords that signify the command. These keyword+command
        pairings are defined in the command_evaluation layer.

        Args:
            msg: A string of user input.
                 i.e 'play something similar to justin bieber'

        Returns: A tokenized list of commands and Entities.
            i.e. ['control_play', 'query_similar_entities', 'Justin Bieber']

        """
        def lexing_algorithm(text):
            # TODO: >= O(n^2) (horrible time complexity). In the interest of
            # making forward progress, optimize this later.  Could use a
            # bloom filter to look for matches, then binary search  on
            # entities/commands find specific match once possible match is
            # found.  Or even just use a hashmap for searching.

            # Base case.
            if text == "":
                return []

            # 1. Parse named entities.
            for entity in self.kb_named_entities:
                if entity.lower() in text.lower():
                    pieces = text.lower().split(entity.lower())
                    left = pieces[0]
                    right = pieces[1]
                    if left == text or right == text:
                        # Safety measure to prevent '' causing infinite recursion.
                        break
                    return lexing_algorithm(left) + [
                        entity.strip()
                    ] + lexing_algorithm(right)

            # 2. Parse unary commands.
            for intent, pattern in self._unary_command_regexes.items():
                sub_msg = re.sub(pattern, 'MARKER', text)
                if sub_msg != text:
                    pieces = sub_msg.split('MARKER')
                    left = pieces[0]
                    right = pieces[1]
                    return lexing_algorithm(left) \
                           + [intent] \
                           + lexing_algorithm(right)

            # 3. Parse terminal commands.
            for intent, pattern in self._terminal_command_regexes.items():
                sub_msg = re.sub(pattern, 'MARKER', text)
                if sub_msg != text:
                    pieces = sub_msg.split('MARKER')
                    left = pieces[0]
                    right = pieces[1]
                    return lexing_algorithm(left) \
                           + [intent] \
                           + lexing_algorithm(right)

            # 4. Parse binary commands.
            for intent, pattern in self._binary_command_regexes.items():
                sub_msg = re.sub(pattern, 'MARKER', text)
                if sub_msg != text:
                    pieces = sub_msg.split('MARKER')
                    left = pieces[0]
                    right = pieces[1]
                    return lexing_algorithm(left) \
                           + [intent] \
                           + lexing_algorithm(right)

            # If no matches, then the word is a stopword.
            return []

        return lexing_algorithm(msg)

    def _parser(self, tokens: List[str]):
        """Generates a Parse Tree from a list of tokens
        provided by the Lexer.

        Args:
            tokens: A tokenized list of commands and Entities.
            i.e. ['control_play', 'query_similar_entities', 'Justin Bieber']

        Returns: An nltk parse tree, as defined by the CFG given
                 in this function.

        """

        # TODO:   Improve the CFG work for the following:
        #          -  Play songs faster than despicito
        #          -  Play something similar to despicito but faster
        #          -  Play something similar to u2 and justin bieber

        def gen_lexing_patterns(vals: List[str]):
            # TODO: Here we remove entries containing ',
            #       as it is a special character used by
            #       the NLTK parser. We need to fix this
            #       eventually.
            safe_vals = [s for s in vals if "\'" not in s]
            return "' | '".join(safe_vals) or "NONE"

        # A Probabilistic Context Free Grammar (PCFG)
        # can be used to simulate "operator precedence",
        # which removes the problems of ambiguity in
        # the grammar.
        grammar = nltk.PCFG.fromstring("""
        Root -> Terminal_Command Result         [0.6]
        Root -> Terminal_Command                [0.4]
        Result -> Entity                        [0.5]
        Result -> Unary_Command Result          [0.1]
        Result -> Result Binary_Command Result  [0.4]
        Entity -> '{}'                          [1.0]
        Unary_Command -> '{}'                   [1.0]
        Terminal_Command -> '{}'                [1.0]
        Binary_Command -> '{}'                  [1.0]
        """.format(
            gen_lexing_patterns(self.kb_named_entities),
            gen_lexing_patterns(self.keywords.get("unary").keys()),
            gen_lexing_patterns(self.keywords.get("terminal").keys()),
            gen_lexing_patterns(self.keywords.get("binary").keys()),
        ))

        parser = nltk.ViterbiParser(grammar)
        # TODO: Returns the first tree, but need to deal with
        #       case where grammar is ambiguous, and more than
        #       one tree is returned.
        return next(parser.parse(tokens))
Ejemplo n.º 3
0
class TestMusicKnowledgeBaseAPI(unittest.TestCase):
    def setUp(self):
        DB_path = test_db_utils.create_and_populate_db()
        self.kb_api = KnowledgeBaseAPI(dbName=DB_path)

    def tearDown(self):
        test_db_utils.remove_db()

    def test_get_song_data(self):
        song_data = self.kb_api.get_song_data("Despacito")
        # we don't care what the node ID is
        self.assertEqual(
            1, len(song_data),
            "Expected exactly one result from query for song 'Despacito'.")
        self.assertEqual(
            song_data[0],
            dict(
                id=10,
                song_name="Despacito",
                artist_name="Justin Bieber",
                duration_ms=222222,
                popularity=10,
            ), "Found expected values for song data for 'Despacito'.")

    def test_get_song_data_dne(self):
        res = self.kb_api.get_song_data("Not In Database")
        self.assertEqual(
            res, [],
            "Expected empty list of results for queried song not in DB.")

    def test_get_artist_data(self):
        artist_data = self.kb_api.get_artist_data("Justin Bieber")
        self.assertEqual(
            len(artist_data), 1,
            "Expected exactly one result for artist 'Justin Bieber'.")
        artist_data[0]["genres"] = set(artist_data[0]["genres"])
        self.assertEqual(
            artist_data[0],
            dict(genres=set(["Pop", "Super pop"]),
                 id=1,
                 num_spotify_followers=4000,
                 name="Justin Bieber"),
            "Artist data for 'Justin Bieber' did not match expected.",
        )

    def test_get_artist_data_dne(self):
        artist_data = self.kb_api.get_artist_data("Unknown artist")
        self.assertEqual(artist_data, [],
                         "Expected 'None' result for unknown artist.")

    def test_get_songs(self):
        res = self.kb_api.get_songs_by_artist("Justin Bieber")
        self.assertEqual(
            res, ["Despacito", "Sorry"],
            "Songs retrieved for 'Justin Bieber' did not match expected.")

        res = self.kb_api.get_songs_by_artist("Justin Timberlake")
        self.assertEqual(
            res, ["Rock Your Body"],
            "Songs retrieved for 'Justin Timberlake' did not match expected.")

    def test_get_songs_unknown_artist(self):
        res = self.kb_api.get_songs_by_artist("Unknown artist")
        self.assertEqual(res, None,
                         "Unexpected songs retrieved for unknown artist.")

    def test_find_similar_song(self):
        res = self.kb_api.get_related_entities("Despacito")
        self.assertEqual(
            len(res),
            1,
            "Expected only one song similar to \"Despacito\". Found {0}".
            format(res),
        )
        self.assertEqual(
            res[0],
            "Rock Your Body",
            "Expected to find \"Rock Your Body\" as similar to \"Despacito\".",
        )

    def test_find_similar_artist(self):
        res = self.kb_api.get_related_entities("Justin Bieber")
        self.assertEqual(
            len(res),
            2,
            "Expected exactly two artists similar to Justin Bieber.",
        )
        self.assertEqual(
            res[0],
            "Justin Timberlake",
            "Expected to find Justin Timberlake as similar to Justin Bieber.",
        )
        self.assertEqual(
            res[1],
            "Shawn Mendes",
            "Expected to find Justin Timberlake as similar to Justin Bieber.",
        )

    def test_get_all_music_entities(self):
        res = self.kb_api.get_all_music_entities()
        self.assertTrue(
            'Justin Bieber' in res,
            'Expected to find "Justin Bieber" in the list of entities.')

    def test_find_similar_to_entity_that_dne(self):
        res = self.kb_api.get_related_entities("Unknown Entity")
        self.assertEqual(res, [])

    def test_get_related_genres(self):
        genre_rel_str = self.kb_api.approved_relations["genre"]
        rel_genres = self.kb_api.get_related_entities("Justin Bieber",
                                                      rel_str=genre_rel_str)
        self.assertEqual(
            set(rel_genres), set(["Pop", "Super pop"]),
            "Did not find expected related genres for artist 'Justin Bieber'")

        rel_genres = self.kb_api.get_related_entities("Justin Timberlake",
                                                      rel_str=genre_rel_str)
        self.assertEqual(
            rel_genres, ["Pop"],
            "Did not find expected related genres for artist 'Justin Timberlake'"
        )

    def test_connect_entities_by_similarity(self):
        res = self.kb_api.get_related_entities("Shawn Mendes")
        self.assertEqual(len(res), 0)

        res = self.kb_api.connect_entities("Shawn Mendes", "Justin Timberlake",
                                           "similar to", 0)
        self.assertEqual(res, True, "")

        res = self.kb_api.get_related_entities("Shawn Mendes")
        self.assertEqual(len(res), 1)
        self.assertEqual(res[0], "Justin Timberlake")

    def test_connect_entities_by_genre(self):
        genre_rel_str = self.kb_api.approved_relations["genre"]
        res = self.kb_api.get_related_entities("Shawn Mendes",
                                               rel_str=genre_rel_str)
        self.assertEqual(len(res), 0)

        res = self.kb_api.connect_entities("Shawn Mendes", "Pop", "of genre",
                                           100)
        self.assertEqual(res, True, "")

        res = self.kb_api.get_related_entities("Shawn Mendes",
                                               rel_str=genre_rel_str)
        self.assertEqual(
            len(res), 1,
            "Expected to find exactly one related genre for 'Shawn Mendes'.")
        self.assertEqual(res[0], "Pop",
                         "Found unexpected genre for 'Shawn Mendes'.")

    def test_rejects_connect_ambiguous_entities(self):
        self.kb_api.add_artist("Artist and Song name clash")
        self.kb_api.add_song("Artist and Song name clash", "U2")

        res = self.kb_api.connect_entities("Artist and Song name clash",
                                           "Justin Timberlake", "similar to",
                                           0)
        self.assertEqual(res, False, "")

    def test_get_node_ids_by_entity_type(self):
        node_ids_dict = self.kb_api.get_node_ids_by_entity_type(
            "Justin Bieber")
        all_entity_types = list(node_ids_dict.keys())
        self.assertEqual(
            all_entity_types, ["artist"],
            "Expected to find (only) entities of type 'artist' for 'Justin Bieber', but got: {}"
            .format(node_ids_dict))
        self.assertEqual(
            len(node_ids_dict.get("artist")), 1,
            "Expected to find exactly one entity of type 'artist' for 'Justin Bieber', but got: {}"
            .format(node_ids_dict))

        self.kb_api.add_artist("Song and Artist name")
        self.kb_api.add_song("Song and Artist name", "U2")

        node_ids_dict = self.kb_api.get_node_ids_by_entity_type(
            "Song and Artist name")
        alphabetized_entity_types = sorted(list(node_ids_dict.keys()))
        self.assertEqual(
            alphabetized_entity_types, ["artist", "song"],
            "Expected to find (exaclty) two entity types: 'artist' and 'song', but got: {}"
            .format(node_ids_dict))

    def test_get_matching_node_ids(self):
        node_ids = self.kb_api._get_matching_node_ids("Justin Bieber")
        self.assertEqual(
            len(node_ids), 1,
            "Expected to find exactly one matching node for 'Justin Bieber', but got: {}"
            .format(node_ids))

    def test_get_matching_node_ids_empty(self):
        node_ids = self.kb_api._get_matching_node_ids("Unknown artist")
        self.assertEqual(
            len(node_ids), 0,
            "Expected to find no matching node for 'Unknown artist', but got: {}"
            .format(node_ids))

    def test_add_artist(self):
        sample_genres = ["Pop", "Very pop", "Omg so pop"]
        new_artist_node_id = self.kb_api.add_artist("Heart",
                                                    genres=sample_genres,
                                                    num_spotify_followers=1)
        self.assertNotEqual(new_artist_node_id, None,
                            "Failed to add artist 'Heart' to knowledge base.")

        artist_data = self.kb_api.get_artist_data("Heart")
        self.assertEqual(len(artist_data), 1,
                         "Expected unique match for artist 'Heart'.")

        artist_data = artist_data[0]
        artist_data["genres"] = set(artist_data["genres"])

        self.assertEqual(
            artist_data,
            dict(
                name="Heart",
                id=new_artist_node_id,
                genres=set(sample_genres),
                num_spotify_followers=1,
            ),
            "Did not find expected genres for artist 'Heart'.",
        )

    def test_reject_add_artist_already_exists(self):
        artist_node_id = self.kb_api.get_artist_data("Justin Bieber")[0]["id"]
        res = self.kb_api.add_artist("Justin Bieber")
        self.assertEqual(
            res, artist_node_id,
            "Expected rejection of attempt to add artist 'Justin Bieber' to knowledge base."
        )

    def test_add_artist_omitted_opt_params(self):
        res = self.kb_api.add_artist("Heart")
        self.assertNotEqual(res, None,
                            "Failed to add artist 'Heart' to knowledge base.")

        artist_data = self.kb_api.get_artist_data("Heart")
        self.assertEqual(len(artist_data), 1,
                         "Expected unique match for artist 'Heart'.")
        self.assertEqual(artist_data[0]["name"], "Heart",
                         "Expected match for artist 'Heart'.")
        self.assertEqual(artist_data[0]["genres"], [],
                         "Expected no genres for artist 'Heart'.")
        self.assertEqual(artist_data[0]["num_spotify_followers"], None,
                         "Expected no genres for artist 'Heart'.")

    def test_add_song(self):
        new_song_node_id = self.kb_api.add_song("Heart",
                                                "Justin Bieber",
                                                duration_ms=11111,
                                                popularity=100)
        self.assertNotEqual(
            new_song_node_id, None,
            "Failed to add song 'Heart' by artist 'Justin Bieber' to knowledge base."
        )

        song_data = self.kb_api.get_song_data("Heart")
        self.assertEqual(len(song_data), 1, "Expected exactly one result.")

        song_data = song_data[0]
        self.assertEqual(
            song_data,
            dict(id=new_song_node_id,
                 song_name="Heart",
                 artist_name="Justin Bieber",
                 duration_ms=11111,
                 popularity=100), "Received unexpected song data")

    def test_add_song_omitted_opt_params(self):
        new_song_node_id = self.kb_api.add_song("What do you mean?",
                                                "Justin Bieber")
        self.assertNotEqual(
            new_song_node_id, None,
            "Failed to add song 'sorry' by artist 'Justin Bieber' to knowledge base."
        )

        song_data = self.kb_api.get_song_data("What do you mean?")
        self.assertEqual(len(song_data), 1, "Expected exactly one result.")

        song_data = song_data[0]
        self.assertEqual(
            song_data,
            dict(id=new_song_node_id,
                 song_name="What do you mean?",
                 artist_name="Justin Bieber",
                 duration_ms=None,
                 popularity=None), "Received unexpected song data")

    def test_add_duplicate_song_for_different_artist(self):
        new_song_node_id = self.kb_api.add_song("Despacito",
                                                "Justin Timberlake")
        self.assertNotEqual(
            new_song_node_id, None,
            "Failed to add song 'Despacito' by artist 'Justin Timberlake' to knowledge base."
        )

        res = self.kb_api.get_song_data("Despacito")
        self.assertEqual(len(res), 2,
                         "Expected exactly one match for song 'Despacito'.")

        artists = set([res[0]["artist_name"], res[1]["artist_name"]])
        self.assertEqual(
            artists, set(["Justin Bieber", "Justin Timberlake"]),
            "Expected to find duplicate artists 'Justin Bieber' and 'Justin Timberlake' for song 'Despacito')"
        )

    def test_reject_add_song_already_exists(self):
        res = self.kb_api.add_song("Despacito", "Justin Bieber")
        self.assertEqual(
            res, None,
            "Expected rejection of attempt to add song 'Despacito' by 'Justin Bieber' to knowledge base."
        )

    # The logic tested here is currently implemented in the KR API
    # However, if it is moved to the schema (e.g. trigger functions),
    # then this test can be moved to the schema test module
    def test_new_song_with_unknown_artist_rejected(self):
        res = self.kb_api.add_song("Song by Unknown Artist", "Unknown artist")
        self.assertEqual(res, None,
                         "Expected song with unknown artist to be rejected")

        res = self.kb_api.get_song_data("Song by Unknown Artist")
        self.assertEqual(
            len(res), 0,
            "Insertion of song with unknown artist should not have been added to nodes table"
        )

    def test_add_genre(self):
        node_id = self.kb_api.add_genre("hip hop")
        self.assertEqual(
            type(node_id), int,
            "Genre addition appears to have failed: expected int return value (node id) on valid attempt to add genre."
        )

        res = self.kb_api.add_genre("hip hop")
        self.assertEqual(
            res, node_id,
            "Expected original node id to be fetched when attempting to add duplicate genre."
        )

    def test_add_genre_creates_node(self):
        res = self.kb_api.add_genre("hip hop")
        self.assertEqual(
            type(res), int,
            "Genre addition appears to have failed: expected int return value (node id) on valid attempt to add genre."
        )

        entities = self.kb_api.get_node_ids_by_entity_type("hip hop")
        self.assertIn(
            "genre", entities,
            "Expected to find node associated with genre 'hip hop'.")

    def test_get_node_ids_by_entity_type(self):
        res = self.kb_api.get_node_ids_by_entity_type("Justin Timberlake")
        self.assertTrue(
            "artist" in res,
            "Expected to find an 'artist' entity with name 'Justin Timberlake', but got: {0}"
            .format(res))
        self.assertEqual(
            len(res["artist"]), 1,
            "Expected to find exactly one entity (of type 'artist') with name 'Justin Timberlake', but got: {0}"
            .format(res))

        res = self.kb_api.get_node_ids_by_entity_type("Despacito")
        self.assertTrue(
            "song" in res,
            "Expected to find an 'song' entity with name 'Despacito', but got: {0}"
            .format(res))
        self.assertEqual(
            len(res["song"]), 1,
            "Expected to find exactly one entity (of type 'song') with name 'Despacito', but got: {0}"
            .format(res))

        res = self.kb_api.get_node_ids_by_entity_type("Unknown entity")
        self.assertEqual(
            res, {},
            "Expected no results from query for unknown entity, but got {}".
            format(res))