コード例 #1
0
ファイル: mslr10k.py プロジェクト: ahoyosid/pytorchltr
    def __init__(self,
                 location: str = dataset_dir("MSLR10K"),
                 split: str = "train",
                 fold: int = 1,
                 normalize: bool = True,
                 filter_queries: Optional[bool] = None,
                 download: bool = True,
                 validate_checksums: bool = True):
        """
        Args:
            location: Directory where the dataset is located.
            split: The data split to load ("train", "test" or "vali")
            fold: Which data fold to load (1...5)
            normalize: Whether to perform query-level feature
                normalization.
            filter_queries: Whether to filter out queries that
                have no relevant items. If not given this will filter queries
                for the test set but not the train set.
            download: Whether to download the dataset if it does not
                exist.
            validate_checksums: Whether to validate the dataset files
                via sha256.
        """
        # Check if specified split and fold exists.
        if split not in MSLR10K.splits.keys():
            raise ValueError("unrecognized data split '%s'" % str(split))

        if fold not in MSLR10K.per_fold_expected_files.keys():
            raise ValueError("unrecognized data fold '%s'" % str(fold))

        # Validate dataset exists and is correct, or download it.
        validate_and_download(
            location=location,
            expected_files=MSLR10K.per_fold_expected_files[fold],
            downloader=MSLR10K.downloader if download else None,
            validate_checksums=validate_checksums)

        # Only filter queries on non-train splits.
        if filter_queries is None:
            filter_queries = False if split == "train" else True

        # Initialize the dataset.
        datafile = os.path.join(location, "Fold%d" % fold,
                                MSLR10K.splits[split])
        super().__init__(file=datafile,
                         sparse=False,
                         normalize=normalize,
                         filter_queries=filter_queries,
                         zero_based="auto")
コード例 #2
0
    def __init__(self, location: str = dataset_dir("example3"),
                 split: str = "train",
                 normalize: bool = True, filter_queries: Optional[bool] = None,
                 download: bool = True, validate_checksums: bool = True):
        """
        Args:
            location: Directory where the dataset is located.
            split: The data split to load ("train" or "test")
            normalize: Whether to perform query-level feature
                normalization.
            filter_queries: Whether to filter out queries that
                have no relevant items. If not given this will filter queries
                for the test set but not the train set.
            download: Whether to download the dataset if it does not
                exist.
            validate_checksums: Whether to validate the dataset files
                via sha256.
        """
        # Check if specified split exists.
        if split not in Example3.splits.keys():
            raise ValueError("unrecognized data split '%s'" % split)

        # Validate dataset exists and is correct, or download it.
        validate_and_download(
            location=location,
            expected_files=Example3.expected_files,
            downloader=Example3.downloader if download else None,
            validate_checksums=validate_checksums)

        # Only filter queries on non-train splits.
        if filter_queries is None:
            filter_queries = False if split == "train" else True

        # Initialize the dataset.
        super().__init__(file=os.path.join(location, Example3.splits[split]),
                         sparse=False, normalize=normalize,
                         filter_queries=filter_queries, zero_based="auto")