class BOWParser: """ This layer stores the NLP specific tooling and logic. References: https://stackoverflow.com/questions/42322902/how-to-get-parse-tree-using-python-nltk https://www.nltk.org/book/ch08.html """ extra_stopwords = {'s', 'hey', 'want', 'you'} def __init__(self, db_path, keywords): # Download the stopwords if necessary. try: nltk.data.find('corpora/stopwords') except LookupError: nltk.download('stopwords') self.commands = keywords self.kb_api = KnowledgeBaseAPI(db_path) self.db_nouns = self.kb_api.get_all_music_entities() def _get_stop_words(self): # Remove all keywords from stopwords stop_words = set(stopwords.words('english')) stop_words |= BOWParser.extra_stopwords for _, words in self.commands.items(): for word in words: try: stop_words.remove(word) except: pass return stop_words def _gen_patterns(self): # Generate RegEx patterns from keywords patterns = {} for intent, keys in self.commands.items(): patterns[intent] = re.compile(r'\b'+r'\b|\b'.join(keys)+r'\b') return patterns def __call__(self, msg: str): # Identify the first subject from the database that matches. subjects = [] for noun in self.db_nouns: if noun.lower() in msg.lower(): pattern = re.compile(noun, re.IGNORECASE) msg = pattern.sub('', msg) subjects.append(noun.strip()) # Remove punctuation from the string msg = re.sub(r"[,.;@#?!&$']+\ *", " ", msg, flags=re.VERBOSE) # Clean the stopwords from the input. stop_words = self._get_stop_words() clean_msg = ' '.join([word for word in msg.lower().split(' ') if word not in stop_words]) # Parse the keywords from the filtered input. patterns = self._gen_patterns() intents = [] for intent, pattern in patterns.items(): sub_msg = re.sub(pattern, '', clean_msg) if sub_msg != clean_msg: intents.append(intent) clean_msg = sub_msg remaining_text = clean_msg.strip() return subjects, intents, remaining_text
class TreeParser: """ This class contains tree-parsing logic that will convert input text to an NLTK Parse-Tree using a Context Free Grammar. References: https://stackoverflow.com/questions/42322902/how-to-get-parse-tree-using-python-nltk https://www.nltk.org/book/ch08.html """ def __init__(self, db_path, keywords): self.keywords = keywords self.kb_api = KnowledgeBaseAPI(db_path) self.kb_named_entities = self.kb_api.get_all_music_entities() def __call__(self, msg: str): """Creates an NLTK Parse Tree from the user input msg. Args: msg: A string of user input. i.e 'play something similar to Justin Bieber' Returns: An NLTK parse tree, as defined by the CFG given in the "parser" function. """ # Remove punctuation from the string msg = re.sub(r"[.?']+\ *", " ", msg, flags=re.VERBOSE) # Parse sentence into list of tokens containing # only entities and commands. tokens = self._lexer(msg) # Generate an NLTK parse tree tree = self._parser(tokens) return tree @property @lru_cache(1) def _unary_command_regexes(self): """Generates RegEx patterns from command signifiers. """ patterns = {} for intent, keys in self.keywords.get("unary").items(): if keys: patterns[intent] = re.compile(r'\b' + r'\b|\b'.join(keys) + r'\b') return patterns @property @lru_cache(1) def _terminal_command_regexes(self): """Generates RegEx patterns from command signifiers. """ patterns = {} for intent, keys in self.keywords.get("terminal").items(): if keys: patterns[intent] = re.compile(r'\b' + r'\b|\b'.join(keys) + r'\b') return patterns @property @lru_cache(1) def _binary_command_regexes(self): """Generates RegEx patterns from command signifiers. """ patterns = {} for intent, keys in self.keywords.get("binary").items(): if keys: patterns[intent] = re.compile(r'\b' + r'\b|\b'.join(keys) + r'\b') return patterns def _lexer(self, msg: str): """Lexes an input string into a list of tokens. This lexer first looks for Entities in the input string and parses them into tokens of the same name (i.e. 'blah U2 blah' -> ['U2']). Next, the lexer will look for Commands by searching for keywords that signify the command. These keyword+command pairings are defined in the command_evaluation layer. Args: msg: A string of user input. i.e 'play something similar to justin bieber' Returns: A tokenized list of commands and Entities. i.e. ['control_play', 'query_similar_entities', 'Justin Bieber'] """ def lexing_algorithm(text): # TODO: >= O(n^2) (horrible time complexity). In the interest of # making forward progress, optimize this later. Could use a # bloom filter to look for matches, then binary search on # entities/commands find specific match once possible match is # found. Or even just use a hashmap for searching. # Base case. if text == "": return [] # 1. Parse named entities. for entity in self.kb_named_entities: if entity.lower() in text.lower(): pieces = text.lower().split(entity.lower()) left = pieces[0] right = pieces[1] if left == text or right == text: # Safety measure to prevent '' causing infinite recursion. break return lexing_algorithm(left) + [ entity.strip() ] + lexing_algorithm(right) # 2. Parse unary commands. for intent, pattern in self._unary_command_regexes.items(): sub_msg = re.sub(pattern, 'MARKER', text) if sub_msg != text: pieces = sub_msg.split('MARKER') left = pieces[0] right = pieces[1] return lexing_algorithm(left) \ + [intent] \ + lexing_algorithm(right) # 3. Parse terminal commands. for intent, pattern in self._terminal_command_regexes.items(): sub_msg = re.sub(pattern, 'MARKER', text) if sub_msg != text: pieces = sub_msg.split('MARKER') left = pieces[0] right = pieces[1] return lexing_algorithm(left) \ + [intent] \ + lexing_algorithm(right) # 4. Parse binary commands. for intent, pattern in self._binary_command_regexes.items(): sub_msg = re.sub(pattern, 'MARKER', text) if sub_msg != text: pieces = sub_msg.split('MARKER') left = pieces[0] right = pieces[1] return lexing_algorithm(left) \ + [intent] \ + lexing_algorithm(right) # If no matches, then the word is a stopword. return [] return lexing_algorithm(msg) def _parser(self, tokens: List[str]): """Generates a Parse Tree from a list of tokens provided by the Lexer. Args: tokens: A tokenized list of commands and Entities. i.e. ['control_play', 'query_similar_entities', 'Justin Bieber'] Returns: An nltk parse tree, as defined by the CFG given in this function. """ # TODO: Improve the CFG work for the following: # - Play songs faster than despicito # - Play something similar to despicito but faster # - Play something similar to u2 and justin bieber def gen_lexing_patterns(vals: List[str]): # TODO: Here we remove entries containing ', # as it is a special character used by # the NLTK parser. We need to fix this # eventually. safe_vals = [s for s in vals if "\'" not in s] return "' | '".join(safe_vals) or "NONE" # A Probabilistic Context Free Grammar (PCFG) # can be used to simulate "operator precedence", # which removes the problems of ambiguity in # the grammar. grammar = nltk.PCFG.fromstring(""" Root -> Terminal_Command Result [0.6] Root -> Terminal_Command [0.4] Result -> Entity [0.5] Result -> Unary_Command Result [0.1] Result -> Result Binary_Command Result [0.4] Entity -> '{}' [1.0] Unary_Command -> '{}' [1.0] Terminal_Command -> '{}' [1.0] Binary_Command -> '{}' [1.0] """.format( gen_lexing_patterns(self.kb_named_entities), gen_lexing_patterns(self.keywords.get("unary").keys()), gen_lexing_patterns(self.keywords.get("terminal").keys()), gen_lexing_patterns(self.keywords.get("binary").keys()), )) parser = nltk.ViterbiParser(grammar) # TODO: Returns the first tree, but need to deal with # case where grammar is ambiguous, and more than # one tree is returned. return next(parser.parse(tokens))
class TestMusicKnowledgeBaseAPI(unittest.TestCase): def setUp(self): DB_path = test_db_utils.create_and_populate_db() self.kb_api = KnowledgeBaseAPI(dbName=DB_path) def tearDown(self): test_db_utils.remove_db() def test_get_song_data(self): song_data = self.kb_api.get_song_data("Despacito") # we don't care what the node ID is self.assertEqual( 1, len(song_data), "Expected exactly one result from query for song 'Despacito'.") self.assertEqual( song_data[0], dict( id=10, song_name="Despacito", artist_name="Justin Bieber", duration_ms=222222, popularity=10, ), "Found expected values for song data for 'Despacito'.") def test_get_song_data_dne(self): res = self.kb_api.get_song_data("Not In Database") self.assertEqual( res, [], "Expected empty list of results for queried song not in DB.") def test_get_artist_data(self): artist_data = self.kb_api.get_artist_data("Justin Bieber") self.assertEqual( len(artist_data), 1, "Expected exactly one result for artist 'Justin Bieber'.") artist_data[0]["genres"] = set(artist_data[0]["genres"]) self.assertEqual( artist_data[0], dict(genres=set(["Pop", "Super pop"]), id=1, num_spotify_followers=4000, name="Justin Bieber"), "Artist data for 'Justin Bieber' did not match expected.", ) def test_get_artist_data_dne(self): artist_data = self.kb_api.get_artist_data("Unknown artist") self.assertEqual(artist_data, [], "Expected 'None' result for unknown artist.") def test_get_songs(self): res = self.kb_api.get_songs_by_artist("Justin Bieber") self.assertEqual( res, ["Despacito", "Sorry"], "Songs retrieved for 'Justin Bieber' did not match expected.") res = self.kb_api.get_songs_by_artist("Justin Timberlake") self.assertEqual( res, ["Rock Your Body"], "Songs retrieved for 'Justin Timberlake' did not match expected.") def test_get_songs_unknown_artist(self): res = self.kb_api.get_songs_by_artist("Unknown artist") self.assertEqual(res, None, "Unexpected songs retrieved for unknown artist.") def test_find_similar_song(self): res = self.kb_api.get_related_entities("Despacito") self.assertEqual( len(res), 1, "Expected only one song similar to \"Despacito\". Found {0}". format(res), ) self.assertEqual( res[0], "Rock Your Body", "Expected to find \"Rock Your Body\" as similar to \"Despacito\".", ) def test_find_similar_artist(self): res = self.kb_api.get_related_entities("Justin Bieber") self.assertEqual( len(res), 2, "Expected exactly two artists similar to Justin Bieber.", ) self.assertEqual( res[0], "Justin Timberlake", "Expected to find Justin Timberlake as similar to Justin Bieber.", ) self.assertEqual( res[1], "Shawn Mendes", "Expected to find Justin Timberlake as similar to Justin Bieber.", ) def test_get_all_music_entities(self): res = self.kb_api.get_all_music_entities() self.assertTrue( 'Justin Bieber' in res, 'Expected to find "Justin Bieber" in the list of entities.') def test_find_similar_to_entity_that_dne(self): res = self.kb_api.get_related_entities("Unknown Entity") self.assertEqual(res, []) def test_get_related_genres(self): genre_rel_str = self.kb_api.approved_relations["genre"] rel_genres = self.kb_api.get_related_entities("Justin Bieber", rel_str=genre_rel_str) self.assertEqual( set(rel_genres), set(["Pop", "Super pop"]), "Did not find expected related genres for artist 'Justin Bieber'") rel_genres = self.kb_api.get_related_entities("Justin Timberlake", rel_str=genre_rel_str) self.assertEqual( rel_genres, ["Pop"], "Did not find expected related genres for artist 'Justin Timberlake'" ) def test_connect_entities_by_similarity(self): res = self.kb_api.get_related_entities("Shawn Mendes") self.assertEqual(len(res), 0) res = self.kb_api.connect_entities("Shawn Mendes", "Justin Timberlake", "similar to", 0) self.assertEqual(res, True, "") res = self.kb_api.get_related_entities("Shawn Mendes") self.assertEqual(len(res), 1) self.assertEqual(res[0], "Justin Timberlake") def test_connect_entities_by_genre(self): genre_rel_str = self.kb_api.approved_relations["genre"] res = self.kb_api.get_related_entities("Shawn Mendes", rel_str=genre_rel_str) self.assertEqual(len(res), 0) res = self.kb_api.connect_entities("Shawn Mendes", "Pop", "of genre", 100) self.assertEqual(res, True, "") res = self.kb_api.get_related_entities("Shawn Mendes", rel_str=genre_rel_str) self.assertEqual( len(res), 1, "Expected to find exactly one related genre for 'Shawn Mendes'.") self.assertEqual(res[0], "Pop", "Found unexpected genre for 'Shawn Mendes'.") def test_rejects_connect_ambiguous_entities(self): self.kb_api.add_artist("Artist and Song name clash") self.kb_api.add_song("Artist and Song name clash", "U2") res = self.kb_api.connect_entities("Artist and Song name clash", "Justin Timberlake", "similar to", 0) self.assertEqual(res, False, "") def test_get_node_ids_by_entity_type(self): node_ids_dict = self.kb_api.get_node_ids_by_entity_type( "Justin Bieber") all_entity_types = list(node_ids_dict.keys()) self.assertEqual( all_entity_types, ["artist"], "Expected to find (only) entities of type 'artist' for 'Justin Bieber', but got: {}" .format(node_ids_dict)) self.assertEqual( len(node_ids_dict.get("artist")), 1, "Expected to find exactly one entity of type 'artist' for 'Justin Bieber', but got: {}" .format(node_ids_dict)) self.kb_api.add_artist("Song and Artist name") self.kb_api.add_song("Song and Artist name", "U2") node_ids_dict = self.kb_api.get_node_ids_by_entity_type( "Song and Artist name") alphabetized_entity_types = sorted(list(node_ids_dict.keys())) self.assertEqual( alphabetized_entity_types, ["artist", "song"], "Expected to find (exaclty) two entity types: 'artist' and 'song', but got: {}" .format(node_ids_dict)) def test_get_matching_node_ids(self): node_ids = self.kb_api._get_matching_node_ids("Justin Bieber") self.assertEqual( len(node_ids), 1, "Expected to find exactly one matching node for 'Justin Bieber', but got: {}" .format(node_ids)) def test_get_matching_node_ids_empty(self): node_ids = self.kb_api._get_matching_node_ids("Unknown artist") self.assertEqual( len(node_ids), 0, "Expected to find no matching node for 'Unknown artist', but got: {}" .format(node_ids)) def test_add_artist(self): sample_genres = ["Pop", "Very pop", "Omg so pop"] new_artist_node_id = self.kb_api.add_artist("Heart", genres=sample_genres, num_spotify_followers=1) self.assertNotEqual(new_artist_node_id, None, "Failed to add artist 'Heart' to knowledge base.") artist_data = self.kb_api.get_artist_data("Heart") self.assertEqual(len(artist_data), 1, "Expected unique match for artist 'Heart'.") artist_data = artist_data[0] artist_data["genres"] = set(artist_data["genres"]) self.assertEqual( artist_data, dict( name="Heart", id=new_artist_node_id, genres=set(sample_genres), num_spotify_followers=1, ), "Did not find expected genres for artist 'Heart'.", ) def test_reject_add_artist_already_exists(self): artist_node_id = self.kb_api.get_artist_data("Justin Bieber")[0]["id"] res = self.kb_api.add_artist("Justin Bieber") self.assertEqual( res, artist_node_id, "Expected rejection of attempt to add artist 'Justin Bieber' to knowledge base." ) def test_add_artist_omitted_opt_params(self): res = self.kb_api.add_artist("Heart") self.assertNotEqual(res, None, "Failed to add artist 'Heart' to knowledge base.") artist_data = self.kb_api.get_artist_data("Heart") self.assertEqual(len(artist_data), 1, "Expected unique match for artist 'Heart'.") self.assertEqual(artist_data[0]["name"], "Heart", "Expected match for artist 'Heart'.") self.assertEqual(artist_data[0]["genres"], [], "Expected no genres for artist 'Heart'.") self.assertEqual(artist_data[0]["num_spotify_followers"], None, "Expected no genres for artist 'Heart'.") def test_add_song(self): new_song_node_id = self.kb_api.add_song("Heart", "Justin Bieber", duration_ms=11111, popularity=100) self.assertNotEqual( new_song_node_id, None, "Failed to add song 'Heart' by artist 'Justin Bieber' to knowledge base." ) song_data = self.kb_api.get_song_data("Heart") self.assertEqual(len(song_data), 1, "Expected exactly one result.") song_data = song_data[0] self.assertEqual( song_data, dict(id=new_song_node_id, song_name="Heart", artist_name="Justin Bieber", duration_ms=11111, popularity=100), "Received unexpected song data") def test_add_song_omitted_opt_params(self): new_song_node_id = self.kb_api.add_song("What do you mean?", "Justin Bieber") self.assertNotEqual( new_song_node_id, None, "Failed to add song 'sorry' by artist 'Justin Bieber' to knowledge base." ) song_data = self.kb_api.get_song_data("What do you mean?") self.assertEqual(len(song_data), 1, "Expected exactly one result.") song_data = song_data[0] self.assertEqual( song_data, dict(id=new_song_node_id, song_name="What do you mean?", artist_name="Justin Bieber", duration_ms=None, popularity=None), "Received unexpected song data") def test_add_duplicate_song_for_different_artist(self): new_song_node_id = self.kb_api.add_song("Despacito", "Justin Timberlake") self.assertNotEqual( new_song_node_id, None, "Failed to add song 'Despacito' by artist 'Justin Timberlake' to knowledge base." ) res = self.kb_api.get_song_data("Despacito") self.assertEqual(len(res), 2, "Expected exactly one match for song 'Despacito'.") artists = set([res[0]["artist_name"], res[1]["artist_name"]]) self.assertEqual( artists, set(["Justin Bieber", "Justin Timberlake"]), "Expected to find duplicate artists 'Justin Bieber' and 'Justin Timberlake' for song 'Despacito')" ) def test_reject_add_song_already_exists(self): res = self.kb_api.add_song("Despacito", "Justin Bieber") self.assertEqual( res, None, "Expected rejection of attempt to add song 'Despacito' by 'Justin Bieber' to knowledge base." ) # The logic tested here is currently implemented in the KR API # However, if it is moved to the schema (e.g. trigger functions), # then this test can be moved to the schema test module def test_new_song_with_unknown_artist_rejected(self): res = self.kb_api.add_song("Song by Unknown Artist", "Unknown artist") self.assertEqual(res, None, "Expected song with unknown artist to be rejected") res = self.kb_api.get_song_data("Song by Unknown Artist") self.assertEqual( len(res), 0, "Insertion of song with unknown artist should not have been added to nodes table" ) def test_add_genre(self): node_id = self.kb_api.add_genre("hip hop") self.assertEqual( type(node_id), int, "Genre addition appears to have failed: expected int return value (node id) on valid attempt to add genre." ) res = self.kb_api.add_genre("hip hop") self.assertEqual( res, node_id, "Expected original node id to be fetched when attempting to add duplicate genre." ) def test_add_genre_creates_node(self): res = self.kb_api.add_genre("hip hop") self.assertEqual( type(res), int, "Genre addition appears to have failed: expected int return value (node id) on valid attempt to add genre." ) entities = self.kb_api.get_node_ids_by_entity_type("hip hop") self.assertIn( "genre", entities, "Expected to find node associated with genre 'hip hop'.") def test_get_node_ids_by_entity_type(self): res = self.kb_api.get_node_ids_by_entity_type("Justin Timberlake") self.assertTrue( "artist" in res, "Expected to find an 'artist' entity with name 'Justin Timberlake', but got: {0}" .format(res)) self.assertEqual( len(res["artist"]), 1, "Expected to find exactly one entity (of type 'artist') with name 'Justin Timberlake', but got: {0}" .format(res)) res = self.kb_api.get_node_ids_by_entity_type("Despacito") self.assertTrue( "song" in res, "Expected to find an 'song' entity with name 'Despacito', but got: {0}" .format(res)) self.assertEqual( len(res["song"]), 1, "Expected to find exactly one entity (of type 'song') with name 'Despacito', but got: {0}" .format(res)) res = self.kb_api.get_node_ids_by_entity_type("Unknown entity") self.assertEqual( res, {}, "Expected no results from query for unknown entity, but got {}". format(res))