Example #1
0
def _build_train_valid_test_datasets(
    cfg,
    trainer,
    data_prefix,
    data_impl,
    splits_string,
    train_valid_test_num_samples,
    seq_length,
    seed,
    skip_warmup,
    tokenizer,
):
    """Build train, valid, and test datasets."""

    # Indexed dataset.
    indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)

    total_num_of_documents = indexed_dataset.sizes.shape[0]
    splits = get_train_valid_test_split_(splits_string, total_num_of_documents)

    # Print stats about the splits.
    logging.info(' > dataset split:')

    def print_split_stats(name, index):
        logging.info('    {}:'.format(name))
        logging.info('     document indices in [{}, {}) total of {} '
                     'documents'.format(splits[index], splits[index + 1],
                                        splits[index + 1] - splits[index]))

    print_split_stats('train', 0)
    print_split_stats('validation', 1)
    print_split_stats('test', 2)

    def build_dataset(index, name):
        dataset = None
        if splits[index + 1] > splits[index]:
            documents = np.arange(start=splits[index],
                                  stop=splits[index + 1],
                                  step=1,
                                  dtype=np.int32)
            dataset = GPTDataset(
                cfg,
                trainer,
                tokenizer,
                name,
                data_prefix,
                documents,
                indexed_dataset,
                train_valid_test_num_samples[index],
                seq_length,
                seed,
            )
        return dataset

    train_dataset = build_dataset(0, 'train')
    valid_dataset = build_dataset(1, 'valid')
    test_dataset = build_dataset(2, 'test')

    return (train_dataset, valid_dataset, test_dataset)
Example #2
0
def build_mock_train_valid_test_datasets(
    cfg,
    trainer,
    splits_string,
    tokenizer,
    mock_data_size,
):
    """Build train, valid, and test datasets."""

    splits = get_train_valid_test_split_(splits_string, mock_data_size)

    # Print stats about the splits.
    logging.info(' > dataset split:')

    def print_split_stats(name, index):
        logging.info('    {}:'.format(name))
        logging.info('     document indices in [{}, {}) total of {} '
                     'documents'.format(splits[index], splits[index + 1],
                                        splits[index + 1] - splits[index]))

    print_split_stats('train', 0)
    print_split_stats('validation', 1)
    print_split_stats('test', 2)

    def build_dataset(index, name):
        dataset = None
        if splits[index + 1] > splits[index]:
            dataset = MockRETRODataset(
                cfg,
                trainer,
                tokenizer,
                name,
                splits[index + 1] - splits[index],
            )
        return dataset

    train_dataset = build_dataset(0, 'train')
    valid_dataset = build_dataset(1, 'valid')
    test_dataset = build_dataset(2, 'test')

    return (train_dataset, valid_dataset, test_dataset)
Example #3
0
def _build_train_valid_test_datasets(
        cfg,
        trainer,
        data_prefix,
        data_impl,
        splits_string,
        train_valid_test_num_samples,
        max_seq_length,
        masked_lm_prob,
        short_seq_prob,
        seed,
        skip_warmup,
        binary_head,
        max_seq_length_dec,
        dataset_type='standard_bert',
        tokenizer=None,
        max_ngram_size=3,
        mean_ngram_size=None,
        geometric_dist=True,
        permutation=False,
        whole_word_masking=True,
        favor_long_ngrams=False,
        delete_mask_prob=0,  # This flag is used in BART only, and will not have effect on T5/BERT
):

    if dataset_type not in DSET_TYPES:
        raise ValueError("Invalid dataset_type: ", dataset_type)

    # Indexed dataset.
    indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)

    if dataset_type == DSET_TYPE_ICT:
        title_dataset = get_indexed_dataset_(args.titles_data_path, data_impl,
                                             skip_warmup)

    # Get start and end indices of train/valid/train into doc-idx
    # Note that doc-idx is desinged to be num-docs + 1 so we can
    # easily iterate over it.
    total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
    splits = get_train_valid_test_split_(splits_string, total_num_of_documents)

    # Print stats about the splits.
    logging.info(' > dataset split:')

    def print_split_stats(name, index):
        logging.info('    {}:'.format(name))
        logging.info('     document indices in [{}, {}) total of {} '
                     'documents'.format(splits[index], splits[index + 1],
                                        splits[index + 1] - splits[index]))
        start_index = indexed_dataset.doc_idx[splits[index]]
        end_index = indexed_dataset.doc_idx[splits[index + 1]]
        logging.info('     sentence indices in [{}, {}) total of {} '
                     'sentences'.format(start_index, end_index,
                                        end_index - start_index))

    print_split_stats('train', 0)
    print_split_stats('validation', 1)
    print_split_stats('test', 2)

    def build_dataset(index, name):
        # from nemo.collections.nlp.data.language_modeling.megatron.ict_dataset import ICTDataset
        from nemo.collections.nlp.data.language_modeling.megatron.bert_dataset import BertDataset
        from nemo.collections.nlp.data.language_modeling.megatron.t5_dataset import T5Dataset
        from nemo.collections.nlp.data.language_modeling.megatron.bart_dataset import BARTDataset

        dataset = None
        if splits[index + 1] > splits[index]:
            # Get the pointer to the original doc-idx so we can set it later.
            doc_idx_ptr = indexed_dataset.get_doc_idx()
            # Slice the doc-idx
            start_index = splits[index]
            # Add +1 so we can index into the dataset to get the upper bound.
            end_index = splits[index + 1] + 1
            # New doc_idx view.
            indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
            # Build the dataset accordingly.
            kwargs = dict(
                name=name,
                data_prefix=data_prefix,
                num_epochs=None,
                max_num_samples=int(train_valid_test_num_samples[index]),
                max_seq_length=max_seq_length,
                seed=seed,
            )

            if dataset_type == DSET_TYPE_ICT:
                raise NotImplementedError(
                    "ICT dataset is not implemented yet.")
                '''
                dataset = ICTDataset(
                    block_dataset=indexed_dataset,
                    title_dataset=title_dataset,
                    query_in_block_prob=args.query_in_block_prob,
                    use_one_sent_docs=args.use_one_sent_docs,
                    binary_head=binary_head,
                    **kwargs,
                )
                '''
            elif dataset_type == DSET_TYPE_T5:
                assert tokenizer is not None, "Tokenizer is required for T5 dataset"
                logging.info("Instatiating T5 Dataset ...")
                dataset = T5Dataset(
                    cfg=cfg,
                    trainer=trainer,
                    tokenizer=tokenizer,
                    indexed_dataset=indexed_dataset,
                    masked_lm_prob=masked_lm_prob,
                    max_seq_length_dec=max_seq_length_dec,
                    short_seq_prob=short_seq_prob,
                    max_ngram_size=max_ngram_size,
                    mean_ngram_size=mean_ngram_size,
                    geometric_dist=geometric_dist,
                    permutation=permutation,
                    whole_word_masking=whole_word_masking,
                    favor_long_ngrams=favor_long_ngrams,
                    **kwargs,
                )
            elif dataset_type == DSET_TYPE_BERT:
                logging.info("Instatiating BERT Dataset ...")
                dataset = BertDataset(
                    cfg=cfg,
                    indexed_dataset=indexed_dataset,
                    masked_lm_prob=masked_lm_prob,
                    short_seq_prob=short_seq_prob,
                    binary_head=binary_head,
                    tokenizer=tokenizer,
                    **kwargs,
                )
            elif dataset_type == DSET_TYPE_T5_LM:
                documents = np.arange(start=splits[index],
                                      stop=splits[index + 1],
                                      step=1,
                                      dtype=np.int32)
                logging.info("Instatiating T5 Prefix-LM Dataset ...")
                dataset = T5LMAdaptedDataset(
                    cfg=cfg,
                    trainer=trainer,
                    tokenizer=tokenizer,
                    documents=documents,
                    indexed_dataset=indexed_dataset,
                    num_samples=int(train_valid_test_num_samples[index]),
                    **kwargs,
                )
            elif dataset_type == DSET_TYPE_BART:
                assert tokenizer is not None, "Tokenizer is required for BART dataset"
                logging.info("Instatiating BART Dataset ...")
                dataset = BARTDataset(
                    cfg=cfg,
                    trainer=trainer,
                    tokenizer=tokenizer,
                    indexed_dataset=indexed_dataset,
                    masked_lm_prob=masked_lm_prob,
                    short_seq_prob=short_seq_prob,
                    max_ngram_size=max_ngram_size,
                    mean_ngram_size=mean_ngram_size,
                    geometric_dist=geometric_dist,
                    permutation=permutation,
                    whole_word_masking=whole_word_masking,
                    favor_long_ngrams=favor_long_ngrams,
                    delete_mask_prob=delete_mask_prob,
                    **kwargs,
                )
            else:
                raise NotImplementedError(
                    "Dataset type not fully implemented.")

            # Set the original pointer so dataset remains the main dataset.
            indexed_dataset.set_doc_idx(doc_idx_ptr)
            # Checks.
            assert indexed_dataset.doc_idx[0] == 0
            assert indexed_dataset.doc_idx.shape[0] == (
                total_num_of_documents + 1)
        return dataset

    train_dataset = build_dataset(0, 'train')
    valid_dataset = build_dataset(1, 'valid')
    test_dataset = build_dataset(2, 'test')

    return (train_dataset, valid_dataset, test_dataset)