def get_text_classification(dataset_name, config): return datasets.random_slice_text_data(dataset_name=dataset_name, batch_size=config["bs"], num_train=config["num_train"], patch_length=config["patch_length"], cache_dataset=True, num_per_valid=3000)
def lm1b_byte(batch_size, patch_length): return datasets.random_slice_text_data( dataset_name="lm1b/bytes", batch_size=batch_size, patch_length=patch_length, cache_dataset=True, shuffle_buffer=100000)
def _make(config): return _make_just_train( datasets.random_slice_text_data( dataset_name=dataset_name, batch_size=config["batch_size"], patch_length=config["patch_length"], num_train=config["num_train"], cache_dataset=True, num_per_valid=10000, shuffle_buffer=10000, ), config["just_train"])
def get_byte_dataset(config, name): """Return the Datasets object for the corresponding config.""" return _make_just_train( datasets.random_slice_text_data( dataset_name=name, batch_size=config["batch_size"], patch_length=config["patch_length"], num_per_valid=3000, shuffle_buffer=10000, cache_dataset=True, num_train=config["num_train"], ), config["just_train"])
def imdb_subword(batch_size, patch_length): return datasets.random_slice_text_data( dataset_name="imdb_reviews/subwords8k", batch_size=batch_size, cache_dataset=True, patch_length=patch_length)