def __init__(self, db_path, keywords):
        # Download the stopwords if necessary.
        try:
            nltk.data.find('corpora/stopwords')
        except LookupError:
            nltk.download('stopwords')

        self.commands = keywords
        self.kb_api = KnowledgeBaseAPI(db_path)
        self.db_nouns = self.kb_api.get_all_music_entities()
コード例 #2
0
 def setUp(self):
     self.DB_path = test_db_utils.create_and_populate_db()
     self.kb_api = KnowledgeBaseAPI(self.DB_path)
     self.results_dict = {}
     self.player_controller = MockController(self.results_dict)
     self.interactions = BOWEvalEngine(self.DB_path, self.player_controller)
     self.nlp = BOWParser(self.DB_path, self.interactions.keywords)
     self.keywords = self.interactions.keywords
コード例 #3
0
 def __init__(self, db_path, player_controller, parser_type='BagOfWords'):
     self.DB_path = db_path
     self.kb_api = KnowledgeBaseAPI(self.DB_path)
     self.parser_type = parser_type
     if parser_type == 'BagOfWords':
         self.eval_engine = BOWEvalEngine(self.DB_path, player_controller)
         self.parser = BOWParser(self.DB_path, self.eval_engine.keywords)
     elif parser_type == 'TREE':
         self.eval_engine = TreeEvalEngine(self.DB_path, player_controller)
         self.parser = TreeParser(self.DB_path, self.eval_engine.keywords)
def create_and_populate_db_with_spotify(spotify_client_id, spotify_secret_key, artists, path=None):
    path_to_db = create_db(path=path)
    artist_metadata = get_artist_metadata(SpotifyClient(spotify_client_id, spotify_secret_key), artists)
    kb_api = KnowledgeBaseAPI(path_to_db)
    for artist_name, artist_info in artist_metadata.items():
        kb_api.add_artist(artist_name, artist_info["genres"], artist_info["num_followers"])

        for song_name, song_info in artist_info["songs"].items():
            kb_api.add_song(
                song_name,
                artist_name,
                song_info["popularity"],
                song_info["duration_ms"],
            )

        for rel_artist_name, rel_artist_info in artist_info["related_artists"].items():
            kb_api.add_artist(rel_artist_name, rel_artist_info["genres"], rel_artist_info["num_followers"])
            kb_api.connect_entities(artist_name, rel_artist_name, "similar to", 100)
            kb_api.connect_entities(rel_artist_name, artist_name, "similar to", 100)
    return path_to_db
コード例 #5
0
def create_and_populate_db_with_spotify(spotify_client_id, spotify_secret_key, artists, path=None):
    """Pull data from Spotify for given artists and adds it to knowledge base through its API.

    For each of the given artists, find and add all of the following to the knowledge base:
    - general artist metadata (e.g. number of followers)
    - top songs
    - related artists, along with their own metadata and top songs

    Params:
        spotify_client_id (str) e.g. "".
        spotify_secret_key (str) e.g. "".
        artists (iterable) each element is an artist name (str).
            For example, might be stdin, or open file, or list.
        path (str): relative path e.g. "knowledge_base.db".

    Returns:
        path_to_db (str): relative path to newly created db e.g. "knowledge_base/knowledge_base.db"
    """
    from utils.spotify_client import SpotifyClient
    path_to_db = create_db(path=path)
    spotify = SpotifyClient(spotify_client_id, spotify_secret_key)

    artist_metadata = _get_artist_metadata(spotify, artists)
    kb_api = KnowledgeBaseAPI(path_to_db)
    for artist_name, artist_info in artist_metadata.items():
        kb_api.add_artist(artist_name, artist_info["genres"], artist_info["num_followers"])

        songs = spotify.get_top_songs(artist_info['id'], SPOTIFY_COUNTRY_ISO)
        audio_features = spotify.get_audio_features(
            [song_info['id'] for _, song_info in songs.items()]
        )
        _insert_songs(songs, artist_name, audio_features, kb_api)

        for rel_artist_name, rel_artist_info in artist_info["related_artists"].items():
            kb_api.add_artist(rel_artist_name, rel_artist_info["genres"], rel_artist_info["num_followers"])
            kb_api.connect_entities(artist_name, rel_artist_name, "similar to", 100)
            kb_api.connect_entities(rel_artist_name, artist_name, "similar to", 100)
            songs = spotify.get_top_songs(rel_artist_info['id'], SPOTIFY_COUNTRY_ISO)
            audio_features = spotify.get_audio_features(
                [song_info['id'] for _, song_info in songs.items()]
            )
            _insert_songs(songs, rel_artist_name, audio_features, kb_api)

    return path_to_db
コード例 #6
0
ファイル: extract_entities.py プロジェクト: okjuan/muze
    for t in test_cases:
        res = remove_dash_section(t['input'])
        if res != t['expected']:
            print(f"FAIL: '{t['input']}' expected '{t['expected']}' got '{res}'")
            fails += 1
        total += 1
    print(f"Finished running test_remove_dash_section: Ran {total} tests with {fails} failures.")

if __name__ == "__main__":
    if len(sys.argv) > 1 and sys.argv[1] == "-t":
        print("Running tests...")
        test_remove_parenthised_section()
        test_remove_dash_section()

    else:
        music_api = KnowledgeBaseAPI("knowledge_base/knowledge_base.db")

        # print songs
        if sys.argv[1] == "-s":
            print("Fetching song names..")
            song_names = music_api.get_all_song_names()
            print_to_csv(song_names, SONG_FILE_NAME)
            print(f"Wrote to {SONG_FILE_NAME}")

        # print artists
        elif sys.argv[1] == "-a":
            print("Fetching artist names..")
            artist_names = music_api.get_all_artist_names()
            print_to_csv(artist_names, ARTIST_FILE_NAME)
            print(f"Wrote to {ARTIST_FILE_NAME}")
コード例 #7
0
ファイル: server.py プロジェクト: okjuan/muze
"""

from flask import Flask, render_template
from flask_socketio import SocketIO
import json
import logging
import os
import random
import sys

sys.path.extend(['.', '../'])  # needed for import statement(s) below
from knowledge_base.api import KnowledgeBaseAPI

app = Flask(__name__)
socket_io = SocketIO(app)
music_api = KnowledgeBaseAPI('knowledge_base/knowledge_base.db')

CLIENT_SESSION_KEYS, NEW_CLIENT_IDX, = [None for i in range(100)], 0
CLIENT_COUNT, MAX_CLIENT_SESSIONS = 0, 100


@app.route('/')
def get_player():
    return render_template("player.html")


@socket_io.on("start session")
def start_user_session():
    global NEW_CLIENT_IDX, CLIENT_COUNT
    print(f"Received request from client to begin session..")
コード例 #8
0
 def setUpClass(self):
     DB_path = test_db_utils.create_and_populate_db()
     self.kb_api = KnowledgeBaseAPI(dbName=DB_path)
コード例 #9
0
class TestDbSchema(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        DB_path = test_db_utils.create_and_populate_db()
        self.kb_api = KnowledgeBaseAPI(dbName=DB_path)

    @classmethod
    def tearDownClass(self):
        test_db_utils.remove_db()

    def test_rejects_unknown_entity(self):
        res = self.kb_api.connect_entities("Unknown Entity",
                                           "Justin Timberlake", "similar to",
                                           0)
        self.assertEqual(
            res, False,
            "Expected attempt to connect an unknown entity to fail.")

        res = self.kb_api.get_related_entities("Unknown Entity")
        self.assertEqual(res, [])

    def test_rejects_score_out_of_range(self):
        res = self.kb_api.connect_entities("Justin Timberlake",
                                           "Justin Timberlake", "similar to",
                                           -1)
        self.assertEqual(
            res, False,
            "Expected attempt to connect entities with score out-of-range to fail."
        )

        res = self.kb_api.get_related_entities("Justin Timberlake")
        self.assertEqual(len(res), 0)

    def test_rejects_duplicate_edge(self):
        res = self.kb_api.connect_entities("Justin Bieber",
                                           "Justin Timberlake", "similar to",
                                           1)
        self.assertEqual(res, False,
                         "Expected attempt to add a duplicate edge to fail.")

    def test_edges_not_null_constraints(self):
        res = self.kb_api.connect_entities(None, "Justin Timberlake",
                                           "similar to", 1)
        self.assertEqual(res, False,
                         "Expected 'None' value for artist to be rejected.")

        res = self.kb_api.connect_entities("U2", "U2", None, 1)
        self.assertEqual(
            res, False, "Expected 'None' value for edge type to be rejected.")

        res = self.kb_api.connect_entities("U2", "U2", "similar to", None)
        self.assertEqual(
            res, False, "Expected 'None' value for edge score to be rejected.")

    def test_entities_not_null_constraints(self):
        res = self.kb_api.add_artist(None)
        self.assertEqual(res, None,
                         "Expected 'None' value for artist to be rejected.")

        res = self.kb_api.add_song("Song name", None)
        self.assertEqual(res, None,
                         "Expected 'None' value for artist to be rejected.")

        res = self.kb_api.add_song(None, "Artist name")
        self.assertEqual(res, None,
                         "Expected 'None' value for artist to be rejected.")

        node_id = self.kb_api._add_node(None, "artist")
        self.assertEqual(
            node_id, None,
            "Expected 'None' value for entity name to be rejected.")

        node_id = self.kb_api._add_node("Some entity", None)
        self.assertEqual(
            node_id, None,
            "Expected 'None' value for entity type to be rejected.")
コード例 #10
0
 def __init__(self, db_path, keywords):
     self.keywords = keywords
     self.kb_api = KnowledgeBaseAPI(db_path)
     self.kb_named_entities = self.kb_api.get_all_music_entities()
コード例 #11
0
class TreeParser:
    """
    This class contains tree-parsing logic that will convert
    input text to an NLTK Parse-Tree using a Context Free
    Grammar.

    References:
        https://stackoverflow.com/questions/42322902/how-to-get-parse-tree-using-python-nltk
        https://www.nltk.org/book/ch08.html

    """
    def __init__(self, db_path, keywords):
        self.keywords = keywords
        self.kb_api = KnowledgeBaseAPI(db_path)
        self.kb_named_entities = self.kb_api.get_all_music_entities()

    def __call__(self, msg: str):
        """Creates an NLTK Parse Tree from the user input msg.

        Args:
            msg: A string of user input.
                 i.e 'play something similar to Justin Bieber'

        Returns: An NLTK parse tree, as defined by the CFG given
                 in the "parser" function.

        """
        # Remove punctuation from the string
        msg = re.sub(r"[.?']+\ *", " ", msg, flags=re.VERBOSE)

        # Parse sentence into list of tokens containing
        #  only entities and commands.
        tokens = self._lexer(msg)

        # Generate an NLTK parse tree
        tree = self._parser(tokens)
        return tree

    @property
    @lru_cache(1)
    def _unary_command_regexes(self):
        """Generates RegEx patterns from command signifiers.

        """
        patterns = {}
        for intent, keys in self.keywords.get("unary").items():
            if keys:
                patterns[intent] = re.compile(r'\b' + r'\b|\b'.join(keys) +
                                              r'\b')
        return patterns

    @property
    @lru_cache(1)
    def _terminal_command_regexes(self):
        """Generates RegEx patterns from command signifiers.

        """
        patterns = {}
        for intent, keys in self.keywords.get("terminal").items():
            if keys:
                patterns[intent] = re.compile(r'\b' + r'\b|\b'.join(keys) +
                                              r'\b')
        return patterns

    @property
    @lru_cache(1)
    def _binary_command_regexes(self):
        """Generates RegEx patterns from command signifiers.

        """
        patterns = {}
        for intent, keys in self.keywords.get("binary").items():
            if keys:
                patterns[intent] = re.compile(r'\b' + r'\b|\b'.join(keys) +
                                              r'\b')
        return patterns

    def _lexer(self, msg: str):
        """Lexes an input string into a list of tokens.

        This lexer first looks for Entities in the input string
        and parses them into tokens of the same name
        (i.e. 'blah U2 blah' -> ['U2']).

        Next, the lexer will look for Commands by searching for
        keywords that signify the command. These keyword+command
        pairings are defined in the command_evaluation layer.

        Args:
            msg: A string of user input.
                 i.e 'play something similar to justin bieber'

        Returns: A tokenized list of commands and Entities.
            i.e. ['control_play', 'query_similar_entities', 'Justin Bieber']

        """
        def lexing_algorithm(text):
            # TODO: >= O(n^2) (horrible time complexity). In the interest of
            # making forward progress, optimize this later.  Could use a
            # bloom filter to look for matches, then binary search  on
            # entities/commands find specific match once possible match is
            # found.  Or even just use a hashmap for searching.

            # Base case.
            if text == "":
                return []

            # 1. Parse named entities.
            for entity in self.kb_named_entities:
                if entity.lower() in text.lower():
                    pieces = text.lower().split(entity.lower())
                    left = pieces[0]
                    right = pieces[1]
                    if left == text or right == text:
                        # Safety measure to prevent '' causing infinite recursion.
                        break
                    return lexing_algorithm(left) + [
                        entity.strip()
                    ] + lexing_algorithm(right)

            # 2. Parse unary commands.
            for intent, pattern in self._unary_command_regexes.items():
                sub_msg = re.sub(pattern, 'MARKER', text)
                if sub_msg != text:
                    pieces = sub_msg.split('MARKER')
                    left = pieces[0]
                    right = pieces[1]
                    return lexing_algorithm(left) \
                           + [intent] \
                           + lexing_algorithm(right)

            # 3. Parse terminal commands.
            for intent, pattern in self._terminal_command_regexes.items():
                sub_msg = re.sub(pattern, 'MARKER', text)
                if sub_msg != text:
                    pieces = sub_msg.split('MARKER')
                    left = pieces[0]
                    right = pieces[1]
                    return lexing_algorithm(left) \
                           + [intent] \
                           + lexing_algorithm(right)

            # 4. Parse binary commands.
            for intent, pattern in self._binary_command_regexes.items():
                sub_msg = re.sub(pattern, 'MARKER', text)
                if sub_msg != text:
                    pieces = sub_msg.split('MARKER')
                    left = pieces[0]
                    right = pieces[1]
                    return lexing_algorithm(left) \
                           + [intent] \
                           + lexing_algorithm(right)

            # If no matches, then the word is a stopword.
            return []

        return lexing_algorithm(msg)

    def _parser(self, tokens: List[str]):
        """Generates a Parse Tree from a list of tokens
        provided by the Lexer.

        Args:
            tokens: A tokenized list of commands and Entities.
            i.e. ['control_play', 'query_similar_entities', 'Justin Bieber']

        Returns: An nltk parse tree, as defined by the CFG given
                 in this function.

        """

        # TODO:   Improve the CFG work for the following:
        #          -  Play songs faster than despicito
        #          -  Play something similar to despicito but faster
        #          -  Play something similar to u2 and justin bieber

        def gen_lexing_patterns(vals: List[str]):
            # TODO: Here we remove entries containing ',
            #       as it is a special character used by
            #       the NLTK parser. We need to fix this
            #       eventually.
            safe_vals = [s for s in vals if "\'" not in s]
            return "' | '".join(safe_vals) or "NONE"

        # A Probabilistic Context Free Grammar (PCFG)
        # can be used to simulate "operator precedence",
        # which removes the problems of ambiguity in
        # the grammar.
        grammar = nltk.PCFG.fromstring("""
        Root -> Terminal_Command Result         [0.6]
        Root -> Terminal_Command                [0.4]
        Result -> Entity                        [0.5]
        Result -> Unary_Command Result          [0.1]
        Result -> Result Binary_Command Result  [0.4]
        Entity -> '{}'                          [1.0]
        Unary_Command -> '{}'                   [1.0]
        Terminal_Command -> '{}'                [1.0]
        Binary_Command -> '{}'                  [1.0]
        """.format(
            gen_lexing_patterns(self.kb_named_entities),
            gen_lexing_patterns(self.keywords.get("unary").keys()),
            gen_lexing_patterns(self.keywords.get("terminal").keys()),
            gen_lexing_patterns(self.keywords.get("binary").keys()),
        ))

        parser = nltk.ViterbiParser(grammar)
        # TODO: Returns the first tree, but need to deal with
        #       case where grammar is ambiguous, and more than
        #       one tree is returned.
        return next(parser.parse(tokens))
コード例 #12
0
class TestMusicKnowledgeBaseAPI(unittest.TestCase):
    def setUp(self):
        DB_path = test_db_utils.create_and_populate_db()
        self.kb_api = KnowledgeBaseAPI(dbName=DB_path)

    def tearDown(self):
        test_db_utils.remove_db()

    def test_get_song_data(self):
        song_data = self.kb_api.get_song_data("Despacito")
        # we don't care what the node ID is
        self.assertEqual(
            1, len(song_data),
            "Expected exactly one result from query for song 'Despacito'.")
        self.assertEqual(
            song_data[0],
            dict(
                id=10,
                song_name="Despacito",
                artist_name="Justin Bieber",
                duration_ms=222222,
                popularity=10,
            ), "Found expected values for song data for 'Despacito'.")

    def test_get_song_data_dne(self):
        res = self.kb_api.get_song_data("Not In Database")
        self.assertEqual(
            res, [],
            "Expected empty list of results for queried song not in DB.")

    def test_get_artist_data(self):
        artist_data = self.kb_api.get_artist_data("Justin Bieber")
        self.assertEqual(
            len(artist_data), 1,
            "Expected exactly one result for artist 'Justin Bieber'.")
        artist_data[0]["genres"] = set(artist_data[0]["genres"])
        self.assertEqual(
            artist_data[0],
            dict(genres=set(["Pop", "Super pop"]),
                 id=1,
                 num_spotify_followers=4000,
                 name="Justin Bieber"),
            "Artist data for 'Justin Bieber' did not match expected.",
        )

    def test_get_artist_data_dne(self):
        artist_data = self.kb_api.get_artist_data("Unknown artist")
        self.assertEqual(artist_data, [],
                         "Expected 'None' result for unknown artist.")

    def test_get_songs(self):
        res = self.kb_api.get_songs_by_artist("Justin Bieber")
        self.assertEqual(
            res, ["Despacito", "Sorry"],
            "Songs retrieved for 'Justin Bieber' did not match expected.")

        res = self.kb_api.get_songs_by_artist("Justin Timberlake")
        self.assertEqual(
            res, ["Rock Your Body"],
            "Songs retrieved for 'Justin Timberlake' did not match expected.")

    def test_get_songs_unknown_artist(self):
        res = self.kb_api.get_songs_by_artist("Unknown artist")
        self.assertEqual(res, None,
                         "Unexpected songs retrieved for unknown artist.")

    def test_find_similar_song(self):
        res = self.kb_api.get_related_entities("Despacito")
        self.assertEqual(
            len(res),
            1,
            "Expected only one song similar to \"Despacito\". Found {0}".
            format(res),
        )
        self.assertEqual(
            res[0],
            "Rock Your Body",
            "Expected to find \"Rock Your Body\" as similar to \"Despacito\".",
        )

    def test_find_similar_artist(self):
        res = self.kb_api.get_related_entities("Justin Bieber")
        self.assertEqual(
            len(res),
            2,
            "Expected exactly two artists similar to Justin Bieber.",
        )
        self.assertEqual(
            res[0],
            "Justin Timberlake",
            "Expected to find Justin Timberlake as similar to Justin Bieber.",
        )
        self.assertEqual(
            res[1],
            "Shawn Mendes",
            "Expected to find Justin Timberlake as similar to Justin Bieber.",
        )

    def test_get_all_music_entities(self):
        res = self.kb_api.get_all_music_entities()
        self.assertTrue(
            'Justin Bieber' in res,
            'Expected to find "Justin Bieber" in the list of entities.')

    def test_find_similar_to_entity_that_dne(self):
        res = self.kb_api.get_related_entities("Unknown Entity")
        self.assertEqual(res, [])

    def test_get_related_genres(self):
        genre_rel_str = self.kb_api.approved_relations["genre"]
        rel_genres = self.kb_api.get_related_entities("Justin Bieber",
                                                      rel_str=genre_rel_str)
        self.assertEqual(
            set(rel_genres), set(["Pop", "Super pop"]),
            "Did not find expected related genres for artist 'Justin Bieber'")

        rel_genres = self.kb_api.get_related_entities("Justin Timberlake",
                                                      rel_str=genre_rel_str)
        self.assertEqual(
            rel_genres, ["Pop"],
            "Did not find expected related genres for artist 'Justin Timberlake'"
        )

    def test_connect_entities_by_similarity(self):
        res = self.kb_api.get_related_entities("Shawn Mendes")
        self.assertEqual(len(res), 0)

        res = self.kb_api.connect_entities("Shawn Mendes", "Justin Timberlake",
                                           "similar to", 0)
        self.assertEqual(res, True, "")

        res = self.kb_api.get_related_entities("Shawn Mendes")
        self.assertEqual(len(res), 1)
        self.assertEqual(res[0], "Justin Timberlake")

    def test_connect_entities_by_genre(self):
        genre_rel_str = self.kb_api.approved_relations["genre"]
        res = self.kb_api.get_related_entities("Shawn Mendes",
                                               rel_str=genre_rel_str)
        self.assertEqual(len(res), 0)

        res = self.kb_api.connect_entities("Shawn Mendes", "Pop", "of genre",
                                           100)
        self.assertEqual(res, True, "")

        res = self.kb_api.get_related_entities("Shawn Mendes",
                                               rel_str=genre_rel_str)
        self.assertEqual(
            len(res), 1,
            "Expected to find exactly one related genre for 'Shawn Mendes'.")
        self.assertEqual(res[0], "Pop",
                         "Found unexpected genre for 'Shawn Mendes'.")

    def test_rejects_connect_ambiguous_entities(self):
        self.kb_api.add_artist("Artist and Song name clash")
        self.kb_api.add_song("Artist and Song name clash", "U2")

        res = self.kb_api.connect_entities("Artist and Song name clash",
                                           "Justin Timberlake", "similar to",
                                           0)
        self.assertEqual(res, False, "")

    def test_get_node_ids_by_entity_type(self):
        node_ids_dict = self.kb_api.get_node_ids_by_entity_type(
            "Justin Bieber")
        all_entity_types = list(node_ids_dict.keys())
        self.assertEqual(
            all_entity_types, ["artist"],
            "Expected to find (only) entities of type 'artist' for 'Justin Bieber', but got: {}"
            .format(node_ids_dict))
        self.assertEqual(
            len(node_ids_dict.get("artist")), 1,
            "Expected to find exactly one entity of type 'artist' for 'Justin Bieber', but got: {}"
            .format(node_ids_dict))

        self.kb_api.add_artist("Song and Artist name")
        self.kb_api.add_song("Song and Artist name", "U2")

        node_ids_dict = self.kb_api.get_node_ids_by_entity_type(
            "Song and Artist name")
        alphabetized_entity_types = sorted(list(node_ids_dict.keys()))
        self.assertEqual(
            alphabetized_entity_types, ["artist", "song"],
            "Expected to find (exaclty) two entity types: 'artist' and 'song', but got: {}"
            .format(node_ids_dict))

    def test_get_matching_node_ids(self):
        node_ids = self.kb_api._get_matching_node_ids("Justin Bieber")
        self.assertEqual(
            len(node_ids), 1,
            "Expected to find exactly one matching node for 'Justin Bieber', but got: {}"
            .format(node_ids))

    def test_get_matching_node_ids_empty(self):
        node_ids = self.kb_api._get_matching_node_ids("Unknown artist")
        self.assertEqual(
            len(node_ids), 0,
            "Expected to find no matching node for 'Unknown artist', but got: {}"
            .format(node_ids))

    def test_add_artist(self):
        sample_genres = ["Pop", "Very pop", "Omg so pop"]
        new_artist_node_id = self.kb_api.add_artist("Heart",
                                                    genres=sample_genres,
                                                    num_spotify_followers=1)
        self.assertNotEqual(new_artist_node_id, None,
                            "Failed to add artist 'Heart' to knowledge base.")

        artist_data = self.kb_api.get_artist_data("Heart")
        self.assertEqual(len(artist_data), 1,
                         "Expected unique match for artist 'Heart'.")

        artist_data = artist_data[0]
        artist_data["genres"] = set(artist_data["genres"])

        self.assertEqual(
            artist_data,
            dict(
                name="Heart",
                id=new_artist_node_id,
                genres=set(sample_genres),
                num_spotify_followers=1,
            ),
            "Did not find expected genres for artist 'Heart'.",
        )

    def test_reject_add_artist_already_exists(self):
        artist_node_id = self.kb_api.get_artist_data("Justin Bieber")[0]["id"]
        res = self.kb_api.add_artist("Justin Bieber")
        self.assertEqual(
            res, artist_node_id,
            "Expected rejection of attempt to add artist 'Justin Bieber' to knowledge base."
        )

    def test_add_artist_omitted_opt_params(self):
        res = self.kb_api.add_artist("Heart")
        self.assertNotEqual(res, None,
                            "Failed to add artist 'Heart' to knowledge base.")

        artist_data = self.kb_api.get_artist_data("Heart")
        self.assertEqual(len(artist_data), 1,
                         "Expected unique match for artist 'Heart'.")
        self.assertEqual(artist_data[0]["name"], "Heart",
                         "Expected match for artist 'Heart'.")
        self.assertEqual(artist_data[0]["genres"], [],
                         "Expected no genres for artist 'Heart'.")
        self.assertEqual(artist_data[0]["num_spotify_followers"], None,
                         "Expected no genres for artist 'Heart'.")

    def test_add_song(self):
        new_song_node_id = self.kb_api.add_song("Heart",
                                                "Justin Bieber",
                                                duration_ms=11111,
                                                popularity=100)
        self.assertNotEqual(
            new_song_node_id, None,
            "Failed to add song 'Heart' by artist 'Justin Bieber' to knowledge base."
        )

        song_data = self.kb_api.get_song_data("Heart")
        self.assertEqual(len(song_data), 1, "Expected exactly one result.")

        song_data = song_data[0]
        self.assertEqual(
            song_data,
            dict(id=new_song_node_id,
                 song_name="Heart",
                 artist_name="Justin Bieber",
                 duration_ms=11111,
                 popularity=100), "Received unexpected song data")

    def test_add_song_omitted_opt_params(self):
        new_song_node_id = self.kb_api.add_song("What do you mean?",
                                                "Justin Bieber")
        self.assertNotEqual(
            new_song_node_id, None,
            "Failed to add song 'sorry' by artist 'Justin Bieber' to knowledge base."
        )

        song_data = self.kb_api.get_song_data("What do you mean?")
        self.assertEqual(len(song_data), 1, "Expected exactly one result.")

        song_data = song_data[0]
        self.assertEqual(
            song_data,
            dict(id=new_song_node_id,
                 song_name="What do you mean?",
                 artist_name="Justin Bieber",
                 duration_ms=None,
                 popularity=None), "Received unexpected song data")

    def test_add_duplicate_song_for_different_artist(self):
        new_song_node_id = self.kb_api.add_song("Despacito",
                                                "Justin Timberlake")
        self.assertNotEqual(
            new_song_node_id, None,
            "Failed to add song 'Despacito' by artist 'Justin Timberlake' to knowledge base."
        )

        res = self.kb_api.get_song_data("Despacito")
        self.assertEqual(len(res), 2,
                         "Expected exactly one match for song 'Despacito'.")

        artists = set([res[0]["artist_name"], res[1]["artist_name"]])
        self.assertEqual(
            artists, set(["Justin Bieber", "Justin Timberlake"]),
            "Expected to find duplicate artists 'Justin Bieber' and 'Justin Timberlake' for song 'Despacito')"
        )

    def test_reject_add_song_already_exists(self):
        res = self.kb_api.add_song("Despacito", "Justin Bieber")
        self.assertEqual(
            res, None,
            "Expected rejection of attempt to add song 'Despacito' by 'Justin Bieber' to knowledge base."
        )

    # The logic tested here is currently implemented in the KR API
    # However, if it is moved to the schema (e.g. trigger functions),
    # then this test can be moved to the schema test module
    def test_new_song_with_unknown_artist_rejected(self):
        res = self.kb_api.add_song("Song by Unknown Artist", "Unknown artist")
        self.assertEqual(res, None,
                         "Expected song with unknown artist to be rejected")

        res = self.kb_api.get_song_data("Song by Unknown Artist")
        self.assertEqual(
            len(res), 0,
            "Insertion of song with unknown artist should not have been added to nodes table"
        )

    def test_add_genre(self):
        node_id = self.kb_api.add_genre("hip hop")
        self.assertEqual(
            type(node_id), int,
            "Genre addition appears to have failed: expected int return value (node id) on valid attempt to add genre."
        )

        res = self.kb_api.add_genre("hip hop")
        self.assertEqual(
            res, node_id,
            "Expected original node id to be fetched when attempting to add duplicate genre."
        )

    def test_add_genre_creates_node(self):
        res = self.kb_api.add_genre("hip hop")
        self.assertEqual(
            type(res), int,
            "Genre addition appears to have failed: expected int return value (node id) on valid attempt to add genre."
        )

        entities = self.kb_api.get_node_ids_by_entity_type("hip hop")
        self.assertIn(
            "genre", entities,
            "Expected to find node associated with genre 'hip hop'.")

    def test_get_node_ids_by_entity_type(self):
        res = self.kb_api.get_node_ids_by_entity_type("Justin Timberlake")
        self.assertTrue(
            "artist" in res,
            "Expected to find an 'artist' entity with name 'Justin Timberlake', but got: {0}"
            .format(res))
        self.assertEqual(
            len(res["artist"]), 1,
            "Expected to find exactly one entity (of type 'artist') with name 'Justin Timberlake', but got: {0}"
            .format(res))

        res = self.kb_api.get_node_ids_by_entity_type("Despacito")
        self.assertTrue(
            "song" in res,
            "Expected to find an 'song' entity with name 'Despacito', but got: {0}"
            .format(res))
        self.assertEqual(
            len(res["song"]), 1,
            "Expected to find exactly one entity (of type 'song') with name 'Despacito', but got: {0}"
            .format(res))

        res = self.kb_api.get_node_ids_by_entity_type("Unknown entity")
        self.assertEqual(
            res, {},
            "Expected no results from query for unknown entity, but got {}".
            format(res))
コード例 #13
0
class TestMusicKnowledgeBaseAPI(unittest.TestCase):
    def setUp(self):
        DB_path = test_db_utils.create_and_populate_db()
        self.kb_api = KnowledgeBaseAPI(dbName=DB_path)

    def tearDown(self):
        test_db_utils.remove_db()

    def test_get_song_data(self):
        song_data = self.kb_api.get_song_data("Despacito")
        expected_res = dict(
            id=10,
            song_name="Despacito",
            artist_name="Justin Bieber",
            duration_ms=222222,
            popularity=10,
            valence=0.1,
            spotify_uri='spotify:track:Despacito',
            acousticness=None,
            danceability=None,
            energy=None,
            instrumentalness=None,
            liveness=None,
            loudness=None,
            speechiness=None,
            tempo=None,
            mode=None,
            musical_key=None,
            time_signature=None,
        )

        self.assertEqual(
            1, len(song_data),
            "Expected exactly one result from query for song 'Despacito'.")
        self.assertEqual(
            song_data[0], expected_res,
            "Found expected values for song data for 'Despacito'.")

        song_data = self.kb_api.get_song_data(song_id=10)
        self.assertEqual(
            1, len(song_data),
            "Expected exactly one result from query for song 'Despacito'.")
        self.assertEqual(
            song_data[0], expected_res,
            "Found expected values for song data for 'Despacito'.")

    def test_get_song_data_bad_params(self):
        res = self.kb_api.get_song_data()
        self.assertEqual(
            res, [],
            "Expect no results when both song name and ID are ommitted.")

    def test_get_song_data_case_insensitive(self):
        song_data = self.kb_api.get_song_data("dESpAcItO")
        self.assertEqual(
            1, len(song_data),
            "Expected exactly one result from query for song 'dESpAcItO'.")
        self.assertEqual(
            song_data[0],
            dict(
                id=10,
                song_name="Despacito",
                artist_name="Justin Bieber",
                duration_ms=222222,
                popularity=10,
                valence=0.1,
                spotify_uri='spotify:track:Despacito',
                acousticness=None,
                danceability=None,
                energy=None,
                instrumentalness=None,
                liveness=None,
                loudness=None,
                speechiness=None,
                tempo=None,
                mode=None,
                musical_key=None,
                time_signature=None,
            ), "Found expected values for song data for 'Despacito'.")

        song_data = self.kb_api.get_song_data("bEaUTIful DAY")
        self.assertEqual(
            1, len(song_data),
            "Expected exactly one result from query for song 'bEaUTIful DAY'.")
        self.assertEqual(
            song_data[0],
            dict(
                id=12,
                song_name="Beautiful Day",
                artist_name="U2",
                duration_ms=111111,
                popularity=60,
                valence=1.0,
                spotify_uri='spotify:track:BeautifulDay',
                acousticness=None,
                danceability=None,
                energy=None,
                instrumentalness=None,
                liveness=None,
                loudness=None,
                speechiness=None,
                tempo=None,
                mode=None,
                musical_key=None,
                time_signature=None,
            ), "Found expected values for song data for 'Beautiful Day'.")

    def test_get_song_data_ambiguous_name(self):
        res = self.kb_api.get_song_data("Sorry")

        self.assertEqual(2, len(res), "Expected exactly two results.")
        self.assertEqual(set([14, 15]), set([hit["id"] for hit in res]),
                         "Found unexpected IDs.")

        res = self.kb_api.get_song_data("Sorry", 14)
        self.assertEqual(1, len(res), "Expected exactly one result.")
        self.assertEqual(
            res[0],
            dict(
                id=14,
                song_name="Sorry",
                artist_name="Justin Bieber",
                popularity=20,
                duration_ms=333333,
                valence=0.3,
                spotify_uri='spotify:track:Sorry',
                acousticness=None,
                danceability=None,
                energy=None,
                instrumentalness=None,
                liveness=None,
                loudness=None,
                speechiness=None,
                tempo=None,
                mode=None,
                musical_key=None,
                time_signature=None,
            ), "Unexpected result contents.")

        res = self.kb_api.get_song_data("Sorry", 15)
        self.assertEqual(1, len(res), "Expected exactly one result.")
        self.assertEqual(
            res[0],
            dict(
                id=15,
                song_name="Sorry",
                artist_name="The Anti Justin Bieber",
                duration_ms=None,
                popularity=None,
                spotify_uri=None,
                acousticness=None,
                danceability=None,
                energy=None,
                instrumentalness=None,
                liveness=None,
                loudness=None,
                speechiness=None,
                valence=None,
                tempo=None,
                mode=None,
                musical_key=None,
                time_signature=None,
            ), "Unexpected result contents.")

    def test_get_song_data_dne(self):
        res = self.kb_api.get_song_data("Not In Database")
        self.assertEqual(
            res, [],
            "Expected empty list of results for queried song not in DB.")

    def test_get_artist_data(self):
        artist_data = self.kb_api.get_artist_data("Justin Bieber")
        self.assertEqual(
            len(artist_data), 1,
            "Expected exactly one result for artist 'Justin Bieber'.")
        artist_data[0]["genres"] = set(artist_data[0]["genres"])
        self.assertEqual(
            artist_data[0],
            dict(genres=set(["Pop", "Super pop"]),
                 id=1,
                 num_spotify_followers=4000,
                 name="Justin Bieber"),
            "Artist data for 'Justin Bieber' did not match expected.",
        )

    def test_get_artist_data_dne(self):
        artist_data = self.kb_api.get_artist_data("Unknown artist")
        self.assertEqual(artist_data, [],
                         "Expected 'None' result for unknown artist.")

    def test_get_all_song_names(self):
        expected_song_names = set([
            "Despacito", "Rock Your Body", "Beautiful Day", "In My Blood",
            "Sorry"
        ])
        res = self.kb_api.get_all_song_names()
        self.assertEqual(set(res), expected_song_names,
                         "Unexpected result from fetching all songs from db.")

    def test_get_all_artist_names(self):
        expected_artist_names = set([
            "Justin Bieber", "Justin Timberlake", "U2", "Shawn Mendes",
            "The Anti Justin Bieber"
        ])
        res = self.kb_api.get_all_artist_names()
        self.assertEqual(
            set(res), expected_artist_names,
            "Unexpected result from fetching all artists from db.")

    def test_get_songs(self):
        res = self.kb_api.get_songs_by_artist("Justin Bieber")
        self.assertEqual(
            res,
            [
                dict(song_name="Despacito", id=10),
                dict(song_name="Sorry", id=14)
            ],
            "Songs retrieved for 'Justin Bieber' did not match expected.",
        )

        res = self.kb_api.get_songs_by_artist("Justin Timberlake")
        self.assertEqual(
            res,
            [dict(song_name="Rock Your Body", id=11)],
            "Songs retrieved for 'Justin Timberlake' did not match expected.",
        )

    def test_get_songs_unknown_artist(self):
        res = self.kb_api.get_songs_by_artist("Unknown artist")
        self.assertEqual(res, [],
                         "Unexpected songs retrieved for unknown artist.")

    def test_songs_are_related_popularity(self):
        self.assertEqual(
            self.kb_api.songs_are_related(12, 10, "less popular"),
            False,
            "'Beautiful Day' is MORE popular than 'Despacito'",
        )

        self.assertEqual(
            self.kb_api.songs_are_related(12, 10, "more popular"),
            True,
            "'Beautiful Day' is MORE popular than 'Despacito'",
        )

    def test_songs_are_related_valence(self):
        self.assertEqual(
            self.kb_api.songs_are_related(12, 10, "more happy"),
            True,
            "'Beautiful Day' is MORE happy than 'Despacito'",
        )

        self.assertEqual(
            self.kb_api.songs_are_related(12, 10, "less happy"),
            False,
            "'Beautiful Day' is MORE happy than 'Despacito'",
        )

    def test_songs_are_related_same_song(self):
        self.assertEqual(
            self.kb_api.songs_are_related(12, 12, "more popular"),
            False,
            "'Beautiful Day' cannot be more popular than itself.",
        )

        self.assertEqual(
            self.kb_api.songs_are_related(12, 12, "more popular"),
            False,
            "'Beautiful Day' cannot be more popular than itself.",
        )

    def test_songs_are_related_unknown_relationship(self):
        self.assertEqual(
            self.kb_api.songs_are_related(12, 10, 'more something'),
            False,
            "Expected return value False when relationship is invalid.",
        )

    def test_songs_are_related_value_missing(self):
        self.assertEqual(
            self.kb_api.songs_are_related(12, 10, 'more dancey'),
            False,
            "Expected return value False when DB does not contain value for given (valid) relationship.",
        )

    def test_find_similar_song(self):
        res = self.kb_api.get_related_entities("Despacito")
        self.assertEqual(
            len(res),
            1,
            "Expected only one song similar to \"Despacito\". Found {0}".
            format(res),
        )
        self.assertEqual(
            res[0],
            "Rock Your Body",
            "Expected to find \"Rock Your Body\" as similar to \"Despacito\".",
        )

    def test_find_similar_artist(self):
        res = self.kb_api.get_related_entities("Justin Bieber")
        self.assertEqual(
            len(res),
            2,
            "Expected exactly two artists similar to Justin Bieber.",
        )
        self.assertEqual(
            res[0],
            "Justin Timberlake",
            "Expected to find Justin Timberlake as similar to Justin Bieber.",
        )
        self.assertEqual(
            res[1],
            "Shawn Mendes",
            "Expected to find Justin Timberlake as similar to Justin Bieber.",
        )

    def test_find_similar_to_entity_that_dne(self):
        res = self.kb_api.get_related_entities("Unknown Entity")
        self.assertEqual(res, [])

    def test_get_related_genres(self):
        genre_rel_str = self.kb_api.approved_relations["genre"]
        rel_genres = self.kb_api.get_related_entities("Justin Bieber",
                                                      rel_str=genre_rel_str)
        self.assertEqual(
            set(rel_genres), set(["Pop", "Super pop"]),
            "Did not find expected related genres for artist 'Justin Bieber'")

        rel_genres = self.kb_api.get_related_entities("Justin Timberlake",
                                                      rel_str=genre_rel_str)
        self.assertEqual(
            rel_genres, ["Pop"],
            "Did not find expected related genres for artist 'Justin Timberlake'"
        )

    def test_connect_entities_by_similarity(self):
        res = self.kb_api.get_related_entities("Shawn Mendes")
        self.assertEqual(len(res), 0)

        res = self.kb_api.connect_entities("Shawn Mendes", "Justin Timberlake",
                                           "similar to", 0)
        self.assertEqual(res, True, "")

        res = self.kb_api.get_related_entities("Shawn Mendes")
        self.assertEqual(len(res), 1)
        self.assertEqual(res[0], "Justin Timberlake")

    def test_connect_entities_by_genre(self):
        genre_rel_str = self.kb_api.approved_relations["genre"]
        res = self.kb_api.get_related_entities("Shawn Mendes",
                                               rel_str=genre_rel_str)
        self.assertEqual(len(res), 0)

        res = self.kb_api.connect_entities("Shawn Mendes", "Pop", "of genre",
                                           100)
        self.assertEqual(res, True, "")

        res = self.kb_api.get_related_entities("Shawn Mendes",
                                               rel_str=genre_rel_str)
        self.assertEqual(
            len(res), 1,
            "Expected to find exactly one related genre for 'Shawn Mendes'.")
        self.assertEqual(res[0], "Pop",
                         "Found unexpected genre for 'Shawn Mendes'.")

    def test_rejects_connect_ambiguous_entities(self):
        self.kb_api.add_artist("Artist and Song name clash")
        self.kb_api.add_song("Artist and Song name clash", "U2")

        res = self.kb_api.connect_entities("Artist and Song name clash",
                                           "Justin Timberlake", "similar to",
                                           0)
        self.assertEqual(res, False, "")

    def test_get_node_ids_by_entity_type(self):
        node_ids_dict = self.kb_api.get_node_ids_by_entity_type(
            "Justin Bieber")
        all_entity_types = list(node_ids_dict.keys())
        self.assertEqual(
            all_entity_types, ["artist"],
            "Expected to find (only) entities of type 'artist' for 'Justin Bieber', but got: {}"
            .format(node_ids_dict))
        self.assertEqual(
            len(node_ids_dict.get("artist")), 1,
            "Expected to find exactly one entity of type 'artist' for 'Justin Bieber', but got: {}"
            .format(node_ids_dict))

        self.kb_api.add_artist("Song and Artist name")
        self.kb_api.add_song("Song and Artist name", "U2")

        node_ids_dict = self.kb_api.get_node_ids_by_entity_type(
            "Song and Artist name")
        alphabetized_entity_types = sorted(list(node_ids_dict.keys()))
        self.assertEqual(
            alphabetized_entity_types, ["artist", "song"],
            "Expected to find (exactly) two entity types: 'artist' and 'song', but got: {}"
            .format(node_ids_dict))

    def test_get_matching_node_ids(self):
        node_ids = self.kb_api._get_matching_node_ids("Justin Bieber")
        self.assertEqual(
            len(node_ids), 1,
            "Expected to find exactly one matching node for 'Justin Bieber', but got: {}"
            .format(node_ids))

    def test_get_matching_node_ids_empty(self):
        node_ids = self.kb_api._get_matching_node_ids("Unknown artist")
        self.assertEqual(
            len(node_ids), 0,
            "Expected to find no matching node for 'Unknown artist', but got: {}"
            .format(node_ids))

    def test_add_artist(self):
        sample_genres = ["Pop", "Very pop", "Omg so pop"]
        new_artist_node_id = self.kb_api.add_artist("Heart",
                                                    genres=sample_genres,
                                                    num_spotify_followers=1)
        self.assertNotEqual(new_artist_node_id, None,
                            "Failed to add artist 'Heart' to knowledge base.")

        artist_data = self.kb_api.get_artist_data("Heart")
        self.assertEqual(len(artist_data), 1,
                         "Expected unique match for artist 'Heart'.")

        artist_data = artist_data[0]
        artist_data["genres"] = set(artist_data["genres"])

        self.assertEqual(
            artist_data,
            dict(
                name="Heart",
                id=new_artist_node_id,
                genres=set(sample_genres),
                num_spotify_followers=1,
            ),
            "Did not find expected genres for artist 'Heart'.",
        )

    def test_reject_add_artist_already_exists(self):
        artist_node_id = self.kb_api.get_artist_data("Justin Bieber")[0]["id"]
        res = self.kb_api.add_artist("Justin Bieber")
        self.assertEqual(
            res, artist_node_id,
            "Expected rejection of attempt to add artist 'Justin Bieber' to knowledge base."
        )

    def test_add_artist_omitted_opt_params(self):
        res = self.kb_api.add_artist("Heart")
        self.assertNotEqual(res, None,
                            "Failed to add artist 'Heart' to knowledge base.")

        artist_data = self.kb_api.get_artist_data("Heart")
        self.assertEqual(len(artist_data), 1,
                         "Expected unique match for artist 'Heart'.")
        self.assertEqual(artist_data[0]["name"], "Heart",
                         "Expected match for artist 'Heart'.")
        self.assertEqual(artist_data[0]["genres"], [],
                         "Expected no genres for artist 'Heart'.")
        self.assertEqual(artist_data[0]["num_spotify_followers"], None,
                         "Expected no genres for artist 'Heart'.")

    def test_add_song(self):
        new_song_node_id = self.kb_api.add_song(
            "Heart",
            "Justin Bieber",
            duration_ms=11111,
            popularity=100,
            spotify_uri="spotify:track:Heart",
            audio_features=dict(energy=0.5),
        )
        self.assertNotEqual(
            new_song_node_id, None,
            "Failed to add song 'Heart' by artist 'Justin Bieber' to knowledge base."
        )

        song_data = self.kb_api.get_song_data("Heart")
        self.assertEqual(len(song_data), 1, "Expected exactly one result.")

        song_data = song_data[0]
        self.assertEqual(
            song_data,
            dict(
                id=new_song_node_id,
                song_name="Heart",
                artist_name="Justin Bieber",
                duration_ms=11111,
                popularity=100,
                spotify_uri="spotify:track:Heart",
                energy=0.5,
                acousticness=None,
                danceability=None,
                instrumentalness=None,
                liveness=None,
                loudness=None,
                speechiness=None,
                valence=None,
                tempo=None,
                mode=None,
                musical_key=None,
                time_signature=None,
            ),
            "Received unexpected song data",
        )

    def test_add_song_omitted_opt_params(self):
        new_song_node_id = self.kb_api.add_song("What do you mean?",
                                                "Justin Bieber")
        self.assertNotEqual(
            new_song_node_id, None,
            "Failed to add song 'sorry' by artist 'Justin Bieber' to knowledge base."
        )

        song_data = self.kb_api.get_song_data("What do you mean?")
        self.assertEqual(len(song_data), 1, "Expected exactly one result.")

        song_data = song_data[0]
        self.assertEqual(
            song_data,
            dict(
                id=new_song_node_id,
                song_name="What do you mean?",
                artist_name="Justin Bieber",
                duration_ms=None,
                popularity=None,
                spotify_uri=None,
                acousticness=None,
                danceability=None,
                energy=None,
                instrumentalness=None,
                liveness=None,
                loudness=None,
                speechiness=None,
                valence=None,
                tempo=None,
                mode=None,
                musical_key=None,
                time_signature=None,
            ), "Received unexpected song data")

    def test_add_song_with_all_audio_features(self):
        valid_audio_features = dict(
            acousticness=1,
            danceability=1,
            energy=1,
            instrumentalness=1,
            liveness=1,
            loudness=0,
            speechiness=1,
            valence=1,
            tempo=1,
            mode="major",
            musical_key=1,
            time_signature=3,
        )
        new_song_id = self.kb_api.add_song(
            "Song with all audio features",
            "Justin Bieber",
            audio_features=valid_audio_features,
        )

        song_data = self.kb_api.get_song_data("Song with all audio features")
        self.assertEqual(
            dict(
                id=new_song_id,
                song_name="Song with all audio features",
                artist_name="Justin Bieber",
                duration_ms=None,
                popularity=None,
                spotify_uri=None,
                acousticness=1.0,
                danceability=1.0,
                energy=1.0,
                instrumentalness=1.0,
                liveness=1.0,
                loudness=0.0,
                speechiness=1.0,
                valence=1.0,
                tempo=1.0,
                mode="major",
                musical_key=1,
                time_signature=3,
            ),
            song_data[0],
            "Retrieved song data did not match expected.",
        )

    def test_add_duplicate_song_for_different_artist(self):
        new_song_node_id = self.kb_api.add_song("Despacito",
                                                "Justin Timberlake")
        self.assertNotEqual(
            new_song_node_id, None,
            "Failed to add song 'Despacito' by artist 'Justin Timberlake' to knowledge base."
        )

        res = self.kb_api.get_song_data("Despacito")
        self.assertEqual(len(res), 2,
                         "Expected exactly one match for song 'Despacito'.")

        artists = set([res[0]["artist_name"], res[1]["artist_name"]])
        self.assertEqual(
            artists, set(["Justin Bieber", "Justin Timberlake"]),
            "Expected to find duplicate artists 'Justin Bieber' and 'Justin Timberlake' for song 'Despacito')"
        )

    def test_reject_add_song_already_exists(self):
        res = self.kb_api.add_song("Despacito", "Justin Bieber")
        self.assertEqual(
            res, None,
            "Expected rejection of attempt to add song 'Despacito' by 'Justin Bieber' to knowledge base."
        )

    # The logic tested here is currently implemented in the KR API
    # However, if it is moved to the schema (e.g. trigger functions),
    # then this test can be moved to the schema test module
    def test_new_song_with_unknown_artist_rejected(self):
        res = self.kb_api.add_song("Song by Unknown Artist", "Unknown artist")
        self.assertEqual(res, None,
                         "Expected song with unknown artist to be rejected")

        res = self.kb_api.get_song_data("Song by Unknown Artist")
        self.assertEqual(
            len(res), 0,
            "Insertion of song with unknown artist should not have been added to nodes table"
        )

    def test_add_genre(self):
        node_id = self.kb_api.add_genre("hip hop")
        self.assertEqual(
            type(node_id), int,
            "Genre addition appears to have failed: expected int return value (node id) on valid attempt to add genre."
        )

        res = self.kb_api.add_genre("hip hop")
        self.assertEqual(
            res, node_id,
            "Expected original node id to be fetched when attempting to add duplicate genre."
        )

    def test_add_genre_creates_node(self):
        res = self.kb_api.add_genre("hip hop")
        self.assertEqual(
            type(res), int,
            "Genre addition appears to have failed: expected int return value (node id) on valid attempt to add genre."
        )

        entities = self.kb_api.get_node_ids_by_entity_type("hip hop")
        self.assertIn(
            "genre", entities,
            "Expected to find node associated with genre 'hip hop'.")

    def test_get_node_ids_by_entity_type(self):
        res = self.kb_api.get_node_ids_by_entity_type("Justin Timberlake")
        self.assertTrue(
            "artist" in res,
            "Expected to find an 'artist' entity with name 'Justin Timberlake', but got: {0}"
            .format(res))
        self.assertEqual(
            len(res["artist"]), 1,
            "Expected to find exactly one entity (of type 'artist') with name 'Justin Timberlake', but got: {0}"
            .format(res))

        res = self.kb_api.get_node_ids_by_entity_type("Despacito")
        self.assertTrue(
            "song" in res,
            "Expected to find an 'song' entity with name 'Despacito', but got: {0}"
            .format(res))
        self.assertEqual(
            len(res["song"]), 1,
            "Expected to find exactly one entity (of type 'song') with name 'Despacito', but got: {0}"
            .format(res))

        res = self.kb_api.get_node_ids_by_entity_type("Unknown entity")
        self.assertEqual(
            res, {},
            "Expected no results from query for unknown entity, but got {}".
            format(res))
コード例 #14
0
class TreeEvalEngine:
    """This class stores the possible interactions between the
    user and the system, and the logic to act on them.

    The various "Commands" properties contain mappings that store
    signifiers of user's intent, along with the functions
    that they map to.

    This class also contains the functions that corresponds to
    each of the possible user interactions.

    """
    def __init__(self, db_path, player_controller):
        self.player = player_controller
        self.DB_path = db_path
        self.kb_api = KnowledgeBaseAPI(self.DB_path)

    def __call__(self, parser, text):
        """Evaluates a parse tree that was generated by
        the NLP layer from the user input.

        """
        try:
            nltk_parse_tree = parser(text)

            self._evaluate(nltk_parse_tree)
        except:
            self.player.respond("I'm sorry, I don't understand.")

    @property
    def unary_commands(self):
        """A unary command operates on just one set of
        entities when evaluating a user's input
        (encoded as a Parse Tree).

        Returns:
            A mapping that stores signifiers of user's
            intent, along with the `commands` and the functions
            that they map to.

        """
        return OrderedDict([
            ('query_similar_entities', (['like', 'similar'],
                                        self._query_similar_entities)),
            ('query_songs_by_artist', (['songs by',
                                        'by'], self._query_songs_by_artist)),
            ('query_artist_by_song', (['artist'], self._query_artist_by_song)),
        ])

    @property
    def terminal_commands(self):
        """A terminal commands is the final command that is
        executed on the results of evaluating a user's input
        (encoded as a Parse Tree).

        Returns:
            A mapping that stores signifiers of user's
            intent, along with the `commands` and the functions
            that they map to.

        """
        return OrderedDict([
            ('query_commands', (['hi', 'how', 'hello'], self._query_commands)),
            ('control_stop', (['stop'], self._control_stop)),
            ('control_pause', (['pause'], self._control_pause)),
            ('control_play', (['start', 'play'], self._control_play)),
            ('query_info', (['who', 'what'], self._query_info)),
            ('control_forward', (['skip', 'next'], self._control_skip)),
        ])

    @property
    def binary_commands(self):
        """A binary command operates on two sets of
        entities when evaluating a user's input
        (encoded as a Parse Tree).

        Returns:
            A mapping that stores signifiers of user's
            intent, along with the `commands` and the functions
            that they map to.

        """
        return OrderedDict([
            ('control_union', (['or', 'and'], self._control_union)),
        ])

    @property
    def keywords(self):
        """For each type of command, generate a dictionary
        of keywords that map to the specific command (intent).

        """
        return {
            "unary": {k: v[0]
                      for k, v in self.unary_commands.items()},
            "terminal": {k: v[0]
                         for k, v in self.terminal_commands.items()},
            "binary": {k: v[0]
                       for k, v in self.binary_commands.items()},
        }

    @property
    def actions(self):
        return {k: v[1] for k, v in self.unary_commands.items()}

    @property
    def command_precedence(self):
        return [k for k, v in self.unary_commands.items()]

    def _query_commands(self):
        # TODO: make this work
        """Terminal command.

        Reports on possible commands and interactions.

        """
        self.player.respond("Hi there! Ask me to play artists or songs. "
                            "I can also find songs that are similar to other "
                            "artists.")

    def _control_play(self, entities: List[str]):
        """Terminal command.

        Plays the subjects specified.

        Args:
            entities: A list of Entities.

       TODO:
           - query db for entity
               - if it's of type song, play it
               - if it's anything else, get the songs associated with
                 it and play in DESC order
        """
        if entities:
            self.player.play(entities)
        else:
            self.player.respond("I'm sorry, I couldn't find that for you.")

    def _control_stop(self):
        """Terminal Command

        Stops the player.

        """
        self.player.stop()

    def _control_pause(self):
        """Terminal Command

        Pauses the player.

        """
        self.player.pause()

    def _control_skip(self):
        """Terminal Command

        Skips the song.

        """
        self.player.skip()

    def _control_union(self, entities_1: List[str], entities_2: List[str]):
        """Binary Command

            Returns the intersection of the two parameters.

            Args:
                entities_1: A list of Entities.
                entities_2: A list of Entities.

            Returns: A list of Entities

            """
        return list(set(entities_1).union(set(entities_2)))

    def _query_info(self, entities: List[str]):
        """Terminal Command

        Queries the artists for each Entity.

        Args:
            entities: A list of Entities.

        """
        self.player.respond(entities)

    def _query_similar_entities(self, entities: List[str]):
        """Unary Command

        Returns a list of related Entities for
        all Entities in the parameters.

        Args:
            entities: A list of Entities.

        """
        similar_entities = []
        for e in entities:
            # Don't return the artists given in the
            # parms (for the case where there are
            # multiple artists and they are related
            # to each other).
            similar_entities += [
                ent for ent in self.kb_api.get_related_entities(e)
                if ent not in entities
            ]

        return similar_entities

    def _query_songs_by_artist(self, entities: List[str]):
        """Unary Command

        Returns a list of Artists for
        all Entities in the parameters.

        Args:
            entities: A list of Entities.

        """
        artists = []
        for e in entities:
            artists += self.kb_api.get_songs_by_artist(e)

        return artists

    def _query_artist_by_song(self, entities: List[str]):
        """Unary Command

        Returns a list of Artists for
        all Entities in the parameters.

        Args:
            entities: A list of Entities.

        """
        artists = []
        for e in entities:
            artists += [
                song.get('artist_name')
                for song in self.kb_api.get_song_data(e)
            ]

        return artists

    def _evaluate(self, tree: nltk.tree.Tree):
        """This function will evaluate the parse tree
        generated by the NLP layer.  It recursively
        evaluates the tree from the top down.  The
        tree starts with a `terminal command` (ie play)
        at the root.  The tree's leaves of the tree are
        composed of `entities`, and the inner nodes are
        composed of operations on those entities.  This
        function evaluates the operations and entities
        into a single result, which is then given to the
        `terminal command` to act upon.

        """
        if tree.label() == "Root":
            if len(tree) == 1:
                func = self._evaluate(tree[0])
                func()
            else:
                func = self._evaluate(tree[0])
                result = self._evaluate(tree[1])
                func(result)
            return
        elif tree.label() == "Result":
            if tree[0].label() == "Entity":
                return self._evaluate(tree[0])
            if tree[0].label() == "Unary_Command":
                func = self._evaluate(tree[0])
                result = self._evaluate(tree[1])
                return func(result)
            if tree[1].label() == "Binary_Command":
                result_left = self._evaluate(tree[0])
                func = self._evaluate(tree[1])
                result_right = self._evaluate(tree[2])
                return func(result_left, result_right)
        elif tree.label() == "Unary_Command":
            func = self.unary_commands.get(tree[0])[1]
            return func
        elif tree.label() == "Terminal_Command":
            func = self.terminal_commands.get(tree[0])[1]
            return func
        elif tree.label() == "Binary_Command":
            func = self.binary_commands.get(tree[0])[1]
            return func
        elif tree.label() == "Entity":
            return [tree[0]]

        print(
            "Error: CFG label rule not defined in "
            "evaluateEngine#self._evaluate",
            file=sys.stderr)
コード例 #15
0
 def __init__(self, db_path, player_controller):
     self.player = player_controller
     self.DB_path = db_path
     self.kb_api = KnowledgeBaseAPI(self.DB_path)
class BOWParser:
    """
    This layer stores the NLP specific tooling and logic.

    References:
        https://stackoverflow.com/questions/42322902/how-to-get-parse-tree-using-python-nltk
        https://www.nltk.org/book/ch08.html
    """
    extra_stopwords = {'s', 'hey', 'want', 'you'}

    def __init__(self, db_path, keywords):
        # Download the stopwords if necessary.
        try:
            nltk.data.find('corpora/stopwords')
        except LookupError:
            nltk.download('stopwords')

        self.commands = keywords
        self.kb_api = KnowledgeBaseAPI(db_path)
        self.db_nouns = self.kb_api.get_all_music_entities()

    def _get_stop_words(self):
        # Remove all keywords from stopwords
        stop_words = set(stopwords.words('english'))
        stop_words |= BOWParser.extra_stopwords
        for _, words in self.commands.items():
            for word in words:
                try:
                    stop_words.remove(word)
                except:
                    pass
        return stop_words

    def _gen_patterns(self):
        # Generate RegEx patterns from keywords
        patterns = {}
        for intent, keys in self.commands.items():
            patterns[intent] = re.compile(r'\b'+r'\b|\b'.join(keys)+r'\b')
        return patterns

    def __call__(self, msg: str):
        # Identify the first subject from the database that matches.
        subjects = []
        for noun in self.db_nouns:
            if noun.lower() in msg.lower():
                pattern = re.compile(noun, re.IGNORECASE)
                msg = pattern.sub('', msg)
                subjects.append(noun.strip())

        # Remove punctuation from the string
        msg = re.sub(r"[,.;@#?!&$']+\ *",
                     " ",
                     msg,
                     flags=re.VERBOSE)

        # Clean the stopwords from the input.
        stop_words = self._get_stop_words()
        clean_msg = ' '.join([word for word in msg.lower().split(' ')
                              if word not in stop_words])

        # Parse the keywords from the filtered input.
        patterns = self._gen_patterns()
        intents = []
        for intent, pattern in patterns.items():
            sub_msg = re.sub(pattern, '', clean_msg)
            if sub_msg != clean_msg:
                intents.append(intent)
                clean_msg = sub_msg

        remaining_text = clean_msg.strip()
        return subjects, intents, remaining_text
コード例 #17
0
class BOWEvalEngine:
    """This class stores the possible interactions between the
    user and the system, and the logic to act on them.

    The `intents` property contains a mapping that stores
    signifiers of user's intent, along with the functions
    that they map to.

    This class also contains the functions that corresponds to
    each of the possible user interactions.

    """
    def __init__(self, db_path, player_controller):
        self.player = player_controller
        self.DB_path = db_path
        self.kb_api = KnowledgeBaseAPI(self.DB_path)

    def __call__(
        self,
        subjects: List[str] = None,
        commands: List[str] = None,
        remaining_text: str = None,
    ):
        """Initiates a sequence of commands that will act
        on the subject parameters in order of command
        precedence.

        Args:
            subjects: List of subjects in the database.
            commands: Names of commands to enact.
            remaining_text: Any remaining text not parsed by the NLP layer.

        Returns:

        """
        # Sort the commands by precedence.
        sorted_commands = []
        for c in self.command_precedence:
            if c in commands:
                sorted_commands.append(c)

        # Call the highest precedence command.
        self._next_operation(
            subjects=subjects,
            commands=sorted_commands,
            remaining_text=remaining_text,
        )

    @property
    def intents(self):
        """Returns a mapping that stores signifiers of user's
        intent, along with the `commands` and the functions
        that they map to.

        The ordering of these commands matters, as it is used to
        store the precedence of operations.

        Returns (OrderedDict): A mapping of intentions, their
            corresponding signifiers (as keywords), and function
            handlers.

        """
        return OrderedDict([
            ('query_commands', (['hi', 'how', 'hello'], self._query_commands)),
            ('control_stop', (['stop'], self._control_stop)),
            ('control_pause', (['pause'], self._control_pause)),
            ('control_forward', (['skip', 'next'], self._control_skip)),
            ('query_similar_entities', (['like', 'similar'],
                                        self._query_similar_entities)),
            ('control_play', (['start', 'play'], self._control_play)),
            ('query_artist', (['who', 'artist'], self._query_artist)),
            ('default', ([], self._default)),
        ])

    @property
    def keywords(self):
        """For each type of command, generate a dictionary
        of keywords that map to the specific command (intent)

        """
        return {k: v[0] for k, v in self.intents.items()}

    @property
    def actions(self):
        return {k: v[1] for k, v in self.intents.items()}

    @property
    def command_precedence(self):
        return [k for k, v in self.intents.items()]

    def _query_commands(
        self,
        subjects: List[str] = None,
        commands: List[str] = None,
        remaining_text: str = None,
        response_msg: str = None,
    ):
        """Terminal command. Reports on possible commands and interactions.

        """
        if subjects:
            # Handle question about specific `subjects`.
            self.player.respond("I'm sorry, I can't answer that one")
        else:
            # Handle genery inquiries.
            self.player.respond(
                "Hi there! Ask me to play artists or songs. "
                "I can also find songs that are similar to other "
                "artists.")

    def _control_play(
        self,
        subjects: List[str] = None,
        commands: List[str] = None,
        remaining_text: str = None,
        response_msg: str = None,
    ):
        """Terminal command. Plays the subjects specified.

       TODO:
           - query db for entity
               - if it's of type song, play it
               - if it's anything else, get the songs associated with
                 it and play in DESC order

        """
        if subjects:
            self.player.play(subjects)
        elif remaining_text:
            self.player.respond("I'm sorry, I couldn't find that for you.")
        else:
            self.player.respond('Resuming the current song')

    def _control_stop(
        self,
        subjects: List[str] = None,
        commands: List[str] = None,
        remaining_text: str = None,
        response_msg: str = None,
    ):
        self.player.stop(subjects)

    def _control_pause(
        self,
        subjects: List[str] = None,
        commands: List[str] = None,
        remaining_text: str = None,
        response_msg: str = None,
    ):
        self.player.pause(subjects)

    def _control_skip(
        self,
        subjects: List[str] = None,
        commands: List[str] = None,
        remaining_text: str = None,
        response_msg: str = None,
    ):
        # TODO: Add number parsing for "skip forward 2 songs".
        self.player.skip(subjects)

    def _control_intersection(
        self,
        subjects: List[str] = None,
        commands: List[str] = None,
        remaining_text: str = None,
        response_msg: str = None,
    ):
        # TODO
        self.player.skip(subjects)

    def _control_union(
        self,
        subjects: List[str] = None,
        commands: List[str] = None,
        remaining_text: str = None,
        response_msg: str = None,
    ):
        # TODO
        self.player.skip(subjects)

    def _query_artist(
        self,
        subjects: List[str] = None,
        commands: List[str] = None,
        remaining_text: str = None,
        response_msg: str = None,
    ):
        # TODO: implement for fetching current state, ie "who is this artist?"
        self.player.respond(subjects)

    def _query_similar_entities(
        self,
        subjects: List[str] = None,
        commands: List[str] = None,
        remaining_text: str = None,
        response_msg: str = None,
    ):
        similar_entities = []
        for e in subjects:
            similar_entities += self.kb_api.get_related_entities(e)

        if not similar_entities:
            self.player.respond("I'm sorry, I couldn't find that for you.")
        else:
            self._next_operation(
                subjects=similar_entities,
                commands=commands,
                remaining_text=remaining_text,
                response_msg=response_msg,
            )

    def _default(
        self,
        subjects: List[str] = None,
        commands: List[str] = None,
        remaining_text: str = None,
        response_msg: str = None,
    ):
        self.player.respond("I'm sorry, I don't understand.")

    def _next_operation(
        self,
        subjects: List[str] = None,
        commands: List[str] = None,
        remaining_text: str = None,
        response_msg: str = None,
    ):
        """Calls the next function in the commands parameter.

        """
        next_command_name = commands.pop(0) if commands else self._default
        next_func = self.actions.get(next_command_name, self._default)
        next_func(
            subjects=subjects,
            commands=commands,
            remaining_text=remaining_text,
            response_msg=response_msg,
        )
コード例 #18
0
ファイル: test_db_schema.py プロジェクト: okjuan/muze
class TestDbSchema(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        DB_path = test_db_utils.create_and_populate_db()
        self.kb_api = KnowledgeBaseAPI(dbName=DB_path)

    @classmethod
    def tearDownClass(self):
        test_db_utils.remove_db()

    def test_rejects_unknown_entity(self):
        res = self.kb_api.connect_entities("Unknown Entity", "Justin Timberlake", "similar to", 0)
        self.assertEqual(res, False,
            "Expected attempt to connect an unknown entity to fail.")

        res = self.kb_api.get_related_entities("Unknown Entity")
        self.assertEqual(res, [])

    def test_rejects_score_out_of_range(self):
        res = self.kb_api.connect_entities("Justin Timberlake", "Justin Timberlake", "similar to", -1)
        self.assertEqual(res, False,
            "Expected attempt to connect entities with score out-of-range to fail.")

        res = self.kb_api.get_related_entities("Justin Timberlake")
        self.assertEqual(len(res), 0)

    def test_rejects_duplicate_edge(self):
        res = self.kb_api.connect_entities("Justin Bieber", "Justin Timberlake", "similar to", 1)
        self.assertEqual(res, False,
            "Expected attempt to add a duplicate edge to fail.")

    def test_rejects_song_with_duplicate_id(self):
        new_song_id = self.kb_api.add_song(
            "some song",
            "Justin Bieber",
            spotify_uri="spotify:track:Despacito", # already in DB
        )
        self.assertEqual(new_song_id, None, "Expected song addition to fail.")

    def test_edges_not_null_constraints(self):
        res = self.kb_api.connect_entities(None, "Justin Timberlake", "similar to", 1)
        self.assertEqual(res, False,
            "Expected 'None' value for artist to be rejected.")

        res = self.kb_api.connect_entities("U2", "U2", None, 1)
        self.assertEqual(res, False,
            "Expected 'None' value for edge type to be rejected.")

        res = self.kb_api.connect_entities("U2", "U2", "similar to", None)
        self.assertEqual(res, False,
            "Expected 'None' value for edge score to be rejected.")

    def test_entities_not_null_constraints(self):
        res = self.kb_api.add_artist(None)
        self.assertEqual(res, None,
            "Expected 'None' value for artist to be rejected.")

        res = self.kb_api.add_song("Song name", None)
        self.assertEqual(res, None,
            "Expected 'None' value for song to be rejected.")

        res = self.kb_api.add_song(None, "Artist name")
        self.assertEqual(res, None,
            "Expected 'None' value for artist to be rejected.")

        node_id = self.kb_api._add_node(None, "artist")
        self.assertEqual(node_id, None,
            "Expected 'None' value for entity name to be rejected.")

        node_id = self.kb_api._add_node("Some entity", None)
        self.assertEqual(node_id, None,
            "Expected 'None' value for entity type to be rejected.")

    def test_song_audio_features_range_constraints(self):
        # NOTE: all values are at their upper limit so that we can just add 1
        #   to them and test that the schema constraints reject their addition
        #   (mode is left out because it is not numerical)
        audio_features = dict(
            acousticness=1, danceability=1, energy=1, instrumentalness=1, liveness=1,
            loudness=1, speechiness=1, valence=1, tempo=999, musical_key=11, time_signature=7,
        )

        for feature_name in self.kb_api.song_audio_features:
            if feature_name == 'mode':
                continue

            # push value out of valid range
            audio_features[feature_name] += 1
            new_song_id = self.kb_api.add_song(
                "Song by Justin Bieber",
                "Justin Bieber",
                audio_features=audio_features,
            )
            self.assertEqual(
                new_song_id,
                None,
                f"Expected song with {feature_name} out of range to be rejected.",
            )
            # reset value back into valid range
            audio_features[feature_name] -= 1

    def test_song_mode_constraint(self):
        new_song_id = self.kb_api.add_song(
            "Song by Justin Bieber",
            "Justin Bieber",
            audio_features=dict(mode='not major or minor'),
        )
        self.assertEqual(new_song_id, None, "Expected song with invalid mode to be rejected.")