def __init__(self, location: str = dataset_dir("MSLR10K"), split: str = "train", fold: int = 1, normalize: bool = True, filter_queries: Optional[bool] = None, download: bool = True, validate_checksums: bool = True): """ Args: location: Directory where the dataset is located. split: The data split to load ("train", "test" or "vali") fold: Which data fold to load (1...5) normalize: Whether to perform query-level feature normalization. filter_queries: Whether to filter out queries that have no relevant items. If not given this will filter queries for the test set but not the train set. download: Whether to download the dataset if it does not exist. validate_checksums: Whether to validate the dataset files via sha256. """ # Check if specified split and fold exists. if split not in MSLR10K.splits.keys(): raise ValueError("unrecognized data split '%s'" % str(split)) if fold not in MSLR10K.per_fold_expected_files.keys(): raise ValueError("unrecognized data fold '%s'" % str(fold)) # Validate dataset exists and is correct, or download it. validate_and_download( location=location, expected_files=MSLR10K.per_fold_expected_files[fold], downloader=MSLR10K.downloader if download else None, validate_checksums=validate_checksums) # Only filter queries on non-train splits. if filter_queries is None: filter_queries = False if split == "train" else True # Initialize the dataset. datafile = os.path.join(location, "Fold%d" % fold, MSLR10K.splits[split]) super().__init__(file=datafile, sparse=False, normalize=normalize, filter_queries=filter_queries, zero_based="auto")
def __init__(self, location: str = dataset_dir("example3"), split: str = "train", normalize: bool = True, filter_queries: Optional[bool] = None, download: bool = True, validate_checksums: bool = True): """ Args: location: Directory where the dataset is located. split: The data split to load ("train" or "test") normalize: Whether to perform query-level feature normalization. filter_queries: Whether to filter out queries that have no relevant items. If not given this will filter queries for the test set but not the train set. download: Whether to download the dataset if it does not exist. validate_checksums: Whether to validate the dataset files via sha256. """ # Check if specified split exists. if split not in Example3.splits.keys(): raise ValueError("unrecognized data split '%s'" % split) # Validate dataset exists and is correct, or download it. validate_and_download( location=location, expected_files=Example3.expected_files, downloader=Example3.downloader if download else None, validate_checksums=validate_checksums) # Only filter queries on non-train splits. if filter_queries is None: filter_queries = False if split == "train" else True # Initialize the dataset. super().__init__(file=os.path.join(location, Example3.splits[split]), sparse=False, normalize=normalize, filter_queries=filter_queries, zero_based="auto")