Пример #1
0
def get_pretrain_data_npz(data,
                          batch_size,
                          shuffle,
                          num_buckets,
                          vocab,
                          num_parts=1,
                          part_idx=0,
                          num_dataset_workers=1,
                          num_batch_workers=1,
                          circle_length=1,
                          repeat=1,
                          dataset_cached=False,
                          num_max_dataset_cached=0):
    """Get a data iterator from pre-processed npz files.

    Parameters
    ----------
    data: str
        The path to the dataset directory
    batch_size : int
        The batch size per GPU.
    shuffle : bool
        Whether to shuffle the data.
    num_buckets : int
        The number of buckets for the FixedBucketSampler for training.
    vocab : Vocab
        The vocabulary.
    num_parts : int
        The number of partitions for the dataset.
    part_idx : int
        The index of the partition to read.
    num_dataset_workers : int
        The number of worker processes for dataset construction.
    num_batch_workers : int
        The number of worker processes for batch contruction.
    circle_length : int, default is 1
        The number of files to be read for a single worker at the same time.
        When circle_length is larger than 1, we merge circle_length files.
    repeat : int, default is 1
        The number of times that files are repeated.
    dataset_cached : bool, default is False
        Whether or not to cache last processed dataset.
        Each processed dataset can only be cached for once.
        When there is no new available processed dataset to be fetched,
        we pop a cached processed dataset.
    num_max_dataset_cached : int, default is 0
        Maximum number of cached datasets. It is valid only if dataset_cached is True
    """
    num_files = len(glob(data))
    logging.info('%d files are found.', num_files)
    assert num_files >= num_parts, \
        'The number of text files must be no less than the number of ' \
        'workers/partitions (%d). Only %d files at %s are found.' % (num_parts, num_files, data)
    split_sampler = SplitSampler(num_files,
                                 num_parts=num_parts,
                                 part_index=part_idx,
                                 repeat=repeat)
    dataset_fn = prepare_pretrain_npz_dataset
    sampler_fn = prepare_pretrain_bucket_sampler
    dataset_params = {'allow_pickle': True}
    sampler_params = {
        'batch_size': batch_size,
        'shuffle': shuffle,
        'num_buckets': num_buckets
    }
    batchify_fn = bf.Tuple(
        bf.Pad(val=vocab.pad_id),  # input_ids
        bf.Pad(val=0),  # segment_ids
        bf.Stack(),  # valid_lengths
    )
    dataloader = DatasetLoader(data,
                               file_sampler=split_sampler,
                               dataset_fn=dataset_fn,
                               batch_sampler_fn=sampler_fn,
                               dataset_params=dataset_params,
                               batch_sampler_params=sampler_params,
                               batchify_fn=batchify_fn,
                               num_dataset_workers=num_dataset_workers,
                               num_batch_workers=num_batch_workers,
                               pin_memory=False,
                               circle_length=circle_length)
    return dataloader
Пример #2
0
def get_pretrain_data_text(data,
                           batch_size,
                           shuffle,
                           num_buckets,
                           tokenizer,
                           vocab,
                           max_seq_length,
                           short_seq_prob=0.05,
                           num_parts=1,
                           part_idx=0,
                           num_dataset_workers=1,
                           num_batch_workers=1,
                           circle_length=1,
                           repeat=1,
                           cached_file_path=None):
    """Get a data iterator from raw text documents.

    Parameters
    ----------
    batch_size : int
        The batch size per GPU.
    shuffle : bool
        Whether to shuffle the data.
    num_buckets : int
        The number of buckets for the FixedBucketSampler for training.
    vocab : Vocab
        The vocabulary.
    tokenizer : HuggingFaceWordPieceTokenizer or SentencepieceTokenizer
        The tokenizer.
    max_seq_length : int
        The hard limit of maximum sequence length of sentence pairs.
    short_seq_prob : float
        The probability of sampling sequences shorter than the max_seq_length.
    num_parts : int
        The number of partitions for the dataset.
    part_idx : int
        The index of the partition to read.
    num_dataset_workers : int
        The number of worker processes for dataset construction.
    num_batch_workers : int
        The number of worker processes for batch construction.
    circle_length : int, default is 1
        The number of files to be read for a single worker at the same time.
        When circle_length is larger than 1, we merge circle_length files.
    repeat : int, default is 1
        The number of times that files are repeated.
    cached_file_path: str, default is None
        Directory for saving preprocessed features
    """
    num_files = len(glob(data))
    logging.info('%d files are found.', num_files)
    assert num_files >= num_parts, \
        'The number of text files must be no less than the number of ' \
        'workers/partitions (%d). Only %d files at %s are found.' % (num_parts, num_files, data)
    split_sampler = SplitSampler(num_files,
                                 num_parts=num_parts,
                                 part_index=part_idx,
                                 repeat=repeat)
    dataset_fn = prepare_pretrain_text_dataset
    sampler_fn = prepare_pretrain_bucket_sampler
    dataset_params = {
        'tokenizer': tokenizer,
        'max_seq_length': max_seq_length,
        'short_seq_prob': short_seq_prob,
        'cached_file_path': cached_file_path
    }
    sampler_params = {
        'batch_size': batch_size,
        'shuffle': shuffle,
        'num_buckets': num_buckets
    }
    batchify_fn = bf.Tuple(
        bf.Pad(val=vocab.pad_id),  # input_ids
        bf.Pad(val=0),  # segment_ids
        bf.Stack(),  # valid_lengths
    )

    dataloader = DatasetLoader(data,
                               file_sampler=split_sampler,
                               dataset_fn=dataset_fn,
                               batch_sampler_fn=sampler_fn,
                               dataset_params=dataset_params,
                               batch_sampler_params=sampler_params,
                               batchify_fn=batchify_fn,
                               num_dataset_workers=num_dataset_workers,
                               num_batch_workers=num_batch_workers,
                               pin_memory=False,
                               circle_length=circle_length)
    return dataloader
Пример #3
0
def get_pretrain_data_text(data, batch_size, shuffle, num_buckets, vocab, tokenizer,
                           max_seq_length, short_seq_prob, masked_lm_prob,
                           max_predictions_per_seq, whole_word_mask, random_next_sentence,
                           num_parts=1, part_idx=0, num_dataset_workers=1, num_batch_workers=1,
                           circle_length=1, repeat=1,
                           dataset_cached=False, num_max_dataset_cached=0):
    """Get a data iterator from raw text documents.

    Parameters
    ----------
    batch_size : int
        The batch size per GPU.
    shuffle : bool
        Whether to shuffle the data.
    num_buckets : int
        The number of buckets for the FixedBucketSampler for training.
    vocab : Vocab
        The vocabulary.
    tokenizer : BaseTokenizer
        The tokenizer.
    max_seq_length : int
        The hard limit of maximum sequence length of sentence pairs.
    short_seq_prob : float
        The probability of sampling sequences shorter than the max_seq_length.
    masked_lm_prob : float
        The probability of replacing texts with masks/random words/original words.
    max_predictions_per_seq : int
        The hard limit of the number of predictions for masked words
    whole_word_mask : bool
        Whether to use whole word masking.
    num_parts : int
        The number of partitions for the dataset.
    part_idx : int
        The index of the partition to read.
    num_dataset_workers : int
        The number of worker processes for dataset construction.
    num_batch_workers : int
        The number of worker processes for batch construction.
    circle_length : int, default is 1
        The number of files to be read for a single worker at the same time.
        When circle_length is larger than 1, we merge circle_length files.
    repeat : int, default is 1
        The number of times that files are repeated.
    dataset_cached : bool, default is False
        Whether or not to cache last processed dataset.
        Each processed dataset can only be cached for once.
        When there is no new available processed dataset to be fetched,
        we pop a cached processed dataset.
    num_max_dataset_cached : int, default is 0
        Maximum number of cached datasets. It is valid only if dataset_cached is True
    """
    num_files = len(glob(data))
    logging.info('%d files are found.', num_files)
    assert num_files >= num_parts, \
        'The number of text files must be no less than the number of ' \
        'workers/partitions (%d). Only %d files at %s are found.'%(num_parts, num_files, data)
    dataset_params = {'tokenizer': tokenizer, 'max_seq_length': max_seq_length,
                      'short_seq_prob': short_seq_prob, 'masked_lm_prob': masked_lm_prob,
                      'max_predictions_per_seq': max_predictions_per_seq, 'vocab':vocab,
                      'whole_word_mask': whole_word_mask, 'random_next_sentence': random_next_sentence}
    sampler_params = {'batch_size': batch_size, 'shuffle': shuffle, 'num_buckets': num_buckets}
    dataset_fn = prepare_pretrain_text_dataset
    sampler_fn = prepare_pretrain_bucket_sampler
    pad_val = vocab.pad_id
    batchify_fn = bf.Tuple(
        bf.Pad(val=pad_val, round_to=8),  # input_id
        bf.Pad(val=pad_val),  # masked_id
        bf.Pad(val=0),  # masked_position
        bf.Pad(val=0),  # masked_weight
        bf.Stack(),  # next_sentence_label
        bf.Pad(val=0, round_to=8),  # segment_id
        bf.Stack())  # valid_lengths
    split_sampler = SplitSampler(num_files, num_parts=num_parts,
                                 part_index=part_idx, repeat=repeat)
    dataloader = DatasetLoader(data,
                               file_sampler=split_sampler,
                               dataset_fn=dataset_fn,
                               batch_sampler_fn=sampler_fn,
                               dataset_params=dataset_params,
                               batch_sampler_params=sampler_params,
                               batchify_fn=batchify_fn,
                               num_dataset_workers=num_dataset_workers,
                               num_batch_workers=num_batch_workers,
                               pin_memory=False,
                               circle_length=circle_length,
                               dataset_cached=dataset_cached,
                               num_max_dataset_cached=num_max_dataset_cached)
    return dataloader