Ejemplo n.º 1
0
def main():
    args, parser = read_arguments()
    if "dataset" in args and args.dataset is not None:
        args.dataset = get_datasetname(args.dataset)

    if args.command is None:
        parser.print_help()
    else:
        downloaded_datasets = get_downloaded_datasets()
        processed_datasets = get_processed_datasets()

        if args.command == "download":
            handle_dowload(args, downloaded_datasets)
        elif args.command == "process":
            handle_process(args, downloaded_datasets, processed_datasets)
        else:
            handle_info(args, downloaded_datasets, processed_datasets)
def test_dataloader():
    config_ids = get_processed_datasets()[args.dataset]
    for cid in config_ids:
        for DataLoader in [
                srdatasets.dataloader.DataLoader,
                srdatasets.dataloader_pytorch.DataLoader
        ]:
            dataloader = DataLoader(args.dataset,
                                    cid,
                                    batch_size=32,
                                    negatives_per_target=5,
                                    include_timestamp=True,
                                    drop_last=True)
            if len(dataloader.dataset[0]) > 5:
                for users, pre_sess_items, cur_sess_items, target_items, pre_sess_timestamps, cur_sess_timestamps, \
                     target_timestamps, negatives in dataloader:
                    assert users.shape == (32, )
                    assert pre_sess_items.shape == (32, args.pre_sessions *
                                                    args.max_session_len)
                    assert cur_sess_items.shape == (32, args.max_session_len -
                                                    args.target_len)
                    assert target_items.shape == (32, args.target_len)
                    assert pre_sess_timestamps.shape == (32,
                                                         args.pre_sessions *
                                                         args.max_session_len)
                    assert cur_sess_timestamps.shape == (32,
                                                         args.max_session_len -
                                                         args.target_len)
                    assert target_timestamps.shape == (32, args.target_len)
                    assert negatives.shape == (32, args.target_len, 5)
            else:
                for users, in_items, out_items, in_timestamps, out_timestamps, negatives in dataloader:
                    assert users.shape == (32, )
                    assert in_items.shape == (32, args.input_len)
                    assert out_items.shape == (32, args.target_len)
                    assert in_timestamps.shape == (32, args.input_len)
                    assert out_timestamps.shape == (32, args.target_len)
                    assert negatives.shape == (32, args.target_len, 5)
Ejemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser("srdatasets | python -m srdatasets")
    subparsers = parser.add_subparsers(help="commands", dest="command")
    # info
    parser_i = subparsers.add_parser("info", help="print local datasets info")
    parser_i.add_argument("--dataset", type=str, default=None, help="dataset name")

    # download
    parser_d = subparsers.add_parser("download", help="download datasets")
    parser_d.add_argument("--dataset", type=str, required=True, help="dataset name")

    # process
    parser_g = subparsers.add_parser(
        "process",
        help="process datasets",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser_g.add_argument("--dataset", type=str, required=True, help="dataset name")
    parser_g.add_argument(
        "--min-freq-item", type=int, default=5, help="minimum occurrence times of item"
    )
    parser_g.add_argument(
        "--min-freq-user", type=int, default=5, help="minimum occurrence times of user"
    )
    parser_g.add_argument(
        "--task",
        type=str,
        choices=["short", "long-short"],
        default="short",
        help="short-term task or long-short-term task",
    )
    parser_g.add_argument(
        "--split-by",
        type=str,
        choices=["user", "time"],
        default="user",
        help="user-based or time-based dataset splitting",
    )
    parser_g.add_argument(
        "--dev-split",
        type=float,
        default=0.1,
        help="[user-split] the fraction of developemnt dataset",
    )
    parser_g.add_argument(
        "--test-split",
        type=float,
        default=0.2,
        help="[user-split] the fraction of test dataset",
    )
    parser_g.add_argument(
        "--input-len", type=int, default=5, help="[short] input sequence length"
    )
    parser_g.add_argument(
        "--target-len", type=int, default=1, help="target sequence length"
    )
    parser_g.add_argument(
        "--no-augment", action="store_true", help="do not use data augmentation"
    )
    parser_g.add_argument(
        "--remove-duplicates",
        action="store_true",
        help="remove duplicate items in user sequence",
    )
    parser_g.add_argument(
        "--session-interval",
        type=int,
        default=0,
        help="[short-optional, long-short-required] split user sequences into sessions (minutes)",
    )
    parser_g.add_argument(
        "--max-session-len", type=int, default=20, help="max session length"
    )
    parser_g.add_argument(
        "--min-session-len", type=int, default=2, help="min session length"
    )
    parser_g.add_argument(
        "--pre-sessions",
        type=int,
        default=10,
        help="[long-short] number of previous sessions",
    )
    parser_g.add_argument(
        "--pick-targets",
        type=str,
        choices=["last", "random"],
        default="random",
        help="[long-short] pick T random or last items from current session as targets",
    )
    parser_g.add_argument(
        "--rating-threshold",
        type=int,
        default=4,
        help="[Amazon-X, Movielens20M, Yelp] ratings great or equal than this are treated as valid",
    )
    parser_g.add_argument(
        "--item-type",
        type=str,
        choices=["song", "artist"],
        default="song",
        help="[Lastfm1K] set item to song or artist",
    )
    args = parser.parse_args()

    if "dataset" in args and args.dataset is not None:
        args.dataset = get_datasetname(args.dataset)

    if args.command is None:
        parser.print_help()
    else:
        downloaded_datasets = get_downloaded_datasets()
        processed_datasets = get_processed_datasets()

        if args.command == "download":
            if args.dataset not in __datasets__:
                raise ValueError("Supported datasets: {}".format(", ".join(__datasets__)))
            if args.dataset in downloaded_datasets:
                raise ValueError("{} has been downloaded".format(args.dataset))
            _download(args.dataset)
        elif args.command == "process":
            if args.dataset not in __datasets__:
                raise ValueError("Supported datasets: {}".format(", ".join(__datasets__)))
            if args.dataset not in downloaded_datasets:
                raise ValueError("{} has not been downloaded".format(args.dataset))

            if args.split_by == "user":
                if args.dev_split <= 0 or args.dev_split >= 1:
                    raise ValueError("dev split ratio should be in (0, 1)")
                if args.test_split <= 0 or args.test_split >= 1:
                    raise ValueError("test split ratio should be in (0, 1)")

            if args.task == "short":
                if args.input_len <= 0:
                    raise ValueError("input length must > 0")
                if args.session_interval < 0:
                    raise ValueError("session interval must >= 0 minutes")
            else:
                if args.session_interval <= 0:
                    raise ValueError("session interval must > 0 minutes")
                if args.pre_sessions < 1:
                    raise ValueError("number of previous sessions must > 0")

            if args.target_len <= 0:
                raise ValueError("target length must > 0")

            if args.session_interval > 0:
                if args.min_session_len <= args.target_len:
                    raise ValueError("min session length must > target length")
                if args.max_session_len < args.min_session_len:
                    raise ValueError("max session length must >= min session length")

            if args.dataset in processed_datasets:
                # TODO Improve processed check when some arguments are not used
                time_splits = {}
                for c in processed_datasets[args.dataset]:
                    config = read_json(
                        __warehouse__.joinpath(args.dataset, "processed", c, "config.json")
                    )
                    if args.split_by == "user" and all(
                        [args.__dict__[k] == v for k, v in config.items()]
                    ):
                        print("You have run this config, the config id is: {}".format(c))
                        sys.exit(1)
                    if args.split_by == "time" and all(
                        [
                            args.__dict__[k] == v
                            for k, v in config.items()
                            if k not in ["dev_split", "test_split"]
                        ]
                    ):
                        time_splits[(config["dev_split"], config["test_split"])] = c
                args.time_splits = time_splits
            _process(args)
        else:
            if args.dataset is None:
                table = [
                    [
                        d,
                        "Y" if d in downloaded_datasets else "",
                        len(processed_datasets[d]) if d in processed_datasets else "",
                    ]
                    for d in __datasets__
                ]
                print(
                    tabulate(
                        table,
                        headers=["name", "downloaded", "processed configs"],
                        tablefmt="psql",
                    )
                )
            else:
                if args.dataset not in __datasets__:
                    raise ValueError(
                        "Supported datasets: {}".format(", ".join(__datasets__))
                    )
                if args.dataset not in downloaded_datasets:
                    print("{} has not been downloaded".format(args.dataset))
                else:
                    if args.dataset not in processed_datasets:
                        print("{} has not been processed".format(args.dataset))
                    else:
                        configs = json_normalize(
                            [
                                read_json(
                                    __warehouse__.joinpath(
                                        args.dataset, "processed", c, "config.json"
                                    )
                                )
                                for c in processed_datasets[args.dataset]
                            ]
                        )
                        print("Configs")
                        configs_part1 = configs.iloc[:, :8]
                        configs_part1.insert(
                            0, "config id", processed_datasets[args.dataset]
                        )
                        print(
                            tabulate(
                                configs_part1,
                                headers="keys",
                                showindex=False,
                                tablefmt="psql",
                            )
                        )
                        print()
                        configs_part2 = configs.iloc[:, 8:]
                        configs_part2.insert(
                            0, "config id", processed_datasets[args.dataset]
                        )
                        print(
                            tabulate(
                                configs_part2,
                                headers="keys",
                                showindex=False,
                                tablefmt="psql",
                            )
                        )
                        print("\nStats")
                        stats = json_normalize(
                            [
                                read_json(
                                    __warehouse__.joinpath(
                                        args.dataset, "processed", c, m, "stats.json"
                                    )
                                )
                                for c in processed_datasets[args.dataset]
                                for m in ["dev", "test"]
                            ]
                        )
                        modes = ["development", "test"] * len(
                            processed_datasets[args.dataset]
                        )
                        stats.insert(0, "mode", modes)
                        ids = []
                        for c in processed_datasets[args.dataset]:
                            ids.extend([c, ""])
                        stats.insert(0, "config id", ids)
                        print(
                            tabulate(
                                stats, headers="keys", showindex=False, tablefmt="psql"
                            )
                        )
                 dev_split=0.1,
                 test_split=0.2,
                 input_len=9,
                 target_len=1,
                 session_interval=0,
                 max_session_len=10,
                 min_session_len=2,
                 pre_sessions=10,
                 pick_targets="random",
                 no_augment=False,
                 remove_duplicates=False)

if args.dataset not in get_downloaded_datasets():
    _download(args.dataset)

if args.dataset in get_processed_datasets():
    shutil.rmtree(__warehouse__.joinpath(args.dataset, "processed"))

# For short term task
short_args = copy.deepcopy(args)
_process(short_args)

# For long-short term task
long_short_args = copy.deepcopy(args)
long_short_args.task = "long-short"
long_short_args.session_interval = 60
_process(long_short_args)


def test_dataloader():
    config_ids = get_processed_datasets()[args.dataset]
Ejemplo n.º 5
0
    def __init__(self,
                 dataset_name: str,
                 config_id: str,
                 batch_size: int = 1,
                 train: bool = True,
                 development: bool = False,
                 negatives_per_target: int = 0,
                 include_timestamp: bool = False,
                 **kwargs):
        """Loader of sequential recommendation datasets

        Args:
            dataset_name (str): dataset name
            config_id (str): dataset config id
            batch_size (int): batch_size
            train (bool, optional): load training dataset
            development (bool, optional): use the dataset for hyperparameter tuning
            negatives_per_target (int, optional): number of negative samples per target
            include_timestamp (bool, optional): add timestamps to batch data
        
        Note: training data is shuffled automatically.
        """
        dataset_name = get_datasetname(dataset_name)

        if dataset_name not in __datasets__:
            raise ValueError(
                "Unrecognized dataset, currently supported datasets: {}".
                format(", ".join(__datasets__)))

        _processed_datasets = get_processed_datasets()
        if dataset_name not in _processed_datasets:
            raise ValueError(
                "{} has not been processed, currently processed datasets: {}".
                format(
                    dataset_name, ", ".join(_processed_datasets)
                    if _processed_datasets else "none"))

        if config_id not in _processed_datasets[dataset_name]:
            raise ValueError(
                "Unrecognized config id, existing config ids: {}".format(
                    ", ".join(_processed_datasets[dataset_name])))

        if negatives_per_target < 0:
            negatives_per_target = 0
            logger.warning(
                "Number of negative samples per target should >= 0, reset to 0"
            )

        if not train and negatives_per_target > 0:
            logger.warning(
                "Negative samples are used for training, set negatives_per_target has no effect when testing"
            )

        self.train = train
        self.include_timestamp = include_timestamp
        self.negatives_per_target = negatives_per_target

        self.dataset = Dataset(dataset_name, config_id, train, development)
        if train:
            self.item_counts = torch.tensor([
                self.dataset.item_counts[i]
                for i in range(max(self.dataset.item_counts.keys()) + 1)
            ],
                                            dtype=torch.float)

        super().__init__(self.dataset,
                         batch_size=batch_size,
                         shuffle=train,
                         collate_fn=self.collate_fn,
                         **kwargs)
    def __init__(
        self,
        dataset_name: str,
        config_id: str,
        batch_size: int = 1,
        train: bool = True,
        development: bool = False,
        negatives_per_target: int = 0,
        include_timestamp: bool = False,
        drop_last: bool = False,
    ):
        """Loader of sequential recommendation datasets

        Args:
            dataset_name (str): dataset name.
            config_id (str): dataset config id
            batch_size (int): batch_size
            train (bool, optional): load training data
            development (bool, optional): use the dataset for hyperparameter tuning
            negatives_per_target (int, optional): number of negative samples per target
            include_timestamp (bool, optional): add timestamps to batch data
            drop_last (bool, optional): drop last incomplete batch

        Note: training data is shuffled automatically.
        """
        dataset_name = get_datasetname(dataset_name)

        if dataset_name not in __datasets__:
            raise ValueError(
                "Unrecognized dataset, currently supported datasets: {}".
                format(", ".join(__datasets__)))

        _processed_datasets = get_processed_datasets()
        if dataset_name not in _processed_datasets:
            raise ValueError(
                "{} has not been processed, currently processed datasets: {}".
                format(
                    dataset_name,
                    ", ".join(_processed_datasets)
                    if _processed_datasets else "none",
                ))

        if config_id not in _processed_datasets[dataset_name]:
            raise ValueError(
                "Unrecognized config id, existing config ids: {}".format(
                    ", ".join(_processed_datasets[dataset_name])))

        if negatives_per_target < 0:
            negatives_per_target = 0
            logger.warning(
                "Number of negative samples per target should >= 0, reset to 0"
            )

        if not train and negatives_per_target > 0:
            logger.warning(
                "Negative samples are used for training, set negatives_per_target has no effect when testing"
            )

        dataset_dir = __warehouse__.joinpath(dataset_name, "processed",
                                             config_id,
                                             "dev" if development else "test")
        with open(dataset_dir.joinpath("stats.json"), "r") as f:
            self.stats = json.load(f)

        dataset_path = dataset_dir.joinpath(
            "train.pkl" if train else "test.pkl")
        with open(dataset_path, "rb") as f:
            self.dataset = pickle.load(f)  # list

        if train:
            counter = Counter()
            for data in self.dataset:
                if len(data) > 5:
                    counter.update(data[1] + data[2] + data[3])
                else:
                    counter.update(data[1] + data[2])
            self.item_counts = np.array(
                [counter[i] for i in range(max(counter.keys()) + 1)])

        if batch_size <= 0:
            raise ValueError("batch_size should >= 1")
        if batch_size > len(self.dataset):
            raise ValueError("batch_size exceeds the dataset size")

        self.batch_size = batch_size
        self.train = train
        self.include_timestamp = include_timestamp
        self.negatives_per_target = negatives_per_target
        self.drop_last = drop_last
        self._batch_idx = 0