def get_pretrain_data_text(data, batch_size, num_ctxes, shuffle, num_buckets, vocab, tokenizer, max_seq_length, short_seq_prob, masked_lm_prob, max_predictions_per_seq, whole_word_mask, num_parts=1, part_idx=0, num_workers=1): """Get a data iterator from raw text documents. Parameters ---------- batch_size : int The batch size per GPU. num_ctxes : int The number of GPUs. shuffle : bool Whether to shuffle the data. num_buckets : int The number of buckets for the FixedBucketSampler for training. vocab : BERTVocab The vocabulary. tokenizer : BERTTokenizer or BERTSPTokenizer The tokenizer. max_seq_length : int The hard limit of maximum sequence length of sentence pairs. short_seq_prob : float The probability of sampling sequences shorter than the max_seq_length. masked_lm_prob : float The probability of replacing texts with masks/random words/original words. max_predictions_per_seq : int The hard limit of the number of predictions for masked words whole_word_mask : bool Whether to use whole word masking. num_parts : int The number of partitions for the dataset. part_idx : int The index of the partition to read. num_workers : int The number of worker processes for dataset contruction. """ num_files = len(nlp.utils.glob(data)) logging.info('%d files are found.', num_files) assert num_files >= num_parts, \ 'The number of training text files must be no less than the number of ' \ 'workers/partitions (%d). Only %d files at %s are found.'%(num_parts, num_files, data) dataset_params = {'tokenizer': tokenizer, 'max_seq_length': max_seq_length, 'short_seq_prob': short_seq_prob, 'masked_lm_prob': masked_lm_prob, 'max_predictions_per_seq': max_predictions_per_seq, 'vocab':vocab, 'whole_word_mask': whole_word_mask} dataset_fn = SimpleDatasetFn(BERTPretrainDataset, dataset_params) sampler_fn = BERTSamplerFn(batch_size, shuffle, num_ctxes, num_buckets) dataloader_fn = BERTDataLoaderFn(num_ctxes, vocab) split_sampler = nlp.data.SplitSampler(num_files, num_parts=num_parts, part_index=part_idx) dataloader = DatasetLoader(data, split_sampler, dataset_fn, sampler_fn, dataloader_fn, num_dataset_workers=num_workers) return dataloader
def get_pretrain_data_npz(data, batch_size, num_ctxes, shuffle, num_buckets, vocab, num_parts=1, part_idx=0, num_workers=1): """Get a data iterator from pre-processed npz files. Parameters ---------- batch_size : int The batch size per GPU. num_ctxes : int The number of GPUs. shuffle : bool Whether to shuffle the data. num_buckets : int The number of buckets for the FixedBucketSampler for training. vocab : BERTVocab The vocabulary. num_parts : int The number of partitions for the dataset. part_idx : int The index of the partition to read. num_workers : int The number of worker processes for dataset contruction. """ num_files = len(nlp.utils.glob(data)) logging.info('%d files are found.', num_files) assert num_files >= num_parts, \ 'The number of training text files must be no less than the number of ' \ 'workers/partitions (%d). Only %d files at %s are found.'%(num_parts, num_files, data) #split_sampler = nlp.data.SplitSampler(num_files, num_parts=num_parts, part_index=part_idx) dataset_params = {'allow_pickle': True} dataset_fn = SimpleDatasetFn(nlp.data.NumpyDataset, dataset_params) sampler_fn = BERTSamplerFn(batch_size, shuffle, num_ctxes, num_buckets) dataloader_fn = BERTDataLoaderFn(num_ctxes, vocab) split_sampler = nlp.data.SplitSampler(num_files, num_parts=num_parts, part_index=part_idx) dataloader = DatasetLoader(data, split_sampler, dataset_fn, sampler_fn, dataloader_fn, num_dataset_workers=num_workers) return dataloader