コード例 #1
0
ファイル: test_core.py プロジェクト: MTG/mirdata
def test_dataset():
    dataset = mirdata.Dataset("guitarset")
    assert isinstance(dataset, core.Dataset)

    dataset = mirdata.Dataset("rwc_jazz")
    assert isinstance(dataset, core.Dataset)

    dataset = mirdata.Dataset("ikala")
    assert isinstance(dataset, core.Dataset)

    print(dataset)  # test that repr doesn't fail
コード例 #2
0
ファイル: test_loaders.py プロジェクト: MTG/mirdata
def test_load_and_trackids():
    for dataset_name in DATASETS:
        data_home = os.path.join("tests/resources/mir_datasets", dataset_name)
        dataset = mirdata.Dataset(dataset_name, data_home=data_home)
        dataset_default = mirdata.Dataset(dataset_name, data_home=None)
        try:
            track_ids = dataset.track_ids
        except:
            assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])

        assert type(
            track_ids) is list, "{}.track_ids() should return a list".format(
                dataset_name)
        trackid_len = len(track_ids)

        # if the dataset has tracks, test the loaders
        if dataset._track_object is not None:

            try:
                choice_track = dataset.choice_track()
            except:
                assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
            assert isinstance(
                choice_track, core.Track
            ), "{}.choice_track must return an instance of type core.Track".format(
                dataset_name)

            try:
                dataset_data = dataset.load_tracks()
            except:
                assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])

            assert (
                type(dataset_data) is dict
            ), "{}.load should return a dictionary".format(dataset_name)
            assert (
                len(dataset_data.keys()) == trackid_len
            ), "the dictionary returned {}.load() does not have the same number of elements as {}.track_ids()".format(
                dataset_name, dataset_name)

            try:
                dataset_data_default = dataset_default.load_tracks()
            except:
                assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])

            assert (
                type(dataset_data_default) is dict
            ), "{}.load should return a dictionary".format(dataset_name)
            assert (
                len(dataset_data_default.keys()) == trackid_len
            ), "the dictionary returned {}.load() does not have the same number of elements as {}.track_ids()".format(
                dataset_name, dataset_name)
コード例 #3
0
ファイル: test_full_dataset.py プロジェクト: MTG/mirdata
def dataset(test_dataset):
    if test_dataset == "":
        return None
    elif test_dataset not in mirdata.DATASETS:
        raise ValueError("{} is not a dataset in mirdata".format(test_dataset))
    data_home = os.path.join("tests/resources/mir_datasets_full", test_dataset)
    return mirdata.Dataset(test_dataset, data_home)
コード例 #4
0
ファイル: test_loaders.py プロジェクト: MTG/mirdata
def test_cite():
    for dataset_name in DATASETS:
        dataset = mirdata.Dataset(dataset_name)
        text_trap = io.StringIO()
        sys.stdout = text_trap
        dataset.cite()
        sys.stdout = sys.__stdout__
コード例 #5
0
ファイル: test_loaders.py プロジェクト: MTG/mirdata
def test_load_methods():
    for dataset_name in DATASETS:
        dataset = mirdata.Dataset(dataset_name)
        all_methods = dir(dataset)
        load_methods = [
            getattr(dataset, m) for m in all_methods if m.startswith("load_")
        ]
        for load_method in load_methods:
            method_name = load_method.__name__

            # skip default methods
            if method_name == "load_tracks":
                continue
            params = [
                p for p in signature(load_method).parameters.values()
                if p.default == inspect._empty
            ]  # get list of parameters that don't have defaults

            # add to the EXCEPTIONS dictionary above if your load_* function needs
            # more than one argument.
            if dataset_name in EXCEPTIONS and method_name in EXCEPTIONS[
                    dataset_name]:
                extra_params = EXCEPTIONS[dataset_name][method_name]
                with pytest.raises(IOError):
                    load_method("a/fake/filepath", **extra_params)
            else:
                with pytest.raises(IOError):
                    load_method("a/fake/filepath")
コード例 #6
0
ファイル: test_loaders.py プロジェクト: MTG/mirdata
def test_multitracks():
    data_home_dir = "tests/resources/mir_datasets"

    for dataset_name in DATASETS:
        dataset = mirdata.Dataset(dataset_name)

        # TODO this is currently an opt-in test. Make it an opt out test
        # once #265 is addressed
        if dataset_name in CUSTOM_TEST_MTRACKS:
            mtrack_id = CUSTOM_TEST_MTRACKS[dataset_name]
        else:
            # there are no multitracks
            continue

        try:
            mtrack_default = dataset.MultiTrack(mtrack_id)
        except:
            assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])

        # test data home specified
        data_home = os.path.join(data_home_dir, dataset_name)
        dataset_specific = mirdata.Dataset(dataset_name, data_home=data_home)
        try:
            mtrack_test = dataset_specific.MultiTrack(mtrack_id,
                                                      data_home=data_home)
        except:
            assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])

        assert isinstance(
            mtrack_test, core.MultiTrack
        ), "{}.MultiTrack must be an instance of type core.MultiTrack".format(
            dataset_name)

        assert hasattr(
            mtrack_test,
            "to_jams"), "{}.MultiTrack must have a to_jams method".format(
                dataset_name)

        # Validate JSON schema
        try:
            jam = mtrack_test.to_jams()
        except:
            assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])

        assert jam.validate(
        ), "Jams validation failed for {}.MultiTrack({})".format(
            dataset_name, mtrack_id)
コード例 #7
0
ファイル: test_loaders.py プロジェクト: MTG/mirdata
def test_download(mocker):
    for dataset_name in DATASETS:
        print(dataset_name)
        dataset = mirdata.Dataset(dataset_name)

        # test parameters & defaults
        assert callable(
            dataset._download_fn), "{}.download is not callable".format(
                dataset_name)
        params = signature(dataset._download_fn).parameters
        expected_params = [
            "save_dir",
            "remotes",
            "partial_download",
            "info_message",
            "force_overwrite",
            "cleanup",
        ]
        assert set(params) == set(
            expected_params), "{}.download must have parameters {}".format(
                dataset_name, expected_params)

        # check that the download method can be called without errors
        if dataset._remotes != {}:
            mock_downloader = mocker.patch.object(dataset, "_remotes")
            if dataset_name not in DOWNLOAD_EXCEPTIONS:
                try:
                    dataset.download()
                except:
                    assert False, "{}: {}".format(dataset_name,
                                                  sys.exc_info()[0])

                mocker.resetall()

            # check that links are online
            for key in dataset._remotes:
                # skip this test if it's in known issues
                if dataset_name in KNOWN_ISSUES and key in KNOWN_ISSUES[
                        dataset_name]:
                    continue

                url = dataset._remotes[key].url
                try:
                    request = requests.head(url)
                    assert request.ok, "Link {} for {} does not return OK".format(
                        url, dataset_name)
                except requests.exceptions.ConnectionError:
                    assert False, "Link {} for {} is unreachable".format(
                        url, dataset_name)
                except:
                    assert False, "{}: {}".format(dataset_name,
                                                  sys.exc_info()[0])
        else:
            try:
                dataset.download()
            except:
                assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
コード例 #8
0
def update_index(all_indexes):
    """Function to update indexes to new format.
    Parameters
    ----------
    all_indexes (list): list of all current dataset indexes


    """

    for index_name in tqdm(all_indexes):
        module = index_name.replace('_index.json', '')

        # load old index
        old_index = mirdata.Dataset(module)._index

        # avoid modifying when running multiple times
        if 'tracks' in old_index.keys():
            old_index = old_index['tracks']

        data_home = mirdata.Dataset(module).data_home

        # get metadata checksum
        metadata_files = get_metadata_paths(module)
        metadata_checksums = None

        if metadata_files is not None:
            metadata_checksums = {key: [metadata_files[key],
                                        md5(os.path.join(data_home, metadata_files[key]))]
                                  for key in metadata_files.keys()}

        # get version of dataset
        version = get_dataset_version(module)

        # Some datasets have a single metadata file, some have multiple.
        # The computation of the checksum should be customized in the make_index
        # of each dataset. This is a patch to convert previous indexes to the new format.
        new_index = {'version': version,
                     'tracks': old_index}

        if metadata_files is not None:
            new_index['metadata'] =  metadata_checksums

        with open(os.path.join(INDEXES_PATH, index_name), 'w') as fhandle:
            json.dump(new_index, fhandle, indent=2)
コード例 #9
0
def test_track_load(dataset_names):
    """Function to test all loaders work and indexes are fine (run locally)
    Parameters
    ----------
    dataset_names (list): list of dataset names

    """
    for module in dataset_names:
        dataset = mirdata.Dataset(module)
        dataset.load_tracks()
コード例 #10
0
ファイル: test_loaders.py プロジェクト: MTG/mirdata
def test_validate(skip_local):
    for dataset_name in DATASETS:
        data_home = os.path.join("tests/resources/mir_datasets", dataset_name)
        dataset = mirdata.Dataset(dataset_name, data_home=data_home)
        try:
            dataset.validate()
        except:
            assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])

        try:
            dataset.validate(verbose=False)
        except:
            assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])

        dataset_default = mirdata.Dataset(dataset_name, data_home=None)
        try:
            dataset_default.validate(verbose=False)
        except:
            assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])
コード例 #11
0
def test_index(dataset_names):
    """ Test if updated indexes are as expected.
    Parameters
    ----------
    dataset_names (list): list of dataset names

    """

    mandatory_keys = ['version']
    for module in dataset_names:
        index = mirdata.Dataset(module)._index
        assert type(index['tracks']) == dict
        assert set(mandatory_keys) <= set([*index.keys()])
コード例 #12
0
def main():

    print(DATASETS)
    # Download metadata from all datasets for computing metadata checksums
    for module in DATASETS:
        if module not in ['dali', 'beatles', 'groove_midi']:
            dataset = mirdata.Dataset(module)
            if dataset._remotes is not None:
                dataset.download(partial_download=['metadata' if 'metadata' in dataset._remotes
                                                   else key for key in dataset._remotes if key is not 'audio'
                                                   and 'training' not in key and 'testing' not in key])

    # Update index to new format
    print('Updating indexes...\n')
    update_index(ALL_INDEXES)
    # Check new indexes are shaped as expected
    print('Quick check on datasets...\n')
    test_index(DATASETS)
    test_track_load(DATASETS)
コード例 #13
0
ファイル: test_loaders.py プロジェクト: MTG/mirdata
def test_dataset_attributes():
    for dataset_name in DATASETS:
        dataset = mirdata.Dataset(dataset_name)
        assert (dataset.name == dataset_name
                ), "{}.dataset attribute does not match dataset name".format(
                    dataset_name)
        assert (dataset.bibtex is not None
                ), "No BIBTEX information provided for {}".format(dataset_name)
        assert (isinstance(dataset._remotes, dict) or dataset._remotes is None
                ), "{}.REMOTES must be a dictionary".format(dataset_name)
        assert isinstance(
            dataset._index,
            dict), "{}.DATA is not properly set".format(dataset_name)
        assert (isinstance(dataset._download_info, str)
                or dataset._download_info is None
                ), "{}.DOWNLOAD_INFO must be a string".format(dataset_name)
        assert type(dataset._track_object) == type(
            core.Track), "{}.Track must be an instance of core.Track".format(
                dataset_name)
        assert callable(
            dataset._download_fn), "{}._download is not a function".format(
                dataset_name)
        assert dataset.readme != "", "{} has no module readme".format(
            dataset_name)
コード例 #14
0
ファイル: test_loaders.py プロジェクト: MTG/mirdata
def test_track():
    data_home_dir = "tests/resources/mir_datasets"

    for dataset_name in DATASETS:

        data_home = os.path.join(data_home_dir, dataset_name)
        dataset = mirdata.Dataset(dataset_name, data_home=data_home)
        dataset_default = mirdata.Dataset(dataset_name, data_home=None)

        # if the dataset doesn't have a track object, make sure it raises a value error
        # and move on to the next dataset
        if dataset._track_object is None:
            with pytest.raises(NotImplementedError):
                dataset.track("~faketrackid~?!")
            continue

        if dataset_name in CUSTOM_TEST_TRACKS:
            trackid = CUSTOM_TEST_TRACKS[dataset_name]
        else:
            trackid = dataset.track_ids[0]

        try:
            track_default = dataset_default.track(trackid)
        except:
            assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])

        assert track_default._data_home == os.path.join(
            DEFAULT_DATA_HOME, dataset.name
        ), "{}: Track._data_home path is not set as expected".format(
            dataset_name)

        # test data home specified
        try:
            track_test = dataset.track(trackid)
        except:
            assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])

        assert isinstance(
            track_test, core.Track
        ), "{}.track must be an instance of type core.Track".format(
            dataset_name)

        assert hasattr(
            track_test,
            "to_jams"), "{}.track must have a to_jams method".format(
                dataset_name)

        # Validate JSON schema
        try:
            jam = track_test.to_jams()
        except:
            assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])

        assert jam.validate(
        ), "Jams validation failed for {}.track({})".format(
            dataset_name, trackid)

        # will fail if something goes wrong with __repr__
        try:
            text_trap = io.StringIO()
            sys.stdout = text_trap
            print(track_test)
            sys.stdout = sys.__stdout__
        except:
            assert False, "{}: {}".format(dataset_name, sys.exc_info()[0])

        with pytest.raises(ValueError):
            dataset.track("~faketrackid~?!")