def test_download_and_trandform():
    rawdir = __warehouse__.joinpath("CiteULike", "raw")
    os.makedirs(rawdir, exist_ok=True)
    citeulike = CiteULike(rawdir)
    citeulike.download()
    assert rawdir.joinpath(citeulike.__corefile__).exists()
    df = citeulike.transform()
    assert all(c in df.columns for c in ["user_id", "item_id", "timestamp"])
    assert len(df) > 0
def test_download_and_trandform():
    rawdir = __warehouse__.joinpath("Lastfm1K", "raw")
    os.makedirs(rawdir, exist_ok=True)
    lastfm1k = Lastfm1K(rawdir)
    lastfm1k.download()
    assert rawdir.joinpath(lastfm1k.__corefile__).exists()
    df = lastfm1k.transform("song")
    assert all(c in df.columns for c in ["user_id", "item_id", "timestamp"])
    assert len(df) > 0
Exemple #3
0
def test_download_and_trandform():
    rawdir = __warehouse__.joinpath("TaFeng", "raw")
    os.makedirs(rawdir, exist_ok=True)
    tafeng = TaFeng(rawdir)
    tafeng.download()
    assert all(rawdir.joinpath(cf).exists() for cf in tafeng.__corefile__)
    df = tafeng.transform()
    assert all(c in df.columns for c in ["user_id", "item_id", "timestamp"])
    assert len(df) > 0
Exemple #4
0
def test_download_and_trandform():
    rawdir = __warehouse__.joinpath("MovieLens20M", "raw")
    os.makedirs(rawdir, exist_ok=True)
    movielens20m = MovieLens20M(rawdir)
    movielens20m.download()
    assert rawdir.joinpath(movielens20m.__corefile__).exists()
    df = movielens20m.transform(4)
    assert all(c in df.columns for c in ["user_id", "item_id", "timestamp"])
    assert len(df) > 0
def test_download_and_trandform():
    rawdir = __warehouse__.joinpath("Gowalla", "raw")
    os.makedirs(rawdir, exist_ok=True)
    gowalla = Gowalla(rawdir)
    gowalla.download()
    assert rawdir.joinpath(gowalla.__corefile__).exists()
    df = gowalla.transform()
    assert all(c in df.columns for c in ["user_id", "item_id", "timestamp"])
    assert len(df) > 0
Exemple #6
0
def test_download_and_trandform():
    rawdir = __warehouse__.joinpath("Amazon", "raw")
    os.makedirs(rawdir, exist_ok=True)
    amazon = Amazon(rawdir)
    category = "Pet"
    amazon.download(category)
    assert rawdir.joinpath(amazon.__corefile__[category]).exists()
    df = amazon.transform(category, 4)
    assert all(c in df.columns for c in ["user_id", "item_id", "timestamp"])
    assert len(df) > 0
Exemple #7
0
def _download(dataset_name):
    _rawdir = __warehouse__.joinpath(dataset_name, "raw")
    os.makedirs(_rawdir, exist_ok=True)

    if dataset_name.startswith("Amazon"):
        dataset_classes["Amazon"](_rawdir).download(dataset_name.split("-")[1])
    elif dataset_name.startswith("FourSquare"):
        dataset_classes["FourSquare"](_rawdir).download()
    else:
        dataset_classes[dataset_name](_rawdir).download()
Exemple #8
0
def _process(args):
    if "-" in args.dataset:
        classname, sub = args.dataset.split("-")
    else:
        classname = args.dataset
    d = dataset_classes[classname](__warehouse__.joinpath(args.dataset, "raw"))

    config = {
        "min_freq_user": args.min_freq_user,
        "min_freq_item": args.min_freq_item,
        "input_len": args.input_len,
        "target_len": args.target_len,
        "no_augment": args.no_augment,
        "remove_duplicates": args.remove_duplicates,
        "session_interval": args.session_interval,
        "min_session_len": args.min_session_len,
        "max_session_len": args.max_session_len,
        "split_by": args.split_by,
        "dev_split": args.dev_split,
        "test_split": args.test_split,
        "task": args.task,
        "pre_sessions": args.pre_sessions,
        "pick_targets": args.pick_targets
    }
    if classname in ["Amazon", "MovieLens20M", "Yelp"]:
        config["rating_threshold"] = args.rating_threshold
    elif classname == "Lastfm1K":
        config["item_type"] = args.item_type

    logger.info("Transforming...")
    if classname == "Amazon":
        df = d.transform(sub, args.rating_threshold)
    elif classname in ["MovieLens20M", "Yelp"]:
        df = d.transform(args.rating_threshold)
    elif classname == "FourSquare":
        df = d.transform(sub)
    elif classname == "Lastfm1K":
        df = d.transform(args.item_type)
    else:
        df = d.transform()

    if args.split_by == "time":
        config["dev_split"], config["test_split"] = access_split_days(df)
        # Processed check
        if ("time_splits" in args and
            (config["dev_split"], config["test_split"]) in args.time_splits):
            logger.warning(
                "You have run this config, the config id is {}".format(
                    args.time_splits[(config["dev_split"],
                                      config["test_split"])]))
            sys.exit(1)
        config["max_timestamp"] = df["timestamp"].max()

    preprocess_and_save(df, args.dataset, config)
Exemple #9
0
def handle_process(args, downloaded_datasets, processed_datasets):
    if args.dataset not in __datasets__:
        raise ValueError("Supported datasets: {}".format(
            ", ".join(__datasets__)))
    if args.dataset not in downloaded_datasets:
        raise ValueError("{} has not been downloaded".format(args.dataset))

    if args.split_by == "user":
        if args.dev_split <= 0 or args.dev_split >= 1:
            raise ValueError("dev split ratio should be in (0, 1)")
        if args.test_split <= 0 or args.test_split >= 1:
            raise ValueError("test split ratio should be in (0, 1)")

    if args.task == "short":
        if args.input_len <= 0:
            raise ValueError("input length must > 0")
        if args.session_interval < 0:
            raise ValueError("session interval must >= 0 minutes")
    else:
        if args.session_interval <= 0:
            raise ValueError("session interval must > 0 minutes")
        if args.pre_sessions < 1:
            raise ValueError("number of previous sessions must > 0")

    if args.target_len <= 0:
        raise ValueError("target length must > 0")

    if args.session_interval > 0:
        if args.min_session_len <= args.target_len:
            raise ValueError("min session length must > target length")
        if args.max_session_len < args.min_session_len:
            raise ValueError("max session length must >= min session length")

    if args.dataset in processed_datasets:
        # TODO Improve processed check when some arguments are not used
        time_splits = {}
        for c in processed_datasets[args.dataset]:
            config = read_json(
                __warehouse__.joinpath(args.dataset, "processed", c,
                                       "config.json"))
            if args.split_by == "user" and all(
                [args.__dict__[k] == v for k, v in config.items()]):
                print(
                    "You have run this config, the config id is: {}".format(c))
                sys.exit(1)
            if args.split_by == "time" and all([
                    args.__dict__[k] == v for k, v in config.items()
                    if k not in ["dev_split", "test_split"]
            ]):
                time_splits[(config["dev_split"], config["test_split"])] = c
        args.time_splits = time_splits
    _process(args)
def test_download_and_trandform():
    rawdir = __warehouse__.joinpath("FourSquare", "raw")
    os.makedirs(rawdir, exist_ok=True)
    foursquare = FourSquare(rawdir)
    cities = ["NYC", "Tokyo"]
    foursquare.download()
    for c in cities:
        assert rawdir.joinpath(foursquare.__corefile__[c]).exists()
    for c in cities:
        df = foursquare.transform(c)
        assert all(c in df.columns
                   for c in ["user_id", "item_id", "timestamp"])
        assert len(df) > 0
Exemple #11
0
    def __init__(self, name: str, config_id: str, train: bool,
                 development: bool):
        super(Dataset, self).__init__()
        datadir = __warehouse__.joinpath(name, "processed", config_id,
                                         "dev" if development else "test")
        datapath = datadir.joinpath("train.pkl" if train else "test.pkl")
        if datapath.exists():
            with open(datapath, "rb") as f:
                self.dataset = pickle.load(f)
        else:
            raise ValueError("{} does not exist!".format(datapath))
        with open(datadir.joinpath("stats.json"), "r") as f:
            self.stats = json.load(f)

        if train:
            self.item_counts = Counter()
            for data in self.dataset:
                if len(data) > 5:
                    self.item_counts.update(data[1] + data[2] + data[3])
                else:
                    self.item_counts.update(data[1] + data[2])
Exemple #12
0
def main():
    parser = argparse.ArgumentParser("srdatasets | python -m srdatasets")
    subparsers = parser.add_subparsers(help="commands", dest="command")
    # info
    parser_i = subparsers.add_parser("info", help="print local datasets info")
    parser_i.add_argument("--dataset", type=str, default=None, help="dataset name")

    # download
    parser_d = subparsers.add_parser("download", help="download datasets")
    parser_d.add_argument("--dataset", type=str, required=True, help="dataset name")

    # process
    parser_g = subparsers.add_parser(
        "process",
        help="process datasets",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser_g.add_argument("--dataset", type=str, required=True, help="dataset name")
    parser_g.add_argument(
        "--min-freq-item", type=int, default=5, help="minimum occurrence times of item"
    )
    parser_g.add_argument(
        "--min-freq-user", type=int, default=5, help="minimum occurrence times of user"
    )
    parser_g.add_argument(
        "--task",
        type=str,
        choices=["short", "long-short"],
        default="short",
        help="short-term task or long-short-term task",
    )
    parser_g.add_argument(
        "--split-by",
        type=str,
        choices=["user", "time"],
        default="user",
        help="user-based or time-based dataset splitting",
    )
    parser_g.add_argument(
        "--dev-split",
        type=float,
        default=0.1,
        help="[user-split] the fraction of developemnt dataset",
    )
    parser_g.add_argument(
        "--test-split",
        type=float,
        default=0.2,
        help="[user-split] the fraction of test dataset",
    )
    parser_g.add_argument(
        "--input-len", type=int, default=5, help="[short] input sequence length"
    )
    parser_g.add_argument(
        "--target-len", type=int, default=1, help="target sequence length"
    )
    parser_g.add_argument(
        "--no-augment", action="store_true", help="do not use data augmentation"
    )
    parser_g.add_argument(
        "--remove-duplicates",
        action="store_true",
        help="remove duplicate items in user sequence",
    )
    parser_g.add_argument(
        "--session-interval",
        type=int,
        default=0,
        help="[short-optional, long-short-required] split user sequences into sessions (minutes)",
    )
    parser_g.add_argument(
        "--max-session-len", type=int, default=20, help="max session length"
    )
    parser_g.add_argument(
        "--min-session-len", type=int, default=2, help="min session length"
    )
    parser_g.add_argument(
        "--pre-sessions",
        type=int,
        default=10,
        help="[long-short] number of previous sessions",
    )
    parser_g.add_argument(
        "--pick-targets",
        type=str,
        choices=["last", "random"],
        default="random",
        help="[long-short] pick T random or last items from current session as targets",
    )
    parser_g.add_argument(
        "--rating-threshold",
        type=int,
        default=4,
        help="[Amazon-X, Movielens20M, Yelp] ratings great or equal than this are treated as valid",
    )
    parser_g.add_argument(
        "--item-type",
        type=str,
        choices=["song", "artist"],
        default="song",
        help="[Lastfm1K] set item to song or artist",
    )
    args = parser.parse_args()

    if "dataset" in args and args.dataset is not None:
        args.dataset = get_datasetname(args.dataset)

    if args.command is None:
        parser.print_help()
    else:
        downloaded_datasets = get_downloaded_datasets()
        processed_datasets = get_processed_datasets()

        if args.command == "download":
            if args.dataset not in __datasets__:
                raise ValueError("Supported datasets: {}".format(", ".join(__datasets__)))
            if args.dataset in downloaded_datasets:
                raise ValueError("{} has been downloaded".format(args.dataset))
            _download(args.dataset)
        elif args.command == "process":
            if args.dataset not in __datasets__:
                raise ValueError("Supported datasets: {}".format(", ".join(__datasets__)))
            if args.dataset not in downloaded_datasets:
                raise ValueError("{} has not been downloaded".format(args.dataset))

            if args.split_by == "user":
                if args.dev_split <= 0 or args.dev_split >= 1:
                    raise ValueError("dev split ratio should be in (0, 1)")
                if args.test_split <= 0 or args.test_split >= 1:
                    raise ValueError("test split ratio should be in (0, 1)")

            if args.task == "short":
                if args.input_len <= 0:
                    raise ValueError("input length must > 0")
                if args.session_interval < 0:
                    raise ValueError("session interval must >= 0 minutes")
            else:
                if args.session_interval <= 0:
                    raise ValueError("session interval must > 0 minutes")
                if args.pre_sessions < 1:
                    raise ValueError("number of previous sessions must > 0")

            if args.target_len <= 0:
                raise ValueError("target length must > 0")

            if args.session_interval > 0:
                if args.min_session_len <= args.target_len:
                    raise ValueError("min session length must > target length")
                if args.max_session_len < args.min_session_len:
                    raise ValueError("max session length must >= min session length")

            if args.dataset in processed_datasets:
                # TODO Improve processed check when some arguments are not used
                time_splits = {}
                for c in processed_datasets[args.dataset]:
                    config = read_json(
                        __warehouse__.joinpath(args.dataset, "processed", c, "config.json")
                    )
                    if args.split_by == "user" and all(
                        [args.__dict__[k] == v for k, v in config.items()]
                    ):
                        print("You have run this config, the config id is: {}".format(c))
                        sys.exit(1)
                    if args.split_by == "time" and all(
                        [
                            args.__dict__[k] == v
                            for k, v in config.items()
                            if k not in ["dev_split", "test_split"]
                        ]
                    ):
                        time_splits[(config["dev_split"], config["test_split"])] = c
                args.time_splits = time_splits
            _process(args)
        else:
            if args.dataset is None:
                table = [
                    [
                        d,
                        "Y" if d in downloaded_datasets else "",
                        len(processed_datasets[d]) if d in processed_datasets else "",
                    ]
                    for d in __datasets__
                ]
                print(
                    tabulate(
                        table,
                        headers=["name", "downloaded", "processed configs"],
                        tablefmt="psql",
                    )
                )
            else:
                if args.dataset not in __datasets__:
                    raise ValueError(
                        "Supported datasets: {}".format(", ".join(__datasets__))
                    )
                if args.dataset not in downloaded_datasets:
                    print("{} has not been downloaded".format(args.dataset))
                else:
                    if args.dataset not in processed_datasets:
                        print("{} has not been processed".format(args.dataset))
                    else:
                        configs = json_normalize(
                            [
                                read_json(
                                    __warehouse__.joinpath(
                                        args.dataset, "processed", c, "config.json"
                                    )
                                )
                                for c in processed_datasets[args.dataset]
                            ]
                        )
                        print("Configs")
                        configs_part1 = configs.iloc[:, :8]
                        configs_part1.insert(
                            0, "config id", processed_datasets[args.dataset]
                        )
                        print(
                            tabulate(
                                configs_part1,
                                headers="keys",
                                showindex=False,
                                tablefmt="psql",
                            )
                        )
                        print()
                        configs_part2 = configs.iloc[:, 8:]
                        configs_part2.insert(
                            0, "config id", processed_datasets[args.dataset]
                        )
                        print(
                            tabulate(
                                configs_part2,
                                headers="keys",
                                showindex=False,
                                tablefmt="psql",
                            )
                        )
                        print("\nStats")
                        stats = json_normalize(
                            [
                                read_json(
                                    __warehouse__.joinpath(
                                        args.dataset, "processed", c, m, "stats.json"
                                    )
                                )
                                for c in processed_datasets[args.dataset]
                                for m in ["dev", "test"]
                            ]
                        )
                        modes = ["development", "test"] * len(
                            processed_datasets[args.dataset]
                        )
                        stats.insert(0, "mode", modes)
                        ids = []
                        for c in processed_datasets[args.dataset]:
                            ids.extend([c, ""])
                        stats.insert(0, "config id", ids)
                        print(
                            tabulate(
                                stats, headers="keys", showindex=False, tablefmt="psql"
                            )
                        )
Exemple #13
0
def preprocess_and_save(df, dname, config):
    """General preprocessing method

    Args:
        df (DataFrame): columns: `user_id`, `item_id`, `timestamp`.
        args (Namespace): arguments.
    """
    # Generate sequences
    logger.info("Generating user sequences...")
    seqs = generate_sequences(df, config)

    # Split sequences in different ways
    if config["session_interval"] > 0:
        split = split_sequences_session
    else:
        split = split_sequences

    logger.info("Splitting user sequences into train/test...")
    train_seqs, test_seqs = split(seqs, config, 0)

    logger.info("Splitting train into dev-train/dev-test...")
    dev_train_seqs, dev_test_seqs = split(train_seqs, config, 1)

    # Remove duplicates (optional)
    if config["remove_duplicates"]:
        logger.info("Removing duplicates...")
        train_seqs, test_seqs, dev_train_seqs, dev_test_seqs = [
            remove_duplicates(seqs, config)
            for seqs in [train_seqs, test_seqs, dev_train_seqs, dev_test_seqs]
        ]

    # Do not use data augmentation (optional)
    if config["no_augment"]:
        logger.info("Enabling no data augmentation...")
        train_seqs, test_seqs, dev_train_seqs, dev_test_seqs = [
            enable_no_augment(seqs, config)
            for seqs in [train_seqs, test_seqs, dev_train_seqs, dev_test_seqs]
        ]

    # Remove unknowns
    logger.info("Removing unknowns in test...")
    test_seqs = remove_unknowns(train_seqs, test_seqs, config)

    logger.info("Removing unknowns in dev-test...")
    dev_test_seqs = remove_unknowns(dev_train_seqs, dev_test_seqs, config)

    # Reassign user and item ids
    logger.info("Reassigning ids (train/test)...")
    train_seqs, test_seqs = reassign_ids(train_seqs, test_seqs)

    logger.info("Reassigning ids (dev-train/dev-test)...")
    dev_train_seqs, dev_test_seqs = reassign_ids(dev_train_seqs, dev_test_seqs)

    # Make datasets based on task
    if config["task"] == "short":
        make_dataset = make_dataset_short
    else:
        make_dataset = make_dataset_long_short

    logger.info("Making datasets...")
    train_data, test_data, dev_train_data, dev_test_data = [
        make_dataset(seqs, config)
        for seqs in [train_seqs, test_seqs, dev_train_seqs, dev_test_seqs]
    ]

    # Dump to disk
    logger.info("Dumping...")
    processed_path = __warehouse__.joinpath(dname, "processed",
                                            "c" + str(int(time.time() * 1000)))
    dump(processed_path, train_data, test_data, 0)
    dump(processed_path, dev_train_data, dev_test_data, 1)

    # Save config
    save_config(processed_path, config)
    logger.info("OK, the config id is: %s", processed_path.stem)
Exemple #14
0
def handle_info(args, downloaded_datasets, processed_datasets):
    if args.dataset is None:
        table = [[
            d, "Y" if d in downloaded_datasets else "",
            len(processed_datasets[d]) if d in processed_datasets else ""
        ] for d in __datasets__]
        print(
            tabulate(table,
                     headers=["name", "downloaded", "processed configs"],
                     tablefmt="psql"))
    else:
        if args.dataset not in __datasets__:
            raise ValueError("Supported datasets: {}".format(
                ", ".join(__datasets__)))
        if args.dataset not in downloaded_datasets:
            print("{} has not been downloaded".format(args.dataset))
        else:
            if args.dataset not in processed_datasets:
                print("{} has not been processed".format(args.dataset))
            else:
                configs = json_normalize([
                    read_json(
                        __warehouse__.joinpath(args.dataset, "processed", c,
                                               "config.json"))
                    for c in processed_datasets[args.dataset]
                ])
                print("Configs")
                configs_part1 = configs.iloc[:, :8]
                configs_part1.insert(0, "config id",
                                     processed_datasets[args.dataset])
                print(
                    tabulate(configs_part1,
                             headers="keys",
                             showindex=False,
                             tablefmt="psql"))
                print()
                configs_part2 = configs.iloc[:, 8:]
                configs_part2.insert(0, "config id",
                                     processed_datasets[args.dataset])
                print(
                    tabulate(configs_part2,
                             headers="keys",
                             showindex=False,
                             tablefmt="psql"))
                print("\nStats")
                stats = json_normalize([
                    read_json(
                        __warehouse__.joinpath(args.dataset, "processed", c, m,
                                               "stats.json"))
                    for c in processed_datasets[args.dataset]
                    for m in ["dev", "test"]
                ])
                modes = ["development", "test"] * len(
                    processed_datasets[args.dataset])
                stats.insert(0, "mode", modes)
                ids = []
                for c in processed_datasets[args.dataset]:
                    ids.extend([c, ""])
                stats.insert(0, "config id", ids)
                print(
                    tabulate(stats,
                             headers="keys",
                             showindex=False,
                             tablefmt="psql"))
                 test_split=0.2,
                 input_len=9,
                 target_len=1,
                 session_interval=0,
                 max_session_len=10,
                 min_session_len=2,
                 pre_sessions=10,
                 pick_targets="random",
                 no_augment=False,
                 remove_duplicates=False)

if args.dataset not in get_downloaded_datasets():
    _download(args.dataset)

if args.dataset in get_processed_datasets():
    shutil.rmtree(__warehouse__.joinpath(args.dataset, "processed"))

# For short term task
short_args = copy.deepcopy(args)
_process(short_args)

# For long-short term task
long_short_args = copy.deepcopy(args)
long_short_args.task = "long-short"
long_short_args.session_interval = 60
_process(long_short_args)


def test_dataloader():
    config_ids = get_processed_datasets()[args.dataset]
    for cid in config_ids:
    def __init__(
        self,
        dataset_name: str,
        config_id: str,
        batch_size: int = 1,
        train: bool = True,
        development: bool = False,
        negatives_per_target: int = 0,
        include_timestamp: bool = False,
        drop_last: bool = False,
    ):
        """Loader of sequential recommendation datasets

        Args:
            dataset_name (str): dataset name.
            config_id (str): dataset config id
            batch_size (int): batch_size
            train (bool, optional): load training data
            development (bool, optional): use the dataset for hyperparameter tuning
            negatives_per_target (int, optional): number of negative samples per target
            include_timestamp (bool, optional): add timestamps to batch data
            drop_last (bool, optional): drop last incomplete batch

        Note: training data is shuffled automatically.
        """
        dataset_name = get_datasetname(dataset_name)

        if dataset_name not in __datasets__:
            raise ValueError(
                "Unrecognized dataset, currently supported datasets: {}".
                format(", ".join(__datasets__)))

        _processed_datasets = get_processed_datasets()
        if dataset_name not in _processed_datasets:
            raise ValueError(
                "{} has not been processed, currently processed datasets: {}".
                format(
                    dataset_name,
                    ", ".join(_processed_datasets)
                    if _processed_datasets else "none",
                ))

        if config_id not in _processed_datasets[dataset_name]:
            raise ValueError(
                "Unrecognized config id, existing config ids: {}".format(
                    ", ".join(_processed_datasets[dataset_name])))

        if negatives_per_target < 0:
            negatives_per_target = 0
            logger.warning(
                "Number of negative samples per target should >= 0, reset to 0"
            )

        if not train and negatives_per_target > 0:
            logger.warning(
                "Negative samples are used for training, set negatives_per_target has no effect when testing"
            )

        dataset_dir = __warehouse__.joinpath(dataset_name, "processed",
                                             config_id,
                                             "dev" if development else "test")
        with open(dataset_dir.joinpath("stats.json"), "r") as f:
            self.stats = json.load(f)

        dataset_path = dataset_dir.joinpath(
            "train.pkl" if train else "test.pkl")
        with open(dataset_path, "rb") as f:
            self.dataset = pickle.load(f)  # list

        if train:
            counter = Counter()
            for data in self.dataset:
                if len(data) > 5:
                    counter.update(data[1] + data[2] + data[3])
                else:
                    counter.update(data[1] + data[2])
            self.item_counts = np.array(
                [counter[i] for i in range(max(counter.keys()) + 1)])

        if batch_size <= 0:
            raise ValueError("batch_size should >= 1")
        if batch_size > len(self.dataset):
            raise ValueError("batch_size exceeds the dataset size")

        self.batch_size = batch_size
        self.train = train
        self.include_timestamp = include_timestamp
        self.negatives_per_target = negatives_per_target
        self.drop_last = drop_last
        self._batch_idx = 0