def __init__(self, db_path, keywords): # Download the stopwords if necessary. try: nltk.data.find('corpora/stopwords') except LookupError: nltk.download('stopwords') self.commands = keywords self.kb_api = KnowledgeBaseAPI(db_path) self.db_nouns = self.kb_api.get_all_music_entities()
def setUp(self): self.DB_path = test_db_utils.create_and_populate_db() self.kb_api = KnowledgeBaseAPI(self.DB_path) self.results_dict = {} self.player_controller = MockController(self.results_dict) self.interactions = BOWEvalEngine(self.DB_path, self.player_controller) self.nlp = BOWParser(self.DB_path, self.interactions.keywords) self.keywords = self.interactions.keywords
def __init__(self, db_path, player_controller, parser_type='BagOfWords'): self.DB_path = db_path self.kb_api = KnowledgeBaseAPI(self.DB_path) self.parser_type = parser_type if parser_type == 'BagOfWords': self.eval_engine = BOWEvalEngine(self.DB_path, player_controller) self.parser = BOWParser(self.DB_path, self.eval_engine.keywords) elif parser_type == 'TREE': self.eval_engine = TreeEvalEngine(self.DB_path, player_controller) self.parser = TreeParser(self.DB_path, self.eval_engine.keywords)
def create_and_populate_db_with_spotify(spotify_client_id, spotify_secret_key, artists, path=None): path_to_db = create_db(path=path) artist_metadata = get_artist_metadata(SpotifyClient(spotify_client_id, spotify_secret_key), artists) kb_api = KnowledgeBaseAPI(path_to_db) for artist_name, artist_info in artist_metadata.items(): kb_api.add_artist(artist_name, artist_info["genres"], artist_info["num_followers"]) for song_name, song_info in artist_info["songs"].items(): kb_api.add_song( song_name, artist_name, song_info["popularity"], song_info["duration_ms"], ) for rel_artist_name, rel_artist_info in artist_info["related_artists"].items(): kb_api.add_artist(rel_artist_name, rel_artist_info["genres"], rel_artist_info["num_followers"]) kb_api.connect_entities(artist_name, rel_artist_name, "similar to", 100) kb_api.connect_entities(rel_artist_name, artist_name, "similar to", 100) return path_to_db
def create_and_populate_db_with_spotify(spotify_client_id, spotify_secret_key, artists, path=None): """Pull data from Spotify for given artists and adds it to knowledge base through its API. For each of the given artists, find and add all of the following to the knowledge base: - general artist metadata (e.g. number of followers) - top songs - related artists, along with their own metadata and top songs Params: spotify_client_id (str) e.g. "". spotify_secret_key (str) e.g. "". artists (iterable) each element is an artist name (str). For example, might be stdin, or open file, or list. path (str): relative path e.g. "knowledge_base.db". Returns: path_to_db (str): relative path to newly created db e.g. "knowledge_base/knowledge_base.db" """ from utils.spotify_client import SpotifyClient path_to_db = create_db(path=path) spotify = SpotifyClient(spotify_client_id, spotify_secret_key) artist_metadata = _get_artist_metadata(spotify, artists) kb_api = KnowledgeBaseAPI(path_to_db) for artist_name, artist_info in artist_metadata.items(): kb_api.add_artist(artist_name, artist_info["genres"], artist_info["num_followers"]) songs = spotify.get_top_songs(artist_info['id'], SPOTIFY_COUNTRY_ISO) audio_features = spotify.get_audio_features( [song_info['id'] for _, song_info in songs.items()] ) _insert_songs(songs, artist_name, audio_features, kb_api) for rel_artist_name, rel_artist_info in artist_info["related_artists"].items(): kb_api.add_artist(rel_artist_name, rel_artist_info["genres"], rel_artist_info["num_followers"]) kb_api.connect_entities(artist_name, rel_artist_name, "similar to", 100) kb_api.connect_entities(rel_artist_name, artist_name, "similar to", 100) songs = spotify.get_top_songs(rel_artist_info['id'], SPOTIFY_COUNTRY_ISO) audio_features = spotify.get_audio_features( [song_info['id'] for _, song_info in songs.items()] ) _insert_songs(songs, rel_artist_name, audio_features, kb_api) return path_to_db
for t in test_cases: res = remove_dash_section(t['input']) if res != t['expected']: print(f"FAIL: '{t['input']}' expected '{t['expected']}' got '{res}'") fails += 1 total += 1 print(f"Finished running test_remove_dash_section: Ran {total} tests with {fails} failures.") if __name__ == "__main__": if len(sys.argv) > 1 and sys.argv[1] == "-t": print("Running tests...") test_remove_parenthised_section() test_remove_dash_section() else: music_api = KnowledgeBaseAPI("knowledge_base/knowledge_base.db") # print songs if sys.argv[1] == "-s": print("Fetching song names..") song_names = music_api.get_all_song_names() print_to_csv(song_names, SONG_FILE_NAME) print(f"Wrote to {SONG_FILE_NAME}") # print artists elif sys.argv[1] == "-a": print("Fetching artist names..") artist_names = music_api.get_all_artist_names() print_to_csv(artist_names, ARTIST_FILE_NAME) print(f"Wrote to {ARTIST_FILE_NAME}")
""" from flask import Flask, render_template from flask_socketio import SocketIO import json import logging import os import random import sys sys.path.extend(['.', '../']) # needed for import statement(s) below from knowledge_base.api import KnowledgeBaseAPI app = Flask(__name__) socket_io = SocketIO(app) music_api = KnowledgeBaseAPI('knowledge_base/knowledge_base.db') CLIENT_SESSION_KEYS, NEW_CLIENT_IDX, = [None for i in range(100)], 0 CLIENT_COUNT, MAX_CLIENT_SESSIONS = 0, 100 @app.route('/') def get_player(): return render_template("player.html") @socket_io.on("start session") def start_user_session(): global NEW_CLIENT_IDX, CLIENT_COUNT print(f"Received request from client to begin session..")
def setUpClass(self): DB_path = test_db_utils.create_and_populate_db() self.kb_api = KnowledgeBaseAPI(dbName=DB_path)
class TestDbSchema(unittest.TestCase): @classmethod def setUpClass(self): DB_path = test_db_utils.create_and_populate_db() self.kb_api = KnowledgeBaseAPI(dbName=DB_path) @classmethod def tearDownClass(self): test_db_utils.remove_db() def test_rejects_unknown_entity(self): res = self.kb_api.connect_entities("Unknown Entity", "Justin Timberlake", "similar to", 0) self.assertEqual( res, False, "Expected attempt to connect an unknown entity to fail.") res = self.kb_api.get_related_entities("Unknown Entity") self.assertEqual(res, []) def test_rejects_score_out_of_range(self): res = self.kb_api.connect_entities("Justin Timberlake", "Justin Timberlake", "similar to", -1) self.assertEqual( res, False, "Expected attempt to connect entities with score out-of-range to fail." ) res = self.kb_api.get_related_entities("Justin Timberlake") self.assertEqual(len(res), 0) def test_rejects_duplicate_edge(self): res = self.kb_api.connect_entities("Justin Bieber", "Justin Timberlake", "similar to", 1) self.assertEqual(res, False, "Expected attempt to add a duplicate edge to fail.") def test_edges_not_null_constraints(self): res = self.kb_api.connect_entities(None, "Justin Timberlake", "similar to", 1) self.assertEqual(res, False, "Expected 'None' value for artist to be rejected.") res = self.kb_api.connect_entities("U2", "U2", None, 1) self.assertEqual( res, False, "Expected 'None' value for edge type to be rejected.") res = self.kb_api.connect_entities("U2", "U2", "similar to", None) self.assertEqual( res, False, "Expected 'None' value for edge score to be rejected.") def test_entities_not_null_constraints(self): res = self.kb_api.add_artist(None) self.assertEqual(res, None, "Expected 'None' value for artist to be rejected.") res = self.kb_api.add_song("Song name", None) self.assertEqual(res, None, "Expected 'None' value for artist to be rejected.") res = self.kb_api.add_song(None, "Artist name") self.assertEqual(res, None, "Expected 'None' value for artist to be rejected.") node_id = self.kb_api._add_node(None, "artist") self.assertEqual( node_id, None, "Expected 'None' value for entity name to be rejected.") node_id = self.kb_api._add_node("Some entity", None) self.assertEqual( node_id, None, "Expected 'None' value for entity type to be rejected.")
def __init__(self, db_path, keywords): self.keywords = keywords self.kb_api = KnowledgeBaseAPI(db_path) self.kb_named_entities = self.kb_api.get_all_music_entities()
class TreeParser: """ This class contains tree-parsing logic that will convert input text to an NLTK Parse-Tree using a Context Free Grammar. References: https://stackoverflow.com/questions/42322902/how-to-get-parse-tree-using-python-nltk https://www.nltk.org/book/ch08.html """ def __init__(self, db_path, keywords): self.keywords = keywords self.kb_api = KnowledgeBaseAPI(db_path) self.kb_named_entities = self.kb_api.get_all_music_entities() def __call__(self, msg: str): """Creates an NLTK Parse Tree from the user input msg. Args: msg: A string of user input. i.e 'play something similar to Justin Bieber' Returns: An NLTK parse tree, as defined by the CFG given in the "parser" function. """ # Remove punctuation from the string msg = re.sub(r"[.?']+\ *", " ", msg, flags=re.VERBOSE) # Parse sentence into list of tokens containing # only entities and commands. tokens = self._lexer(msg) # Generate an NLTK parse tree tree = self._parser(tokens) return tree @property @lru_cache(1) def _unary_command_regexes(self): """Generates RegEx patterns from command signifiers. """ patterns = {} for intent, keys in self.keywords.get("unary").items(): if keys: patterns[intent] = re.compile(r'\b' + r'\b|\b'.join(keys) + r'\b') return patterns @property @lru_cache(1) def _terminal_command_regexes(self): """Generates RegEx patterns from command signifiers. """ patterns = {} for intent, keys in self.keywords.get("terminal").items(): if keys: patterns[intent] = re.compile(r'\b' + r'\b|\b'.join(keys) + r'\b') return patterns @property @lru_cache(1) def _binary_command_regexes(self): """Generates RegEx patterns from command signifiers. """ patterns = {} for intent, keys in self.keywords.get("binary").items(): if keys: patterns[intent] = re.compile(r'\b' + r'\b|\b'.join(keys) + r'\b') return patterns def _lexer(self, msg: str): """Lexes an input string into a list of tokens. This lexer first looks for Entities in the input string and parses them into tokens of the same name (i.e. 'blah U2 blah' -> ['U2']). Next, the lexer will look for Commands by searching for keywords that signify the command. These keyword+command pairings are defined in the command_evaluation layer. Args: msg: A string of user input. i.e 'play something similar to justin bieber' Returns: A tokenized list of commands and Entities. i.e. ['control_play', 'query_similar_entities', 'Justin Bieber'] """ def lexing_algorithm(text): # TODO: >= O(n^2) (horrible time complexity). In the interest of # making forward progress, optimize this later. Could use a # bloom filter to look for matches, then binary search on # entities/commands find specific match once possible match is # found. Or even just use a hashmap for searching. # Base case. if text == "": return [] # 1. Parse named entities. for entity in self.kb_named_entities: if entity.lower() in text.lower(): pieces = text.lower().split(entity.lower()) left = pieces[0] right = pieces[1] if left == text or right == text: # Safety measure to prevent '' causing infinite recursion. break return lexing_algorithm(left) + [ entity.strip() ] + lexing_algorithm(right) # 2. Parse unary commands. for intent, pattern in self._unary_command_regexes.items(): sub_msg = re.sub(pattern, 'MARKER', text) if sub_msg != text: pieces = sub_msg.split('MARKER') left = pieces[0] right = pieces[1] return lexing_algorithm(left) \ + [intent] \ + lexing_algorithm(right) # 3. Parse terminal commands. for intent, pattern in self._terminal_command_regexes.items(): sub_msg = re.sub(pattern, 'MARKER', text) if sub_msg != text: pieces = sub_msg.split('MARKER') left = pieces[0] right = pieces[1] return lexing_algorithm(left) \ + [intent] \ + lexing_algorithm(right) # 4. Parse binary commands. for intent, pattern in self._binary_command_regexes.items(): sub_msg = re.sub(pattern, 'MARKER', text) if sub_msg != text: pieces = sub_msg.split('MARKER') left = pieces[0] right = pieces[1] return lexing_algorithm(left) \ + [intent] \ + lexing_algorithm(right) # If no matches, then the word is a stopword. return [] return lexing_algorithm(msg) def _parser(self, tokens: List[str]): """Generates a Parse Tree from a list of tokens provided by the Lexer. Args: tokens: A tokenized list of commands and Entities. i.e. ['control_play', 'query_similar_entities', 'Justin Bieber'] Returns: An nltk parse tree, as defined by the CFG given in this function. """ # TODO: Improve the CFG work for the following: # - Play songs faster than despicito # - Play something similar to despicito but faster # - Play something similar to u2 and justin bieber def gen_lexing_patterns(vals: List[str]): # TODO: Here we remove entries containing ', # as it is a special character used by # the NLTK parser. We need to fix this # eventually. safe_vals = [s for s in vals if "\'" not in s] return "' | '".join(safe_vals) or "NONE" # A Probabilistic Context Free Grammar (PCFG) # can be used to simulate "operator precedence", # which removes the problems of ambiguity in # the grammar. grammar = nltk.PCFG.fromstring(""" Root -> Terminal_Command Result [0.6] Root -> Terminal_Command [0.4] Result -> Entity [0.5] Result -> Unary_Command Result [0.1] Result -> Result Binary_Command Result [0.4] Entity -> '{}' [1.0] Unary_Command -> '{}' [1.0] Terminal_Command -> '{}' [1.0] Binary_Command -> '{}' [1.0] """.format( gen_lexing_patterns(self.kb_named_entities), gen_lexing_patterns(self.keywords.get("unary").keys()), gen_lexing_patterns(self.keywords.get("terminal").keys()), gen_lexing_patterns(self.keywords.get("binary").keys()), )) parser = nltk.ViterbiParser(grammar) # TODO: Returns the first tree, but need to deal with # case where grammar is ambiguous, and more than # one tree is returned. return next(parser.parse(tokens))
class TestMusicKnowledgeBaseAPI(unittest.TestCase): def setUp(self): DB_path = test_db_utils.create_and_populate_db() self.kb_api = KnowledgeBaseAPI(dbName=DB_path) def tearDown(self): test_db_utils.remove_db() def test_get_song_data(self): song_data = self.kb_api.get_song_data("Despacito") # we don't care what the node ID is self.assertEqual( 1, len(song_data), "Expected exactly one result from query for song 'Despacito'.") self.assertEqual( song_data[0], dict( id=10, song_name="Despacito", artist_name="Justin Bieber", duration_ms=222222, popularity=10, ), "Found expected values for song data for 'Despacito'.") def test_get_song_data_dne(self): res = self.kb_api.get_song_data("Not In Database") self.assertEqual( res, [], "Expected empty list of results for queried song not in DB.") def test_get_artist_data(self): artist_data = self.kb_api.get_artist_data("Justin Bieber") self.assertEqual( len(artist_data), 1, "Expected exactly one result for artist 'Justin Bieber'.") artist_data[0]["genres"] = set(artist_data[0]["genres"]) self.assertEqual( artist_data[0], dict(genres=set(["Pop", "Super pop"]), id=1, num_spotify_followers=4000, name="Justin Bieber"), "Artist data for 'Justin Bieber' did not match expected.", ) def test_get_artist_data_dne(self): artist_data = self.kb_api.get_artist_data("Unknown artist") self.assertEqual(artist_data, [], "Expected 'None' result for unknown artist.") def test_get_songs(self): res = self.kb_api.get_songs_by_artist("Justin Bieber") self.assertEqual( res, ["Despacito", "Sorry"], "Songs retrieved for 'Justin Bieber' did not match expected.") res = self.kb_api.get_songs_by_artist("Justin Timberlake") self.assertEqual( res, ["Rock Your Body"], "Songs retrieved for 'Justin Timberlake' did not match expected.") def test_get_songs_unknown_artist(self): res = self.kb_api.get_songs_by_artist("Unknown artist") self.assertEqual(res, None, "Unexpected songs retrieved for unknown artist.") def test_find_similar_song(self): res = self.kb_api.get_related_entities("Despacito") self.assertEqual( len(res), 1, "Expected only one song similar to \"Despacito\". Found {0}". format(res), ) self.assertEqual( res[0], "Rock Your Body", "Expected to find \"Rock Your Body\" as similar to \"Despacito\".", ) def test_find_similar_artist(self): res = self.kb_api.get_related_entities("Justin Bieber") self.assertEqual( len(res), 2, "Expected exactly two artists similar to Justin Bieber.", ) self.assertEqual( res[0], "Justin Timberlake", "Expected to find Justin Timberlake as similar to Justin Bieber.", ) self.assertEqual( res[1], "Shawn Mendes", "Expected to find Justin Timberlake as similar to Justin Bieber.", ) def test_get_all_music_entities(self): res = self.kb_api.get_all_music_entities() self.assertTrue( 'Justin Bieber' in res, 'Expected to find "Justin Bieber" in the list of entities.') def test_find_similar_to_entity_that_dne(self): res = self.kb_api.get_related_entities("Unknown Entity") self.assertEqual(res, []) def test_get_related_genres(self): genre_rel_str = self.kb_api.approved_relations["genre"] rel_genres = self.kb_api.get_related_entities("Justin Bieber", rel_str=genre_rel_str) self.assertEqual( set(rel_genres), set(["Pop", "Super pop"]), "Did not find expected related genres for artist 'Justin Bieber'") rel_genres = self.kb_api.get_related_entities("Justin Timberlake", rel_str=genre_rel_str) self.assertEqual( rel_genres, ["Pop"], "Did not find expected related genres for artist 'Justin Timberlake'" ) def test_connect_entities_by_similarity(self): res = self.kb_api.get_related_entities("Shawn Mendes") self.assertEqual(len(res), 0) res = self.kb_api.connect_entities("Shawn Mendes", "Justin Timberlake", "similar to", 0) self.assertEqual(res, True, "") res = self.kb_api.get_related_entities("Shawn Mendes") self.assertEqual(len(res), 1) self.assertEqual(res[0], "Justin Timberlake") def test_connect_entities_by_genre(self): genre_rel_str = self.kb_api.approved_relations["genre"] res = self.kb_api.get_related_entities("Shawn Mendes", rel_str=genre_rel_str) self.assertEqual(len(res), 0) res = self.kb_api.connect_entities("Shawn Mendes", "Pop", "of genre", 100) self.assertEqual(res, True, "") res = self.kb_api.get_related_entities("Shawn Mendes", rel_str=genre_rel_str) self.assertEqual( len(res), 1, "Expected to find exactly one related genre for 'Shawn Mendes'.") self.assertEqual(res[0], "Pop", "Found unexpected genre for 'Shawn Mendes'.") def test_rejects_connect_ambiguous_entities(self): self.kb_api.add_artist("Artist and Song name clash") self.kb_api.add_song("Artist and Song name clash", "U2") res = self.kb_api.connect_entities("Artist and Song name clash", "Justin Timberlake", "similar to", 0) self.assertEqual(res, False, "") def test_get_node_ids_by_entity_type(self): node_ids_dict = self.kb_api.get_node_ids_by_entity_type( "Justin Bieber") all_entity_types = list(node_ids_dict.keys()) self.assertEqual( all_entity_types, ["artist"], "Expected to find (only) entities of type 'artist' for 'Justin Bieber', but got: {}" .format(node_ids_dict)) self.assertEqual( len(node_ids_dict.get("artist")), 1, "Expected to find exactly one entity of type 'artist' for 'Justin Bieber', but got: {}" .format(node_ids_dict)) self.kb_api.add_artist("Song and Artist name") self.kb_api.add_song("Song and Artist name", "U2") node_ids_dict = self.kb_api.get_node_ids_by_entity_type( "Song and Artist name") alphabetized_entity_types = sorted(list(node_ids_dict.keys())) self.assertEqual( alphabetized_entity_types, ["artist", "song"], "Expected to find (exaclty) two entity types: 'artist' and 'song', but got: {}" .format(node_ids_dict)) def test_get_matching_node_ids(self): node_ids = self.kb_api._get_matching_node_ids("Justin Bieber") self.assertEqual( len(node_ids), 1, "Expected to find exactly one matching node for 'Justin Bieber', but got: {}" .format(node_ids)) def test_get_matching_node_ids_empty(self): node_ids = self.kb_api._get_matching_node_ids("Unknown artist") self.assertEqual( len(node_ids), 0, "Expected to find no matching node for 'Unknown artist', but got: {}" .format(node_ids)) def test_add_artist(self): sample_genres = ["Pop", "Very pop", "Omg so pop"] new_artist_node_id = self.kb_api.add_artist("Heart", genres=sample_genres, num_spotify_followers=1) self.assertNotEqual(new_artist_node_id, None, "Failed to add artist 'Heart' to knowledge base.") artist_data = self.kb_api.get_artist_data("Heart") self.assertEqual(len(artist_data), 1, "Expected unique match for artist 'Heart'.") artist_data = artist_data[0] artist_data["genres"] = set(artist_data["genres"]) self.assertEqual( artist_data, dict( name="Heart", id=new_artist_node_id, genres=set(sample_genres), num_spotify_followers=1, ), "Did not find expected genres for artist 'Heart'.", ) def test_reject_add_artist_already_exists(self): artist_node_id = self.kb_api.get_artist_data("Justin Bieber")[0]["id"] res = self.kb_api.add_artist("Justin Bieber") self.assertEqual( res, artist_node_id, "Expected rejection of attempt to add artist 'Justin Bieber' to knowledge base." ) def test_add_artist_omitted_opt_params(self): res = self.kb_api.add_artist("Heart") self.assertNotEqual(res, None, "Failed to add artist 'Heart' to knowledge base.") artist_data = self.kb_api.get_artist_data("Heart") self.assertEqual(len(artist_data), 1, "Expected unique match for artist 'Heart'.") self.assertEqual(artist_data[0]["name"], "Heart", "Expected match for artist 'Heart'.") self.assertEqual(artist_data[0]["genres"], [], "Expected no genres for artist 'Heart'.") self.assertEqual(artist_data[0]["num_spotify_followers"], None, "Expected no genres for artist 'Heart'.") def test_add_song(self): new_song_node_id = self.kb_api.add_song("Heart", "Justin Bieber", duration_ms=11111, popularity=100) self.assertNotEqual( new_song_node_id, None, "Failed to add song 'Heart' by artist 'Justin Bieber' to knowledge base." ) song_data = self.kb_api.get_song_data("Heart") self.assertEqual(len(song_data), 1, "Expected exactly one result.") song_data = song_data[0] self.assertEqual( song_data, dict(id=new_song_node_id, song_name="Heart", artist_name="Justin Bieber", duration_ms=11111, popularity=100), "Received unexpected song data") def test_add_song_omitted_opt_params(self): new_song_node_id = self.kb_api.add_song("What do you mean?", "Justin Bieber") self.assertNotEqual( new_song_node_id, None, "Failed to add song 'sorry' by artist 'Justin Bieber' to knowledge base." ) song_data = self.kb_api.get_song_data("What do you mean?") self.assertEqual(len(song_data), 1, "Expected exactly one result.") song_data = song_data[0] self.assertEqual( song_data, dict(id=new_song_node_id, song_name="What do you mean?", artist_name="Justin Bieber", duration_ms=None, popularity=None), "Received unexpected song data") def test_add_duplicate_song_for_different_artist(self): new_song_node_id = self.kb_api.add_song("Despacito", "Justin Timberlake") self.assertNotEqual( new_song_node_id, None, "Failed to add song 'Despacito' by artist 'Justin Timberlake' to knowledge base." ) res = self.kb_api.get_song_data("Despacito") self.assertEqual(len(res), 2, "Expected exactly one match for song 'Despacito'.") artists = set([res[0]["artist_name"], res[1]["artist_name"]]) self.assertEqual( artists, set(["Justin Bieber", "Justin Timberlake"]), "Expected to find duplicate artists 'Justin Bieber' and 'Justin Timberlake' for song 'Despacito')" ) def test_reject_add_song_already_exists(self): res = self.kb_api.add_song("Despacito", "Justin Bieber") self.assertEqual( res, None, "Expected rejection of attempt to add song 'Despacito' by 'Justin Bieber' to knowledge base." ) # The logic tested here is currently implemented in the KR API # However, if it is moved to the schema (e.g. trigger functions), # then this test can be moved to the schema test module def test_new_song_with_unknown_artist_rejected(self): res = self.kb_api.add_song("Song by Unknown Artist", "Unknown artist") self.assertEqual(res, None, "Expected song with unknown artist to be rejected") res = self.kb_api.get_song_data("Song by Unknown Artist") self.assertEqual( len(res), 0, "Insertion of song with unknown artist should not have been added to nodes table" ) def test_add_genre(self): node_id = self.kb_api.add_genre("hip hop") self.assertEqual( type(node_id), int, "Genre addition appears to have failed: expected int return value (node id) on valid attempt to add genre." ) res = self.kb_api.add_genre("hip hop") self.assertEqual( res, node_id, "Expected original node id to be fetched when attempting to add duplicate genre." ) def test_add_genre_creates_node(self): res = self.kb_api.add_genre("hip hop") self.assertEqual( type(res), int, "Genre addition appears to have failed: expected int return value (node id) on valid attempt to add genre." ) entities = self.kb_api.get_node_ids_by_entity_type("hip hop") self.assertIn( "genre", entities, "Expected to find node associated with genre 'hip hop'.") def test_get_node_ids_by_entity_type(self): res = self.kb_api.get_node_ids_by_entity_type("Justin Timberlake") self.assertTrue( "artist" in res, "Expected to find an 'artist' entity with name 'Justin Timberlake', but got: {0}" .format(res)) self.assertEqual( len(res["artist"]), 1, "Expected to find exactly one entity (of type 'artist') with name 'Justin Timberlake', but got: {0}" .format(res)) res = self.kb_api.get_node_ids_by_entity_type("Despacito") self.assertTrue( "song" in res, "Expected to find an 'song' entity with name 'Despacito', but got: {0}" .format(res)) self.assertEqual( len(res["song"]), 1, "Expected to find exactly one entity (of type 'song') with name 'Despacito', but got: {0}" .format(res)) res = self.kb_api.get_node_ids_by_entity_type("Unknown entity") self.assertEqual( res, {}, "Expected no results from query for unknown entity, but got {}". format(res))
class TestMusicKnowledgeBaseAPI(unittest.TestCase): def setUp(self): DB_path = test_db_utils.create_and_populate_db() self.kb_api = KnowledgeBaseAPI(dbName=DB_path) def tearDown(self): test_db_utils.remove_db() def test_get_song_data(self): song_data = self.kb_api.get_song_data("Despacito") expected_res = dict( id=10, song_name="Despacito", artist_name="Justin Bieber", duration_ms=222222, popularity=10, valence=0.1, spotify_uri='spotify:track:Despacito', acousticness=None, danceability=None, energy=None, instrumentalness=None, liveness=None, loudness=None, speechiness=None, tempo=None, mode=None, musical_key=None, time_signature=None, ) self.assertEqual( 1, len(song_data), "Expected exactly one result from query for song 'Despacito'.") self.assertEqual( song_data[0], expected_res, "Found expected values for song data for 'Despacito'.") song_data = self.kb_api.get_song_data(song_id=10) self.assertEqual( 1, len(song_data), "Expected exactly one result from query for song 'Despacito'.") self.assertEqual( song_data[0], expected_res, "Found expected values for song data for 'Despacito'.") def test_get_song_data_bad_params(self): res = self.kb_api.get_song_data() self.assertEqual( res, [], "Expect no results when both song name and ID are ommitted.") def test_get_song_data_case_insensitive(self): song_data = self.kb_api.get_song_data("dESpAcItO") self.assertEqual( 1, len(song_data), "Expected exactly one result from query for song 'dESpAcItO'.") self.assertEqual( song_data[0], dict( id=10, song_name="Despacito", artist_name="Justin Bieber", duration_ms=222222, popularity=10, valence=0.1, spotify_uri='spotify:track:Despacito', acousticness=None, danceability=None, energy=None, instrumentalness=None, liveness=None, loudness=None, speechiness=None, tempo=None, mode=None, musical_key=None, time_signature=None, ), "Found expected values for song data for 'Despacito'.") song_data = self.kb_api.get_song_data("bEaUTIful DAY") self.assertEqual( 1, len(song_data), "Expected exactly one result from query for song 'bEaUTIful DAY'.") self.assertEqual( song_data[0], dict( id=12, song_name="Beautiful Day", artist_name="U2", duration_ms=111111, popularity=60, valence=1.0, spotify_uri='spotify:track:BeautifulDay', acousticness=None, danceability=None, energy=None, instrumentalness=None, liveness=None, loudness=None, speechiness=None, tempo=None, mode=None, musical_key=None, time_signature=None, ), "Found expected values for song data for 'Beautiful Day'.") def test_get_song_data_ambiguous_name(self): res = self.kb_api.get_song_data("Sorry") self.assertEqual(2, len(res), "Expected exactly two results.") self.assertEqual(set([14, 15]), set([hit["id"] for hit in res]), "Found unexpected IDs.") res = self.kb_api.get_song_data("Sorry", 14) self.assertEqual(1, len(res), "Expected exactly one result.") self.assertEqual( res[0], dict( id=14, song_name="Sorry", artist_name="Justin Bieber", popularity=20, duration_ms=333333, valence=0.3, spotify_uri='spotify:track:Sorry', acousticness=None, danceability=None, energy=None, instrumentalness=None, liveness=None, loudness=None, speechiness=None, tempo=None, mode=None, musical_key=None, time_signature=None, ), "Unexpected result contents.") res = self.kb_api.get_song_data("Sorry", 15) self.assertEqual(1, len(res), "Expected exactly one result.") self.assertEqual( res[0], dict( id=15, song_name="Sorry", artist_name="The Anti Justin Bieber", duration_ms=None, popularity=None, spotify_uri=None, acousticness=None, danceability=None, energy=None, instrumentalness=None, liveness=None, loudness=None, speechiness=None, valence=None, tempo=None, mode=None, musical_key=None, time_signature=None, ), "Unexpected result contents.") def test_get_song_data_dne(self): res = self.kb_api.get_song_data("Not In Database") self.assertEqual( res, [], "Expected empty list of results for queried song not in DB.") def test_get_artist_data(self): artist_data = self.kb_api.get_artist_data("Justin Bieber") self.assertEqual( len(artist_data), 1, "Expected exactly one result for artist 'Justin Bieber'.") artist_data[0]["genres"] = set(artist_data[0]["genres"]) self.assertEqual( artist_data[0], dict(genres=set(["Pop", "Super pop"]), id=1, num_spotify_followers=4000, name="Justin Bieber"), "Artist data for 'Justin Bieber' did not match expected.", ) def test_get_artist_data_dne(self): artist_data = self.kb_api.get_artist_data("Unknown artist") self.assertEqual(artist_data, [], "Expected 'None' result for unknown artist.") def test_get_all_song_names(self): expected_song_names = set([ "Despacito", "Rock Your Body", "Beautiful Day", "In My Blood", "Sorry" ]) res = self.kb_api.get_all_song_names() self.assertEqual(set(res), expected_song_names, "Unexpected result from fetching all songs from db.") def test_get_all_artist_names(self): expected_artist_names = set([ "Justin Bieber", "Justin Timberlake", "U2", "Shawn Mendes", "The Anti Justin Bieber" ]) res = self.kb_api.get_all_artist_names() self.assertEqual( set(res), expected_artist_names, "Unexpected result from fetching all artists from db.") def test_get_songs(self): res = self.kb_api.get_songs_by_artist("Justin Bieber") self.assertEqual( res, [ dict(song_name="Despacito", id=10), dict(song_name="Sorry", id=14) ], "Songs retrieved for 'Justin Bieber' did not match expected.", ) res = self.kb_api.get_songs_by_artist("Justin Timberlake") self.assertEqual( res, [dict(song_name="Rock Your Body", id=11)], "Songs retrieved for 'Justin Timberlake' did not match expected.", ) def test_get_songs_unknown_artist(self): res = self.kb_api.get_songs_by_artist("Unknown artist") self.assertEqual(res, [], "Unexpected songs retrieved for unknown artist.") def test_songs_are_related_popularity(self): self.assertEqual( self.kb_api.songs_are_related(12, 10, "less popular"), False, "'Beautiful Day' is MORE popular than 'Despacito'", ) self.assertEqual( self.kb_api.songs_are_related(12, 10, "more popular"), True, "'Beautiful Day' is MORE popular than 'Despacito'", ) def test_songs_are_related_valence(self): self.assertEqual( self.kb_api.songs_are_related(12, 10, "more happy"), True, "'Beautiful Day' is MORE happy than 'Despacito'", ) self.assertEqual( self.kb_api.songs_are_related(12, 10, "less happy"), False, "'Beautiful Day' is MORE happy than 'Despacito'", ) def test_songs_are_related_same_song(self): self.assertEqual( self.kb_api.songs_are_related(12, 12, "more popular"), False, "'Beautiful Day' cannot be more popular than itself.", ) self.assertEqual( self.kb_api.songs_are_related(12, 12, "more popular"), False, "'Beautiful Day' cannot be more popular than itself.", ) def test_songs_are_related_unknown_relationship(self): self.assertEqual( self.kb_api.songs_are_related(12, 10, 'more something'), False, "Expected return value False when relationship is invalid.", ) def test_songs_are_related_value_missing(self): self.assertEqual( self.kb_api.songs_are_related(12, 10, 'more dancey'), False, "Expected return value False when DB does not contain value for given (valid) relationship.", ) def test_find_similar_song(self): res = self.kb_api.get_related_entities("Despacito") self.assertEqual( len(res), 1, "Expected only one song similar to \"Despacito\". Found {0}". format(res), ) self.assertEqual( res[0], "Rock Your Body", "Expected to find \"Rock Your Body\" as similar to \"Despacito\".", ) def test_find_similar_artist(self): res = self.kb_api.get_related_entities("Justin Bieber") self.assertEqual( len(res), 2, "Expected exactly two artists similar to Justin Bieber.", ) self.assertEqual( res[0], "Justin Timberlake", "Expected to find Justin Timberlake as similar to Justin Bieber.", ) self.assertEqual( res[1], "Shawn Mendes", "Expected to find Justin Timberlake as similar to Justin Bieber.", ) def test_find_similar_to_entity_that_dne(self): res = self.kb_api.get_related_entities("Unknown Entity") self.assertEqual(res, []) def test_get_related_genres(self): genre_rel_str = self.kb_api.approved_relations["genre"] rel_genres = self.kb_api.get_related_entities("Justin Bieber", rel_str=genre_rel_str) self.assertEqual( set(rel_genres), set(["Pop", "Super pop"]), "Did not find expected related genres for artist 'Justin Bieber'") rel_genres = self.kb_api.get_related_entities("Justin Timberlake", rel_str=genre_rel_str) self.assertEqual( rel_genres, ["Pop"], "Did not find expected related genres for artist 'Justin Timberlake'" ) def test_connect_entities_by_similarity(self): res = self.kb_api.get_related_entities("Shawn Mendes") self.assertEqual(len(res), 0) res = self.kb_api.connect_entities("Shawn Mendes", "Justin Timberlake", "similar to", 0) self.assertEqual(res, True, "") res = self.kb_api.get_related_entities("Shawn Mendes") self.assertEqual(len(res), 1) self.assertEqual(res[0], "Justin Timberlake") def test_connect_entities_by_genre(self): genre_rel_str = self.kb_api.approved_relations["genre"] res = self.kb_api.get_related_entities("Shawn Mendes", rel_str=genre_rel_str) self.assertEqual(len(res), 0) res = self.kb_api.connect_entities("Shawn Mendes", "Pop", "of genre", 100) self.assertEqual(res, True, "") res = self.kb_api.get_related_entities("Shawn Mendes", rel_str=genre_rel_str) self.assertEqual( len(res), 1, "Expected to find exactly one related genre for 'Shawn Mendes'.") self.assertEqual(res[0], "Pop", "Found unexpected genre for 'Shawn Mendes'.") def test_rejects_connect_ambiguous_entities(self): self.kb_api.add_artist("Artist and Song name clash") self.kb_api.add_song("Artist and Song name clash", "U2") res = self.kb_api.connect_entities("Artist and Song name clash", "Justin Timberlake", "similar to", 0) self.assertEqual(res, False, "") def test_get_node_ids_by_entity_type(self): node_ids_dict = self.kb_api.get_node_ids_by_entity_type( "Justin Bieber") all_entity_types = list(node_ids_dict.keys()) self.assertEqual( all_entity_types, ["artist"], "Expected to find (only) entities of type 'artist' for 'Justin Bieber', but got: {}" .format(node_ids_dict)) self.assertEqual( len(node_ids_dict.get("artist")), 1, "Expected to find exactly one entity of type 'artist' for 'Justin Bieber', but got: {}" .format(node_ids_dict)) self.kb_api.add_artist("Song and Artist name") self.kb_api.add_song("Song and Artist name", "U2") node_ids_dict = self.kb_api.get_node_ids_by_entity_type( "Song and Artist name") alphabetized_entity_types = sorted(list(node_ids_dict.keys())) self.assertEqual( alphabetized_entity_types, ["artist", "song"], "Expected to find (exactly) two entity types: 'artist' and 'song', but got: {}" .format(node_ids_dict)) def test_get_matching_node_ids(self): node_ids = self.kb_api._get_matching_node_ids("Justin Bieber") self.assertEqual( len(node_ids), 1, "Expected to find exactly one matching node for 'Justin Bieber', but got: {}" .format(node_ids)) def test_get_matching_node_ids_empty(self): node_ids = self.kb_api._get_matching_node_ids("Unknown artist") self.assertEqual( len(node_ids), 0, "Expected to find no matching node for 'Unknown artist', but got: {}" .format(node_ids)) def test_add_artist(self): sample_genres = ["Pop", "Very pop", "Omg so pop"] new_artist_node_id = self.kb_api.add_artist("Heart", genres=sample_genres, num_spotify_followers=1) self.assertNotEqual(new_artist_node_id, None, "Failed to add artist 'Heart' to knowledge base.") artist_data = self.kb_api.get_artist_data("Heart") self.assertEqual(len(artist_data), 1, "Expected unique match for artist 'Heart'.") artist_data = artist_data[0] artist_data["genres"] = set(artist_data["genres"]) self.assertEqual( artist_data, dict( name="Heart", id=new_artist_node_id, genres=set(sample_genres), num_spotify_followers=1, ), "Did not find expected genres for artist 'Heart'.", ) def test_reject_add_artist_already_exists(self): artist_node_id = self.kb_api.get_artist_data("Justin Bieber")[0]["id"] res = self.kb_api.add_artist("Justin Bieber") self.assertEqual( res, artist_node_id, "Expected rejection of attempt to add artist 'Justin Bieber' to knowledge base." ) def test_add_artist_omitted_opt_params(self): res = self.kb_api.add_artist("Heart") self.assertNotEqual(res, None, "Failed to add artist 'Heart' to knowledge base.") artist_data = self.kb_api.get_artist_data("Heart") self.assertEqual(len(artist_data), 1, "Expected unique match for artist 'Heart'.") self.assertEqual(artist_data[0]["name"], "Heart", "Expected match for artist 'Heart'.") self.assertEqual(artist_data[0]["genres"], [], "Expected no genres for artist 'Heart'.") self.assertEqual(artist_data[0]["num_spotify_followers"], None, "Expected no genres for artist 'Heart'.") def test_add_song(self): new_song_node_id = self.kb_api.add_song( "Heart", "Justin Bieber", duration_ms=11111, popularity=100, spotify_uri="spotify:track:Heart", audio_features=dict(energy=0.5), ) self.assertNotEqual( new_song_node_id, None, "Failed to add song 'Heart' by artist 'Justin Bieber' to knowledge base." ) song_data = self.kb_api.get_song_data("Heart") self.assertEqual(len(song_data), 1, "Expected exactly one result.") song_data = song_data[0] self.assertEqual( song_data, dict( id=new_song_node_id, song_name="Heart", artist_name="Justin Bieber", duration_ms=11111, popularity=100, spotify_uri="spotify:track:Heart", energy=0.5, acousticness=None, danceability=None, instrumentalness=None, liveness=None, loudness=None, speechiness=None, valence=None, tempo=None, mode=None, musical_key=None, time_signature=None, ), "Received unexpected song data", ) def test_add_song_omitted_opt_params(self): new_song_node_id = self.kb_api.add_song("What do you mean?", "Justin Bieber") self.assertNotEqual( new_song_node_id, None, "Failed to add song 'sorry' by artist 'Justin Bieber' to knowledge base." ) song_data = self.kb_api.get_song_data("What do you mean?") self.assertEqual(len(song_data), 1, "Expected exactly one result.") song_data = song_data[0] self.assertEqual( song_data, dict( id=new_song_node_id, song_name="What do you mean?", artist_name="Justin Bieber", duration_ms=None, popularity=None, spotify_uri=None, acousticness=None, danceability=None, energy=None, instrumentalness=None, liveness=None, loudness=None, speechiness=None, valence=None, tempo=None, mode=None, musical_key=None, time_signature=None, ), "Received unexpected song data") def test_add_song_with_all_audio_features(self): valid_audio_features = dict( acousticness=1, danceability=1, energy=1, instrumentalness=1, liveness=1, loudness=0, speechiness=1, valence=1, tempo=1, mode="major", musical_key=1, time_signature=3, ) new_song_id = self.kb_api.add_song( "Song with all audio features", "Justin Bieber", audio_features=valid_audio_features, ) song_data = self.kb_api.get_song_data("Song with all audio features") self.assertEqual( dict( id=new_song_id, song_name="Song with all audio features", artist_name="Justin Bieber", duration_ms=None, popularity=None, spotify_uri=None, acousticness=1.0, danceability=1.0, energy=1.0, instrumentalness=1.0, liveness=1.0, loudness=0.0, speechiness=1.0, valence=1.0, tempo=1.0, mode="major", musical_key=1, time_signature=3, ), song_data[0], "Retrieved song data did not match expected.", ) def test_add_duplicate_song_for_different_artist(self): new_song_node_id = self.kb_api.add_song("Despacito", "Justin Timberlake") self.assertNotEqual( new_song_node_id, None, "Failed to add song 'Despacito' by artist 'Justin Timberlake' to knowledge base." ) res = self.kb_api.get_song_data("Despacito") self.assertEqual(len(res), 2, "Expected exactly one match for song 'Despacito'.") artists = set([res[0]["artist_name"], res[1]["artist_name"]]) self.assertEqual( artists, set(["Justin Bieber", "Justin Timberlake"]), "Expected to find duplicate artists 'Justin Bieber' and 'Justin Timberlake' for song 'Despacito')" ) def test_reject_add_song_already_exists(self): res = self.kb_api.add_song("Despacito", "Justin Bieber") self.assertEqual( res, None, "Expected rejection of attempt to add song 'Despacito' by 'Justin Bieber' to knowledge base." ) # The logic tested here is currently implemented in the KR API # However, if it is moved to the schema (e.g. trigger functions), # then this test can be moved to the schema test module def test_new_song_with_unknown_artist_rejected(self): res = self.kb_api.add_song("Song by Unknown Artist", "Unknown artist") self.assertEqual(res, None, "Expected song with unknown artist to be rejected") res = self.kb_api.get_song_data("Song by Unknown Artist") self.assertEqual( len(res), 0, "Insertion of song with unknown artist should not have been added to nodes table" ) def test_add_genre(self): node_id = self.kb_api.add_genre("hip hop") self.assertEqual( type(node_id), int, "Genre addition appears to have failed: expected int return value (node id) on valid attempt to add genre." ) res = self.kb_api.add_genre("hip hop") self.assertEqual( res, node_id, "Expected original node id to be fetched when attempting to add duplicate genre." ) def test_add_genre_creates_node(self): res = self.kb_api.add_genre("hip hop") self.assertEqual( type(res), int, "Genre addition appears to have failed: expected int return value (node id) on valid attempt to add genre." ) entities = self.kb_api.get_node_ids_by_entity_type("hip hop") self.assertIn( "genre", entities, "Expected to find node associated with genre 'hip hop'.") def test_get_node_ids_by_entity_type(self): res = self.kb_api.get_node_ids_by_entity_type("Justin Timberlake") self.assertTrue( "artist" in res, "Expected to find an 'artist' entity with name 'Justin Timberlake', but got: {0}" .format(res)) self.assertEqual( len(res["artist"]), 1, "Expected to find exactly one entity (of type 'artist') with name 'Justin Timberlake', but got: {0}" .format(res)) res = self.kb_api.get_node_ids_by_entity_type("Despacito") self.assertTrue( "song" in res, "Expected to find an 'song' entity with name 'Despacito', but got: {0}" .format(res)) self.assertEqual( len(res["song"]), 1, "Expected to find exactly one entity (of type 'song') with name 'Despacito', but got: {0}" .format(res)) res = self.kb_api.get_node_ids_by_entity_type("Unknown entity") self.assertEqual( res, {}, "Expected no results from query for unknown entity, but got {}". format(res))
class TreeEvalEngine: """This class stores the possible interactions between the user and the system, and the logic to act on them. The various "Commands" properties contain mappings that store signifiers of user's intent, along with the functions that they map to. This class also contains the functions that corresponds to each of the possible user interactions. """ def __init__(self, db_path, player_controller): self.player = player_controller self.DB_path = db_path self.kb_api = KnowledgeBaseAPI(self.DB_path) def __call__(self, parser, text): """Evaluates a parse tree that was generated by the NLP layer from the user input. """ try: nltk_parse_tree = parser(text) self._evaluate(nltk_parse_tree) except: self.player.respond("I'm sorry, I don't understand.") @property def unary_commands(self): """A unary command operates on just one set of entities when evaluating a user's input (encoded as a Parse Tree). Returns: A mapping that stores signifiers of user's intent, along with the `commands` and the functions that they map to. """ return OrderedDict([ ('query_similar_entities', (['like', 'similar'], self._query_similar_entities)), ('query_songs_by_artist', (['songs by', 'by'], self._query_songs_by_artist)), ('query_artist_by_song', (['artist'], self._query_artist_by_song)), ]) @property def terminal_commands(self): """A terminal commands is the final command that is executed on the results of evaluating a user's input (encoded as a Parse Tree). Returns: A mapping that stores signifiers of user's intent, along with the `commands` and the functions that they map to. """ return OrderedDict([ ('query_commands', (['hi', 'how', 'hello'], self._query_commands)), ('control_stop', (['stop'], self._control_stop)), ('control_pause', (['pause'], self._control_pause)), ('control_play', (['start', 'play'], self._control_play)), ('query_info', (['who', 'what'], self._query_info)), ('control_forward', (['skip', 'next'], self._control_skip)), ]) @property def binary_commands(self): """A binary command operates on two sets of entities when evaluating a user's input (encoded as a Parse Tree). Returns: A mapping that stores signifiers of user's intent, along with the `commands` and the functions that they map to. """ return OrderedDict([ ('control_union', (['or', 'and'], self._control_union)), ]) @property def keywords(self): """For each type of command, generate a dictionary of keywords that map to the specific command (intent). """ return { "unary": {k: v[0] for k, v in self.unary_commands.items()}, "terminal": {k: v[0] for k, v in self.terminal_commands.items()}, "binary": {k: v[0] for k, v in self.binary_commands.items()}, } @property def actions(self): return {k: v[1] for k, v in self.unary_commands.items()} @property def command_precedence(self): return [k for k, v in self.unary_commands.items()] def _query_commands(self): # TODO: make this work """Terminal command. Reports on possible commands and interactions. """ self.player.respond("Hi there! Ask me to play artists or songs. " "I can also find songs that are similar to other " "artists.") def _control_play(self, entities: List[str]): """Terminal command. Plays the subjects specified. Args: entities: A list of Entities. TODO: - query db for entity - if it's of type song, play it - if it's anything else, get the songs associated with it and play in DESC order """ if entities: self.player.play(entities) else: self.player.respond("I'm sorry, I couldn't find that for you.") def _control_stop(self): """Terminal Command Stops the player. """ self.player.stop() def _control_pause(self): """Terminal Command Pauses the player. """ self.player.pause() def _control_skip(self): """Terminal Command Skips the song. """ self.player.skip() def _control_union(self, entities_1: List[str], entities_2: List[str]): """Binary Command Returns the intersection of the two parameters. Args: entities_1: A list of Entities. entities_2: A list of Entities. Returns: A list of Entities """ return list(set(entities_1).union(set(entities_2))) def _query_info(self, entities: List[str]): """Terminal Command Queries the artists for each Entity. Args: entities: A list of Entities. """ self.player.respond(entities) def _query_similar_entities(self, entities: List[str]): """Unary Command Returns a list of related Entities for all Entities in the parameters. Args: entities: A list of Entities. """ similar_entities = [] for e in entities: # Don't return the artists given in the # parms (for the case where there are # multiple artists and they are related # to each other). similar_entities += [ ent for ent in self.kb_api.get_related_entities(e) if ent not in entities ] return similar_entities def _query_songs_by_artist(self, entities: List[str]): """Unary Command Returns a list of Artists for all Entities in the parameters. Args: entities: A list of Entities. """ artists = [] for e in entities: artists += self.kb_api.get_songs_by_artist(e) return artists def _query_artist_by_song(self, entities: List[str]): """Unary Command Returns a list of Artists for all Entities in the parameters. Args: entities: A list of Entities. """ artists = [] for e in entities: artists += [ song.get('artist_name') for song in self.kb_api.get_song_data(e) ] return artists def _evaluate(self, tree: nltk.tree.Tree): """This function will evaluate the parse tree generated by the NLP layer. It recursively evaluates the tree from the top down. The tree starts with a `terminal command` (ie play) at the root. The tree's leaves of the tree are composed of `entities`, and the inner nodes are composed of operations on those entities. This function evaluates the operations and entities into a single result, which is then given to the `terminal command` to act upon. """ if tree.label() == "Root": if len(tree) == 1: func = self._evaluate(tree[0]) func() else: func = self._evaluate(tree[0]) result = self._evaluate(tree[1]) func(result) return elif tree.label() == "Result": if tree[0].label() == "Entity": return self._evaluate(tree[0]) if tree[0].label() == "Unary_Command": func = self._evaluate(tree[0]) result = self._evaluate(tree[1]) return func(result) if tree[1].label() == "Binary_Command": result_left = self._evaluate(tree[0]) func = self._evaluate(tree[1]) result_right = self._evaluate(tree[2]) return func(result_left, result_right) elif tree.label() == "Unary_Command": func = self.unary_commands.get(tree[0])[1] return func elif tree.label() == "Terminal_Command": func = self.terminal_commands.get(tree[0])[1] return func elif tree.label() == "Binary_Command": func = self.binary_commands.get(tree[0])[1] return func elif tree.label() == "Entity": return [tree[0]] print( "Error: CFG label rule not defined in " "evaluateEngine#self._evaluate", file=sys.stderr)
def __init__(self, db_path, player_controller): self.player = player_controller self.DB_path = db_path self.kb_api = KnowledgeBaseAPI(self.DB_path)
class BOWParser: """ This layer stores the NLP specific tooling and logic. References: https://stackoverflow.com/questions/42322902/how-to-get-parse-tree-using-python-nltk https://www.nltk.org/book/ch08.html """ extra_stopwords = {'s', 'hey', 'want', 'you'} def __init__(self, db_path, keywords): # Download the stopwords if necessary. try: nltk.data.find('corpora/stopwords') except LookupError: nltk.download('stopwords') self.commands = keywords self.kb_api = KnowledgeBaseAPI(db_path) self.db_nouns = self.kb_api.get_all_music_entities() def _get_stop_words(self): # Remove all keywords from stopwords stop_words = set(stopwords.words('english')) stop_words |= BOWParser.extra_stopwords for _, words in self.commands.items(): for word in words: try: stop_words.remove(word) except: pass return stop_words def _gen_patterns(self): # Generate RegEx patterns from keywords patterns = {} for intent, keys in self.commands.items(): patterns[intent] = re.compile(r'\b'+r'\b|\b'.join(keys)+r'\b') return patterns def __call__(self, msg: str): # Identify the first subject from the database that matches. subjects = [] for noun in self.db_nouns: if noun.lower() in msg.lower(): pattern = re.compile(noun, re.IGNORECASE) msg = pattern.sub('', msg) subjects.append(noun.strip()) # Remove punctuation from the string msg = re.sub(r"[,.;@#?!&$']+\ *", " ", msg, flags=re.VERBOSE) # Clean the stopwords from the input. stop_words = self._get_stop_words() clean_msg = ' '.join([word for word in msg.lower().split(' ') if word not in stop_words]) # Parse the keywords from the filtered input. patterns = self._gen_patterns() intents = [] for intent, pattern in patterns.items(): sub_msg = re.sub(pattern, '', clean_msg) if sub_msg != clean_msg: intents.append(intent) clean_msg = sub_msg remaining_text = clean_msg.strip() return subjects, intents, remaining_text
class BOWEvalEngine: """This class stores the possible interactions between the user and the system, and the logic to act on them. The `intents` property contains a mapping that stores signifiers of user's intent, along with the functions that they map to. This class also contains the functions that corresponds to each of the possible user interactions. """ def __init__(self, db_path, player_controller): self.player = player_controller self.DB_path = db_path self.kb_api = KnowledgeBaseAPI(self.DB_path) def __call__( self, subjects: List[str] = None, commands: List[str] = None, remaining_text: str = None, ): """Initiates a sequence of commands that will act on the subject parameters in order of command precedence. Args: subjects: List of subjects in the database. commands: Names of commands to enact. remaining_text: Any remaining text not parsed by the NLP layer. Returns: """ # Sort the commands by precedence. sorted_commands = [] for c in self.command_precedence: if c in commands: sorted_commands.append(c) # Call the highest precedence command. self._next_operation( subjects=subjects, commands=sorted_commands, remaining_text=remaining_text, ) @property def intents(self): """Returns a mapping that stores signifiers of user's intent, along with the `commands` and the functions that they map to. The ordering of these commands matters, as it is used to store the precedence of operations. Returns (OrderedDict): A mapping of intentions, their corresponding signifiers (as keywords), and function handlers. """ return OrderedDict([ ('query_commands', (['hi', 'how', 'hello'], self._query_commands)), ('control_stop', (['stop'], self._control_stop)), ('control_pause', (['pause'], self._control_pause)), ('control_forward', (['skip', 'next'], self._control_skip)), ('query_similar_entities', (['like', 'similar'], self._query_similar_entities)), ('control_play', (['start', 'play'], self._control_play)), ('query_artist', (['who', 'artist'], self._query_artist)), ('default', ([], self._default)), ]) @property def keywords(self): """For each type of command, generate a dictionary of keywords that map to the specific command (intent) """ return {k: v[0] for k, v in self.intents.items()} @property def actions(self): return {k: v[1] for k, v in self.intents.items()} @property def command_precedence(self): return [k for k, v in self.intents.items()] def _query_commands( self, subjects: List[str] = None, commands: List[str] = None, remaining_text: str = None, response_msg: str = None, ): """Terminal command. Reports on possible commands and interactions. """ if subjects: # Handle question about specific `subjects`. self.player.respond("I'm sorry, I can't answer that one") else: # Handle genery inquiries. self.player.respond( "Hi there! Ask me to play artists or songs. " "I can also find songs that are similar to other " "artists.") def _control_play( self, subjects: List[str] = None, commands: List[str] = None, remaining_text: str = None, response_msg: str = None, ): """Terminal command. Plays the subjects specified. TODO: - query db for entity - if it's of type song, play it - if it's anything else, get the songs associated with it and play in DESC order """ if subjects: self.player.play(subjects) elif remaining_text: self.player.respond("I'm sorry, I couldn't find that for you.") else: self.player.respond('Resuming the current song') def _control_stop( self, subjects: List[str] = None, commands: List[str] = None, remaining_text: str = None, response_msg: str = None, ): self.player.stop(subjects) def _control_pause( self, subjects: List[str] = None, commands: List[str] = None, remaining_text: str = None, response_msg: str = None, ): self.player.pause(subjects) def _control_skip( self, subjects: List[str] = None, commands: List[str] = None, remaining_text: str = None, response_msg: str = None, ): # TODO: Add number parsing for "skip forward 2 songs". self.player.skip(subjects) def _control_intersection( self, subjects: List[str] = None, commands: List[str] = None, remaining_text: str = None, response_msg: str = None, ): # TODO self.player.skip(subjects) def _control_union( self, subjects: List[str] = None, commands: List[str] = None, remaining_text: str = None, response_msg: str = None, ): # TODO self.player.skip(subjects) def _query_artist( self, subjects: List[str] = None, commands: List[str] = None, remaining_text: str = None, response_msg: str = None, ): # TODO: implement for fetching current state, ie "who is this artist?" self.player.respond(subjects) def _query_similar_entities( self, subjects: List[str] = None, commands: List[str] = None, remaining_text: str = None, response_msg: str = None, ): similar_entities = [] for e in subjects: similar_entities += self.kb_api.get_related_entities(e) if not similar_entities: self.player.respond("I'm sorry, I couldn't find that for you.") else: self._next_operation( subjects=similar_entities, commands=commands, remaining_text=remaining_text, response_msg=response_msg, ) def _default( self, subjects: List[str] = None, commands: List[str] = None, remaining_text: str = None, response_msg: str = None, ): self.player.respond("I'm sorry, I don't understand.") def _next_operation( self, subjects: List[str] = None, commands: List[str] = None, remaining_text: str = None, response_msg: str = None, ): """Calls the next function in the commands parameter. """ next_command_name = commands.pop(0) if commands else self._default next_func = self.actions.get(next_command_name, self._default) next_func( subjects=subjects, commands=commands, remaining_text=remaining_text, response_msg=response_msg, )
class TestDbSchema(unittest.TestCase): @classmethod def setUpClass(self): DB_path = test_db_utils.create_and_populate_db() self.kb_api = KnowledgeBaseAPI(dbName=DB_path) @classmethod def tearDownClass(self): test_db_utils.remove_db() def test_rejects_unknown_entity(self): res = self.kb_api.connect_entities("Unknown Entity", "Justin Timberlake", "similar to", 0) self.assertEqual(res, False, "Expected attempt to connect an unknown entity to fail.") res = self.kb_api.get_related_entities("Unknown Entity") self.assertEqual(res, []) def test_rejects_score_out_of_range(self): res = self.kb_api.connect_entities("Justin Timberlake", "Justin Timberlake", "similar to", -1) self.assertEqual(res, False, "Expected attempt to connect entities with score out-of-range to fail.") res = self.kb_api.get_related_entities("Justin Timberlake") self.assertEqual(len(res), 0) def test_rejects_duplicate_edge(self): res = self.kb_api.connect_entities("Justin Bieber", "Justin Timberlake", "similar to", 1) self.assertEqual(res, False, "Expected attempt to add a duplicate edge to fail.") def test_rejects_song_with_duplicate_id(self): new_song_id = self.kb_api.add_song( "some song", "Justin Bieber", spotify_uri="spotify:track:Despacito", # already in DB ) self.assertEqual(new_song_id, None, "Expected song addition to fail.") def test_edges_not_null_constraints(self): res = self.kb_api.connect_entities(None, "Justin Timberlake", "similar to", 1) self.assertEqual(res, False, "Expected 'None' value for artist to be rejected.") res = self.kb_api.connect_entities("U2", "U2", None, 1) self.assertEqual(res, False, "Expected 'None' value for edge type to be rejected.") res = self.kb_api.connect_entities("U2", "U2", "similar to", None) self.assertEqual(res, False, "Expected 'None' value for edge score to be rejected.") def test_entities_not_null_constraints(self): res = self.kb_api.add_artist(None) self.assertEqual(res, None, "Expected 'None' value for artist to be rejected.") res = self.kb_api.add_song("Song name", None) self.assertEqual(res, None, "Expected 'None' value for song to be rejected.") res = self.kb_api.add_song(None, "Artist name") self.assertEqual(res, None, "Expected 'None' value for artist to be rejected.") node_id = self.kb_api._add_node(None, "artist") self.assertEqual(node_id, None, "Expected 'None' value for entity name to be rejected.") node_id = self.kb_api._add_node("Some entity", None) self.assertEqual(node_id, None, "Expected 'None' value for entity type to be rejected.") def test_song_audio_features_range_constraints(self): # NOTE: all values are at their upper limit so that we can just add 1 # to them and test that the schema constraints reject their addition # (mode is left out because it is not numerical) audio_features = dict( acousticness=1, danceability=1, energy=1, instrumentalness=1, liveness=1, loudness=1, speechiness=1, valence=1, tempo=999, musical_key=11, time_signature=7, ) for feature_name in self.kb_api.song_audio_features: if feature_name == 'mode': continue # push value out of valid range audio_features[feature_name] += 1 new_song_id = self.kb_api.add_song( "Song by Justin Bieber", "Justin Bieber", audio_features=audio_features, ) self.assertEqual( new_song_id, None, f"Expected song with {feature_name} out of range to be rejected.", ) # reset value back into valid range audio_features[feature_name] -= 1 def test_song_mode_constraint(self): new_song_id = self.kb_api.add_song( "Song by Justin Bieber", "Justin Bieber", audio_features=dict(mode='not major or minor'), ) self.assertEqual(new_song_id, None, "Expected song with invalid mode to be rejected.")