def main(): args, parser = read_arguments() if "dataset" in args and args.dataset is not None: args.dataset = get_datasetname(args.dataset) if args.command is None: parser.print_help() else: downloaded_datasets = get_downloaded_datasets() processed_datasets = get_processed_datasets() if args.command == "download": handle_dowload(args, downloaded_datasets) elif args.command == "process": handle_process(args, downloaded_datasets, processed_datasets) else: handle_info(args, downloaded_datasets, processed_datasets)
def test_dataloader(): config_ids = get_processed_datasets()[args.dataset] for cid in config_ids: for DataLoader in [ srdatasets.dataloader.DataLoader, srdatasets.dataloader_pytorch.DataLoader ]: dataloader = DataLoader(args.dataset, cid, batch_size=32, negatives_per_target=5, include_timestamp=True, drop_last=True) if len(dataloader.dataset[0]) > 5: for users, pre_sess_items, cur_sess_items, target_items, pre_sess_timestamps, cur_sess_timestamps, \ target_timestamps, negatives in dataloader: assert users.shape == (32, ) assert pre_sess_items.shape == (32, args.pre_sessions * args.max_session_len) assert cur_sess_items.shape == (32, args.max_session_len - args.target_len) assert target_items.shape == (32, args.target_len) assert pre_sess_timestamps.shape == (32, args.pre_sessions * args.max_session_len) assert cur_sess_timestamps.shape == (32, args.max_session_len - args.target_len) assert target_timestamps.shape == (32, args.target_len) assert negatives.shape == (32, args.target_len, 5) else: for users, in_items, out_items, in_timestamps, out_timestamps, negatives in dataloader: assert users.shape == (32, ) assert in_items.shape == (32, args.input_len) assert out_items.shape == (32, args.target_len) assert in_timestamps.shape == (32, args.input_len) assert out_timestamps.shape == (32, args.target_len) assert negatives.shape == (32, args.target_len, 5)
def main(): parser = argparse.ArgumentParser("srdatasets | python -m srdatasets") subparsers = parser.add_subparsers(help="commands", dest="command") # info parser_i = subparsers.add_parser("info", help="print local datasets info") parser_i.add_argument("--dataset", type=str, default=None, help="dataset name") # download parser_d = subparsers.add_parser("download", help="download datasets") parser_d.add_argument("--dataset", type=str, required=True, help="dataset name") # process parser_g = subparsers.add_parser( "process", help="process datasets", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser_g.add_argument("--dataset", type=str, required=True, help="dataset name") parser_g.add_argument( "--min-freq-item", type=int, default=5, help="minimum occurrence times of item" ) parser_g.add_argument( "--min-freq-user", type=int, default=5, help="minimum occurrence times of user" ) parser_g.add_argument( "--task", type=str, choices=["short", "long-short"], default="short", help="short-term task or long-short-term task", ) parser_g.add_argument( "--split-by", type=str, choices=["user", "time"], default="user", help="user-based or time-based dataset splitting", ) parser_g.add_argument( "--dev-split", type=float, default=0.1, help="[user-split] the fraction of developemnt dataset", ) parser_g.add_argument( "--test-split", type=float, default=0.2, help="[user-split] the fraction of test dataset", ) parser_g.add_argument( "--input-len", type=int, default=5, help="[short] input sequence length" ) parser_g.add_argument( "--target-len", type=int, default=1, help="target sequence length" ) parser_g.add_argument( "--no-augment", action="store_true", help="do not use data augmentation" ) parser_g.add_argument( "--remove-duplicates", action="store_true", help="remove duplicate items in user sequence", ) parser_g.add_argument( "--session-interval", type=int, default=0, help="[short-optional, long-short-required] split user sequences into sessions (minutes)", ) parser_g.add_argument( "--max-session-len", type=int, default=20, help="max session length" ) parser_g.add_argument( "--min-session-len", type=int, default=2, help="min session length" ) parser_g.add_argument( "--pre-sessions", type=int, default=10, help="[long-short] number of previous sessions", ) parser_g.add_argument( "--pick-targets", type=str, choices=["last", "random"], default="random", help="[long-short] pick T random or last items from current session as targets", ) parser_g.add_argument( "--rating-threshold", type=int, default=4, help="[Amazon-X, Movielens20M, Yelp] ratings great or equal than this are treated as valid", ) parser_g.add_argument( "--item-type", type=str, choices=["song", "artist"], default="song", help="[Lastfm1K] set item to song or artist", ) args = parser.parse_args() if "dataset" in args and args.dataset is not None: args.dataset = get_datasetname(args.dataset) if args.command is None: parser.print_help() else: downloaded_datasets = get_downloaded_datasets() processed_datasets = get_processed_datasets() if args.command == "download": if args.dataset not in __datasets__: raise ValueError("Supported datasets: {}".format(", ".join(__datasets__))) if args.dataset in downloaded_datasets: raise ValueError("{} has been downloaded".format(args.dataset)) _download(args.dataset) elif args.command == "process": if args.dataset not in __datasets__: raise ValueError("Supported datasets: {}".format(", ".join(__datasets__))) if args.dataset not in downloaded_datasets: raise ValueError("{} has not been downloaded".format(args.dataset)) if args.split_by == "user": if args.dev_split <= 0 or args.dev_split >= 1: raise ValueError("dev split ratio should be in (0, 1)") if args.test_split <= 0 or args.test_split >= 1: raise ValueError("test split ratio should be in (0, 1)") if args.task == "short": if args.input_len <= 0: raise ValueError("input length must > 0") if args.session_interval < 0: raise ValueError("session interval must >= 0 minutes") else: if args.session_interval <= 0: raise ValueError("session interval must > 0 minutes") if args.pre_sessions < 1: raise ValueError("number of previous sessions must > 0") if args.target_len <= 0: raise ValueError("target length must > 0") if args.session_interval > 0: if args.min_session_len <= args.target_len: raise ValueError("min session length must > target length") if args.max_session_len < args.min_session_len: raise ValueError("max session length must >= min session length") if args.dataset in processed_datasets: # TODO Improve processed check when some arguments are not used time_splits = {} for c in processed_datasets[args.dataset]: config = read_json( __warehouse__.joinpath(args.dataset, "processed", c, "config.json") ) if args.split_by == "user" and all( [args.__dict__[k] == v for k, v in config.items()] ): print("You have run this config, the config id is: {}".format(c)) sys.exit(1) if args.split_by == "time" and all( [ args.__dict__[k] == v for k, v in config.items() if k not in ["dev_split", "test_split"] ] ): time_splits[(config["dev_split"], config["test_split"])] = c args.time_splits = time_splits _process(args) else: if args.dataset is None: table = [ [ d, "Y" if d in downloaded_datasets else "", len(processed_datasets[d]) if d in processed_datasets else "", ] for d in __datasets__ ] print( tabulate( table, headers=["name", "downloaded", "processed configs"], tablefmt="psql", ) ) else: if args.dataset not in __datasets__: raise ValueError( "Supported datasets: {}".format(", ".join(__datasets__)) ) if args.dataset not in downloaded_datasets: print("{} has not been downloaded".format(args.dataset)) else: if args.dataset not in processed_datasets: print("{} has not been processed".format(args.dataset)) else: configs = json_normalize( [ read_json( __warehouse__.joinpath( args.dataset, "processed", c, "config.json" ) ) for c in processed_datasets[args.dataset] ] ) print("Configs") configs_part1 = configs.iloc[:, :8] configs_part1.insert( 0, "config id", processed_datasets[args.dataset] ) print( tabulate( configs_part1, headers="keys", showindex=False, tablefmt="psql", ) ) print() configs_part2 = configs.iloc[:, 8:] configs_part2.insert( 0, "config id", processed_datasets[args.dataset] ) print( tabulate( configs_part2, headers="keys", showindex=False, tablefmt="psql", ) ) print("\nStats") stats = json_normalize( [ read_json( __warehouse__.joinpath( args.dataset, "processed", c, m, "stats.json" ) ) for c in processed_datasets[args.dataset] for m in ["dev", "test"] ] ) modes = ["development", "test"] * len( processed_datasets[args.dataset] ) stats.insert(0, "mode", modes) ids = [] for c in processed_datasets[args.dataset]: ids.extend([c, ""]) stats.insert(0, "config id", ids) print( tabulate( stats, headers="keys", showindex=False, tablefmt="psql" ) )
dev_split=0.1, test_split=0.2, input_len=9, target_len=1, session_interval=0, max_session_len=10, min_session_len=2, pre_sessions=10, pick_targets="random", no_augment=False, remove_duplicates=False) if args.dataset not in get_downloaded_datasets(): _download(args.dataset) if args.dataset in get_processed_datasets(): shutil.rmtree(__warehouse__.joinpath(args.dataset, "processed")) # For short term task short_args = copy.deepcopy(args) _process(short_args) # For long-short term task long_short_args = copy.deepcopy(args) long_short_args.task = "long-short" long_short_args.session_interval = 60 _process(long_short_args) def test_dataloader(): config_ids = get_processed_datasets()[args.dataset]
def __init__(self, dataset_name: str, config_id: str, batch_size: int = 1, train: bool = True, development: bool = False, negatives_per_target: int = 0, include_timestamp: bool = False, **kwargs): """Loader of sequential recommendation datasets Args: dataset_name (str): dataset name config_id (str): dataset config id batch_size (int): batch_size train (bool, optional): load training dataset development (bool, optional): use the dataset for hyperparameter tuning negatives_per_target (int, optional): number of negative samples per target include_timestamp (bool, optional): add timestamps to batch data Note: training data is shuffled automatically. """ dataset_name = get_datasetname(dataset_name) if dataset_name not in __datasets__: raise ValueError( "Unrecognized dataset, currently supported datasets: {}". format(", ".join(__datasets__))) _processed_datasets = get_processed_datasets() if dataset_name not in _processed_datasets: raise ValueError( "{} has not been processed, currently processed datasets: {}". format( dataset_name, ", ".join(_processed_datasets) if _processed_datasets else "none")) if config_id not in _processed_datasets[dataset_name]: raise ValueError( "Unrecognized config id, existing config ids: {}".format( ", ".join(_processed_datasets[dataset_name]))) if negatives_per_target < 0: negatives_per_target = 0 logger.warning( "Number of negative samples per target should >= 0, reset to 0" ) if not train and negatives_per_target > 0: logger.warning( "Negative samples are used for training, set negatives_per_target has no effect when testing" ) self.train = train self.include_timestamp = include_timestamp self.negatives_per_target = negatives_per_target self.dataset = Dataset(dataset_name, config_id, train, development) if train: self.item_counts = torch.tensor([ self.dataset.item_counts[i] for i in range(max(self.dataset.item_counts.keys()) + 1) ], dtype=torch.float) super().__init__(self.dataset, batch_size=batch_size, shuffle=train, collate_fn=self.collate_fn, **kwargs)
def __init__( self, dataset_name: str, config_id: str, batch_size: int = 1, train: bool = True, development: bool = False, negatives_per_target: int = 0, include_timestamp: bool = False, drop_last: bool = False, ): """Loader of sequential recommendation datasets Args: dataset_name (str): dataset name. config_id (str): dataset config id batch_size (int): batch_size train (bool, optional): load training data development (bool, optional): use the dataset for hyperparameter tuning negatives_per_target (int, optional): number of negative samples per target include_timestamp (bool, optional): add timestamps to batch data drop_last (bool, optional): drop last incomplete batch Note: training data is shuffled automatically. """ dataset_name = get_datasetname(dataset_name) if dataset_name not in __datasets__: raise ValueError( "Unrecognized dataset, currently supported datasets: {}". format(", ".join(__datasets__))) _processed_datasets = get_processed_datasets() if dataset_name not in _processed_datasets: raise ValueError( "{} has not been processed, currently processed datasets: {}". format( dataset_name, ", ".join(_processed_datasets) if _processed_datasets else "none", )) if config_id not in _processed_datasets[dataset_name]: raise ValueError( "Unrecognized config id, existing config ids: {}".format( ", ".join(_processed_datasets[dataset_name]))) if negatives_per_target < 0: negatives_per_target = 0 logger.warning( "Number of negative samples per target should >= 0, reset to 0" ) if not train and negatives_per_target > 0: logger.warning( "Negative samples are used for training, set negatives_per_target has no effect when testing" ) dataset_dir = __warehouse__.joinpath(dataset_name, "processed", config_id, "dev" if development else "test") with open(dataset_dir.joinpath("stats.json"), "r") as f: self.stats = json.load(f) dataset_path = dataset_dir.joinpath( "train.pkl" if train else "test.pkl") with open(dataset_path, "rb") as f: self.dataset = pickle.load(f) # list if train: counter = Counter() for data in self.dataset: if len(data) > 5: counter.update(data[1] + data[2] + data[3]) else: counter.update(data[1] + data[2]) self.item_counts = np.array( [counter[i] for i in range(max(counter.keys()) + 1)]) if batch_size <= 0: raise ValueError("batch_size should >= 1") if batch_size > len(self.dataset): raise ValueError("batch_size exceeds the dataset size") self.batch_size = batch_size self.train = train self.include_timestamp = include_timestamp self.negatives_per_target = negatives_per_target self.drop_last = drop_last self._batch_idx = 0