def instantiate_pipeline(dataset, data_dir, params, constructor_type=None, deterministic=False): # type: (str, str, dict, typing.Optional[str], bool) -> (NCFDataset, typing.Callable) """Load and digest data CSV into a usable form. Args: dataset: The name of the dataset to be used. data_dir: The root directory of the dataset. params: dict of parameters for the run. constructor_type: The name of the constructor subclass that should be used for the input pipeline. deterministic: Tell the data constructor to produce deterministically. """ tf.logging.info("Beginning data preprocessing.") st = timeit.default_timer() raw_rating_path = os.path.join(data_dir, dataset, movielens.RATINGS_FILE) cache_path = os.path.join(data_dir, dataset, rconst.RAW_CACHE_FILE) raw_data, _ = _filter_index_sort(raw_rating_path, cache_path) user_map, item_map = raw_data["user_map"], raw_data["item_map"] num_users, num_items = DATASET_TO_NUM_USERS_AND_ITEMS[dataset] if num_users != len(user_map): raise ValueError("Expected to find {} users, but found {}".format( num_users, len(user_map))) if num_items != len(item_map): raise ValueError("Expected to find {} items, but found {}".format( num_items, len(item_map))) producer = data_pipeline.get_constructor( constructor_type or "materialized")(maximum_number_epochs=params["train_epochs"], num_users=num_users, num_items=num_items, user_map=user_map, item_map=item_map, train_pos_users=raw_data[rconst.TRAIN_USER_KEY], train_pos_items=raw_data[rconst.TRAIN_ITEM_KEY], train_batch_size=params["batch_size"], batches_per_train_step=params["batches_per_step"], num_train_negatives=params["num_neg"], eval_pos_users=raw_data[rconst.EVAL_USER_KEY], eval_pos_items=raw_data[rconst.EVAL_ITEM_KEY], eval_batch_size=params["eval_batch_size"], batches_per_eval_step=params["batches_per_step"], stream_files=params["use_tpu"], deterministic=deterministic) run_time = timeit.default_timer() - st tf.logging.info( "Data preprocessing complete. Time: {:.1f} sec.".format(run_time)) print(producer) return num_users, num_items, producer
def instantiate_pipeline(dataset, data_dir, params, constructor_type=None, deterministic=False, epoch_dir=None): # type: (str, str, dict, typing.Optional[str], bool, typing.Optional[str]) -> (int, int, data_pipeline.BaseDataConstructor) """Load and digest data CSV into a usable form. Args: dataset: The name of the dataset to be used. data_dir: The root directory of the dataset. params: dict of parameters for the run. constructor_type: The name of the constructor subclass that should be used for the input pipeline. deterministic: Tell the data constructor to produce deterministically. epoch_dir: Directory in which to store the training epochs. """ logging.info("Beginning data preprocessing.") st = timeit.default_timer() raw_rating_path = os.path.join(data_dir, dataset, movielens.RATINGS_FILE) cache_path = os.path.join(data_dir, dataset, rconst.RAW_CACHE_FILE) raw_data, _ = _filter_index_sort(raw_rating_path, cache_path) user_map, item_map = raw_data["user_map"], raw_data["item_map"] num_users, num_items = DATASET_TO_NUM_USERS_AND_ITEMS[dataset] if num_users != len(user_map): raise ValueError("Expected to find {} users, but found {}".format( num_users, len(user_map))) if num_items != len(item_map): raise ValueError("Expected to find {} items, but found {}".format( num_items, len(item_map))) producer = data_pipeline.get_constructor(constructor_type or "materialized")( maximum_number_epochs=params["train_epochs"], num_users=num_users, num_items=num_items, user_map=user_map, item_map=item_map, train_pos_users=raw_data[rconst.TRAIN_USER_KEY], train_pos_items=raw_data[rconst.TRAIN_ITEM_KEY], train_batch_size=params["batch_size"], batches_per_train_step=params["batches_per_step"], num_train_negatives=params["num_neg"], eval_pos_users=raw_data[rconst.EVAL_USER_KEY], eval_pos_items=raw_data[rconst.EVAL_ITEM_KEY], eval_batch_size=params["eval_batch_size"], batches_per_eval_step=params["batches_per_step"], stream_files=params["use_tpu"], deterministic=deterministic, epoch_dir=epoch_dir ) run_time = timeit.default_timer() - st logging.info("Data preprocessing complete. Time: {:.1f} sec." .format(run_time)) print(producer) return num_users, num_items, producer