def get_pretrain_data_npz(data, batch_size, shuffle, num_buckets, vocab, num_parts=1, part_idx=0, num_dataset_workers=1, num_batch_workers=1, circle_length=1, repeat=1, dataset_cached=False, num_max_dataset_cached=0): """Get a data iterator from pre-processed npz files. Parameters ---------- data: str The path to the dataset directory batch_size : int The batch size per GPU. shuffle : bool Whether to shuffle the data. num_buckets : int The number of buckets for the FixedBucketSampler for training. vocab : Vocab The vocabulary. num_parts : int The number of partitions for the dataset. part_idx : int The index of the partition to read. num_dataset_workers : int The number of worker processes for dataset construction. num_batch_workers : int The number of worker processes for batch contruction. circle_length : int, default is 1 The number of files to be read for a single worker at the same time. When circle_length is larger than 1, we merge circle_length files. repeat : int, default is 1 The number of times that files are repeated. dataset_cached : bool, default is False Whether or not to cache last processed dataset. Each processed dataset can only be cached for once. When there is no new available processed dataset to be fetched, we pop a cached processed dataset. num_max_dataset_cached : int, default is 0 Maximum number of cached datasets. It is valid only if dataset_cached is True """ num_files = len(glob(data)) logging.info('%d files are found.', num_files) assert num_files >= num_parts, \ 'The number of text files must be no less than the number of ' \ 'workers/partitions (%d). Only %d files at %s are found.' % (num_parts, num_files, data) split_sampler = SplitSampler(num_files, num_parts=num_parts, part_index=part_idx, repeat=repeat) dataset_fn = prepare_pretrain_npz_dataset sampler_fn = prepare_pretrain_bucket_sampler dataset_params = {'allow_pickle': True} sampler_params = { 'batch_size': batch_size, 'shuffle': shuffle, 'num_buckets': num_buckets } batchify_fn = bf.Tuple( bf.Pad(val=vocab.pad_id), # input_ids bf.Pad(val=0), # segment_ids bf.Stack(), # valid_lengths ) dataloader = DatasetLoader(data, file_sampler=split_sampler, dataset_fn=dataset_fn, batch_sampler_fn=sampler_fn, dataset_params=dataset_params, batch_sampler_params=sampler_params, batchify_fn=batchify_fn, num_dataset_workers=num_dataset_workers, num_batch_workers=num_batch_workers, pin_memory=False, circle_length=circle_length) return dataloader
def get_pretrain_data_text(data, batch_size, shuffle, num_buckets, tokenizer, vocab, max_seq_length, short_seq_prob=0.05, num_parts=1, part_idx=0, num_dataset_workers=1, num_batch_workers=1, circle_length=1, repeat=1, cached_file_path=None): """Get a data iterator from raw text documents. Parameters ---------- batch_size : int The batch size per GPU. shuffle : bool Whether to shuffle the data. num_buckets : int The number of buckets for the FixedBucketSampler for training. vocab : Vocab The vocabulary. tokenizer : HuggingFaceWordPieceTokenizer or SentencepieceTokenizer The tokenizer. max_seq_length : int The hard limit of maximum sequence length of sentence pairs. short_seq_prob : float The probability of sampling sequences shorter than the max_seq_length. num_parts : int The number of partitions for the dataset. part_idx : int The index of the partition to read. num_dataset_workers : int The number of worker processes for dataset construction. num_batch_workers : int The number of worker processes for batch construction. circle_length : int, default is 1 The number of files to be read for a single worker at the same time. When circle_length is larger than 1, we merge circle_length files. repeat : int, default is 1 The number of times that files are repeated. cached_file_path: str, default is None Directory for saving preprocessed features """ num_files = len(glob(data)) logging.info('%d files are found.', num_files) assert num_files >= num_parts, \ 'The number of text files must be no less than the number of ' \ 'workers/partitions (%d). Only %d files at %s are found.' % (num_parts, num_files, data) split_sampler = SplitSampler(num_files, num_parts=num_parts, part_index=part_idx, repeat=repeat) dataset_fn = prepare_pretrain_text_dataset sampler_fn = prepare_pretrain_bucket_sampler dataset_params = { 'tokenizer': tokenizer, 'max_seq_length': max_seq_length, 'short_seq_prob': short_seq_prob, 'cached_file_path': cached_file_path } sampler_params = { 'batch_size': batch_size, 'shuffle': shuffle, 'num_buckets': num_buckets } batchify_fn = bf.Tuple( bf.Pad(val=vocab.pad_id), # input_ids bf.Pad(val=0), # segment_ids bf.Stack(), # valid_lengths ) dataloader = DatasetLoader(data, file_sampler=split_sampler, dataset_fn=dataset_fn, batch_sampler_fn=sampler_fn, dataset_params=dataset_params, batch_sampler_params=sampler_params, batchify_fn=batchify_fn, num_dataset_workers=num_dataset_workers, num_batch_workers=num_batch_workers, pin_memory=False, circle_length=circle_length) return dataloader
def get_pretrain_data_text(data, batch_size, shuffle, num_buckets, vocab, tokenizer, max_seq_length, short_seq_prob, masked_lm_prob, max_predictions_per_seq, whole_word_mask, random_next_sentence, num_parts=1, part_idx=0, num_dataset_workers=1, num_batch_workers=1, circle_length=1, repeat=1, dataset_cached=False, num_max_dataset_cached=0): """Get a data iterator from raw text documents. Parameters ---------- batch_size : int The batch size per GPU. shuffle : bool Whether to shuffle the data. num_buckets : int The number of buckets for the FixedBucketSampler for training. vocab : Vocab The vocabulary. tokenizer : BaseTokenizer The tokenizer. max_seq_length : int The hard limit of maximum sequence length of sentence pairs. short_seq_prob : float The probability of sampling sequences shorter than the max_seq_length. masked_lm_prob : float The probability of replacing texts with masks/random words/original words. max_predictions_per_seq : int The hard limit of the number of predictions for masked words whole_word_mask : bool Whether to use whole word masking. num_parts : int The number of partitions for the dataset. part_idx : int The index of the partition to read. num_dataset_workers : int The number of worker processes for dataset construction. num_batch_workers : int The number of worker processes for batch construction. circle_length : int, default is 1 The number of files to be read for a single worker at the same time. When circle_length is larger than 1, we merge circle_length files. repeat : int, default is 1 The number of times that files are repeated. dataset_cached : bool, default is False Whether or not to cache last processed dataset. Each processed dataset can only be cached for once. When there is no new available processed dataset to be fetched, we pop a cached processed dataset. num_max_dataset_cached : int, default is 0 Maximum number of cached datasets. It is valid only if dataset_cached is True """ num_files = len(glob(data)) logging.info('%d files are found.', num_files) assert num_files >= num_parts, \ 'The number of text files must be no less than the number of ' \ 'workers/partitions (%d). Only %d files at %s are found.'%(num_parts, num_files, data) dataset_params = {'tokenizer': tokenizer, 'max_seq_length': max_seq_length, 'short_seq_prob': short_seq_prob, 'masked_lm_prob': masked_lm_prob, 'max_predictions_per_seq': max_predictions_per_seq, 'vocab':vocab, 'whole_word_mask': whole_word_mask, 'random_next_sentence': random_next_sentence} sampler_params = {'batch_size': batch_size, 'shuffle': shuffle, 'num_buckets': num_buckets} dataset_fn = prepare_pretrain_text_dataset sampler_fn = prepare_pretrain_bucket_sampler pad_val = vocab.pad_id batchify_fn = bf.Tuple( bf.Pad(val=pad_val, round_to=8), # input_id bf.Pad(val=pad_val), # masked_id bf.Pad(val=0), # masked_position bf.Pad(val=0), # masked_weight bf.Stack(), # next_sentence_label bf.Pad(val=0, round_to=8), # segment_id bf.Stack()) # valid_lengths split_sampler = SplitSampler(num_files, num_parts=num_parts, part_index=part_idx, repeat=repeat) dataloader = DatasetLoader(data, file_sampler=split_sampler, dataset_fn=dataset_fn, batch_sampler_fn=sampler_fn, dataset_params=dataset_params, batch_sampler_params=sampler_params, batchify_fn=batchify_fn, num_dataset_workers=num_dataset_workers, num_batch_workers=num_batch_workers, pin_memory=False, circle_length=circle_length, dataset_cached=dataset_cached, num_max_dataset_cached=num_max_dataset_cached) return dataloader