예제 #1
0
def test_validate_and_download_fails_without_downloader():
    expected_files = ["file1.txt"]

    with tempfile.TemporaryDirectory() as tmpdir:
        # Call validate_and_download and assert that file not found error
        # is raised when no files are actually downloaded.
        with pytest.raises(FileNotFoundError):
            validate_and_download(tmpdir,
                                  expected_files,
                                  downloader=None,
                                  validate_checksums=False)
예제 #2
0
def test_validate_and_download_succeeds_without_downloader():
    expected_files = ["file1.txt"]
    actual_files = ["file1.txt"]

    with tempfile.TemporaryDirectory() as tmpdir:
        # Create actual files.
        for actual_file in actual_files:
            Path(os.path.join(tmpdir, actual_file)).touch()

        # Call validate_and_download and make sure nothing gets raised.
        validate_and_download(tmpdir,
                              expected_files,
                              downloader=None,
                              validate_checksums=False)
예제 #3
0
    def __init__(self,
                 location: str = dataset_dir("MSLR10K"),
                 split: str = "train",
                 fold: int = 1,
                 normalize: bool = True,
                 filter_queries: Optional[bool] = None,
                 download: bool = True,
                 validate_checksums: bool = True):
        """
        Args:
            location: Directory where the dataset is located.
            split: The data split to load ("train", "test" or "vali")
            fold: Which data fold to load (1...5)
            normalize: Whether to perform query-level feature
                normalization.
            filter_queries: Whether to filter out queries that
                have no relevant items. If not given this will filter queries
                for the test set but not the train set.
            download: Whether to download the dataset if it does not
                exist.
            validate_checksums: Whether to validate the dataset files
                via sha256.
        """
        # Check if specified split and fold exists.
        if split not in MSLR10K.splits.keys():
            raise ValueError("unrecognized data split '%s'" % str(split))

        if fold not in MSLR10K.per_fold_expected_files.keys():
            raise ValueError("unrecognized data fold '%s'" % str(fold))

        # Validate dataset exists and is correct, or download it.
        validate_and_download(
            location=location,
            expected_files=MSLR10K.per_fold_expected_files[fold],
            downloader=MSLR10K.downloader if download else None,
            validate_checksums=validate_checksums)

        # Only filter queries on non-train splits.
        if filter_queries is None:
            filter_queries = False if split == "train" else True

        # Initialize the dataset.
        datafile = os.path.join(location, "Fold%d" % fold,
                                MSLR10K.splits[split])
        super().__init__(file=datafile,
                         sparse=False,
                         normalize=normalize,
                         filter_queries=filter_queries,
                         zero_based="auto")
예제 #4
0
def test_validate_and_download_fails_after_download_fails():
    downloader = mock.MagicMock()
    expected_files = ["file1.txt"]

    with tempfile.TemporaryDirectory() as tmpdir:
        # Call validate_and_download and assert that file not found errors
        # is raised when no files are actually downloaded.
        with pytest.raises(FileNotFoundError):
            validate_and_download(tmpdir,
                                  expected_files,
                                  downloader=downloader,
                                  validate_checksums=False)

            # Assert download was called
            downloader.download.assert_called_once_with(tmpdir)
예제 #5
0
def test_validate_and_download_calls_download():
    downloader = mock.MagicMock()
    expected_files = ["file1.txt"]
    actual_files = ["file1.txt"]

    def create_files_side_effect(location):
        for actual_file in actual_files:
            Path(os.path.join(location, actual_file)).touch()

    downloader.download.side_effect = create_files_side_effect

    with tempfile.TemporaryDirectory() as tmpdir:
        # Call validate_and_download and assert download gets triggered.
        validate_and_download(tmpdir,
                              expected_files,
                              downloader=downloader,
                              validate_checksums=False)
        downloader.download.assert_called_once_with(tmpdir)
예제 #6
0
def test_validate_and_download_skips_download():
    downloader = mock.MagicMock()
    expected_files = ["file1.txt"]
    actual_files = ["file1.txt"]

    def create_files_side_effect(location):
        for actual_file in actual_files:
            Path(os.path.join(location, actual_file)).touch()

    downloader.download.side_effect = create_files_side_effect

    with tempfile.TemporaryDirectory() as tmpdir:
        # Create files already, so that validate_and_download can skip the
        # download call.
        create_files_side_effect(tmpdir)

        # Call validate_and_download and assert download was not triggered.
        validate_and_download(tmpdir,
                              expected_files,
                              downloader=downloader,
                              validate_checksums=False)
        downloader.download.assert_not_called()
예제 #7
0
    def __init__(self, location: str = dataset_dir("example3"),
                 split: str = "train",
                 normalize: bool = True, filter_queries: Optional[bool] = None,
                 download: bool = True, validate_checksums: bool = True):
        """
        Args:
            location: Directory where the dataset is located.
            split: The data split to load ("train" or "test")
            normalize: Whether to perform query-level feature
                normalization.
            filter_queries: Whether to filter out queries that
                have no relevant items. If not given this will filter queries
                for the test set but not the train set.
            download: Whether to download the dataset if it does not
                exist.
            validate_checksums: Whether to validate the dataset files
                via sha256.
        """
        # Check if specified split exists.
        if split not in Example3.splits.keys():
            raise ValueError("unrecognized data split '%s'" % split)

        # Validate dataset exists and is correct, or download it.
        validate_and_download(
            location=location,
            expected_files=Example3.expected_files,
            downloader=Example3.downloader if download else None,
            validate_checksums=validate_checksums)

        # Only filter queries on non-train splits.
        if filter_queries is None:
            filter_queries = False if split == "train" else True

        # Initialize the dataset.
        super().__init__(file=os.path.join(location, Example3.splits[split]),
                         sparse=False, normalize=normalize,
                         filter_queries=filter_queries, zero_based="auto")