示例#1
0
def get_pretrain_data_text(data, batch_size, num_ctxes, shuffle,
                           num_buckets, vocab, tokenizer, max_seq_length, short_seq_prob,
                           masked_lm_prob, max_predictions_per_seq, whole_word_mask,
                           num_parts=1, part_idx=0, num_workers=1):
    """Get a data iterator from raw text documents.

    Parameters
    ----------
    batch_size : int
        The batch size per GPU.
    num_ctxes : int
        The number of GPUs.
    shuffle : bool
        Whether to shuffle the data.
    num_buckets : int
        The number of buckets for the FixedBucketSampler for training.
    vocab : BERTVocab
        The vocabulary.
    tokenizer : BERTTokenizer or BERTSPTokenizer
        The tokenizer.
    max_seq_length : int
        The hard limit of maximum sequence length of sentence pairs.
    short_seq_prob : float
        The probability of sampling sequences shorter than the max_seq_length.
    masked_lm_prob : float
        The probability of replacing texts with masks/random words/original words.
    max_predictions_per_seq : int
        The hard limit of the number of predictions for masked words
    whole_word_mask : bool
        Whether to use whole word masking.
    num_parts : int
        The number of partitions for the dataset.
    part_idx : int
        The index of the partition to read.
    num_workers : int
        The number of worker processes for dataset contruction.
    """
    num_files = len(nlp.utils.glob(data))
    logging.info('%d files are found.', num_files)
    assert num_files >= num_parts, \
        'The number of training text files must be no less than the number of ' \
        'workers/partitions (%d). Only %d files at %s are found.'%(num_parts, num_files, data)
    dataset_params = {'tokenizer': tokenizer, 'max_seq_length': max_seq_length,
                      'short_seq_prob': short_seq_prob, 'masked_lm_prob': masked_lm_prob,
                      'max_predictions_per_seq': max_predictions_per_seq, 'vocab':vocab,
                      'whole_word_mask': whole_word_mask}
    dataset_fn = SimpleDatasetFn(BERTPretrainDataset, dataset_params)
    sampler_fn = BERTSamplerFn(batch_size, shuffle, num_ctxes, num_buckets)
    dataloader_fn = BERTDataLoaderFn(num_ctxes, vocab)

    split_sampler = nlp.data.SplitSampler(num_files, num_parts=num_parts, part_index=part_idx)
    dataloader = DatasetLoader(data, split_sampler, dataset_fn, sampler_fn, dataloader_fn,
                               num_dataset_workers=num_workers)
    return dataloader
示例#2
0
def get_pretrain_data_npz(data,
                          batch_size,
                          num_ctxes,
                          shuffle,
                          num_buckets,
                          vocab,
                          num_parts=1,
                          part_idx=0,
                          num_workers=1):
    """Get a data iterator from pre-processed npz files.

    Parameters
    ----------
    batch_size : int
        The batch size per GPU.
    num_ctxes : int
        The number of GPUs.
    shuffle : bool
        Whether to shuffle the data.
    num_buckets : int
        The number of buckets for the FixedBucketSampler for training.
    vocab : BERTVocab
        The vocabulary.
    num_parts : int
        The number of partitions for the dataset.
    part_idx : int
        The index of the partition to read.
    num_workers : int
        The number of worker processes for dataset contruction.
    """
    num_files = len(nlp.utils.glob(data))
    logging.info('%d files are found.', num_files)
    assert num_files >= num_parts, \
        'The number of training text files must be no less than the number of ' \
        'workers/partitions (%d). Only %d files at %s are found.'%(num_parts, num_files, data)
    #split_sampler = nlp.data.SplitSampler(num_files, num_parts=num_parts, part_index=part_idx)
    dataset_params = {'allow_pickle': True}
    dataset_fn = SimpleDatasetFn(nlp.data.NumpyDataset, dataset_params)
    sampler_fn = BERTSamplerFn(batch_size, shuffle, num_ctxes, num_buckets)
    dataloader_fn = BERTDataLoaderFn(num_ctxes, vocab)

    split_sampler = nlp.data.SplitSampler(num_files,
                                          num_parts=num_parts,
                                          part_index=part_idx)
    dataloader = DatasetLoader(data,
                               split_sampler,
                               dataset_fn,
                               sampler_fn,
                               dataloader_fn,
                               num_dataset_workers=num_workers)
    return dataloader