Esempio n. 1
0
def build_train_valid_test_datasets(
    cfg, trainer, data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup
):
    """Build train, valid, and test datasets."""

    # Single dataset.
    if len(data_prefix) == 1:
        return _build_train_valid_test_datasets(
            cfg,
            trainer,
            data_prefix[0],
            data_impl,
            splits_string,
            train_valid_test_num_samples,
            seq_length,
            seed,
            skip_warmup,
        )

    # Blending dataset.
    # Parse the values.
    output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
    prefixes, weights, datasets_train_valid_test_num_samples = output

    # Build individual datasets.
    train_datasets = []
    valid_datasets = []
    test_datasets = []
    for i in range(len(prefixes)):
        train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
            cfg,
            trainer,
            prefixes[i],
            data_impl,
            splits_string,
            datasets_train_valid_test_num_samples[i],
            seq_length,
            seed,
            skip_warmup,
        )
        if train_ds:
            train_datasets.append(train_ds)
        if valid_ds:
            valid_datasets.append(valid_ds)
        if test_ds:
            test_datasets.append(test_ds)

    # Blend.
    blending_train_dataset = None
    if train_datasets:
        blending_train_dataset = BlendableDataset(train_datasets, weights)
    blending_valid_dataset = None
    if valid_datasets:
        blending_valid_dataset = BlendableDataset(valid_datasets, weights)
    blending_test_dataset = None
    if test_datasets:
        blending_test_dataset = BlendableDataset(test_datasets, weights)

    return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
Esempio n. 2
0
def build_train_valid_test_datasets(
    cfg,
    trainer,
    data_prefix,
    data_impl,
    splits_string,
    train_valid_test_num_samples,
    max_seq_length,
    masked_lm_prob,
    short_seq_prob,
    seed,
    skip_warmup,
    binary_head=False,
    max_seq_length_dec=None,
    dataset_type='standard_bert',
    tokenizer=None,
):

    if len(data_prefix) == 1:
        return _build_train_valid_test_datasets(
            cfg,
            trainer,
            data_prefix[0],
            data_impl,
            splits_string,
            train_valid_test_num_samples,
            max_seq_length,
            masked_lm_prob,
            short_seq_prob,
            seed,
            skip_warmup,
            binary_head,
            max_seq_length_dec,
            dataset_type=dataset_type,
            tokenizer=tokenizer,
        )
    # Blending dataset.
    # Parse the values.
    output = get_datasets_weights_and_num_samples(
        data_prefix, train_valid_test_num_samples)
    prefixes, weights, datasets_train_valid_test_num_samples = output

    # Build individual datasets.
    train_datasets = []
    valid_datasets = []
    test_datasets = []
    for i in range(len(prefixes)):
        train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
            cfg,
            trainer,
            prefixes[i],
            data_impl,
            splits_string,
            datasets_train_valid_test_num_samples[i],
            max_seq_length,
            masked_lm_prob,
            short_seq_prob,
            seed,
            skip_warmup,
            binary_head,
            max_seq_length_dec,
            dataset_type=dataset_type,
            tokenizer=tokenizer,
        )
        if train_ds:
            train_datasets.append(train_ds)
        if valid_ds:
            valid_datasets.append(valid_ds)
        if test_ds:
            test_datasets.append(test_ds)

        # Blend.
    blending_train_dataset = None
    if train_datasets:
        blending_train_dataset = BlendableDataset(train_datasets, weights)
    blending_valid_dataset = None
    if valid_datasets:
        blending_valid_dataset = BlendableDataset(valid_datasets, weights)
    blending_test_dataset = None
    if test_datasets:
        blending_test_dataset = BlendableDataset(test_datasets, weights)

    return (blending_train_dataset, blending_valid_dataset,
            blending_test_dataset)
Esempio n. 3
0
def build_train_valid_test_datasets(
    cfg,
    trainer,
    data_prefix,
    data_impl,
    splits_string,
    train_valid_test_num_samples,
    max_seq_length,
    masked_lm_prob,
    short_seq_prob,
    seed,
    skip_warmup,
    binary_head=False,
    max_seq_length_dec=None,
    dataset_type='standard_bert',
    tokenizer=None,
    max_ngram_size=3,
    mean_ngram_size=None,
    geometric_dist=True,
    permutation=False,
    whole_word_masking=True,
    favor_long_ngrams=False,
    delete_mask_prob=0,
):

    if len(data_prefix) == 1:
        return _build_train_valid_test_datasets(
            cfg,
            trainer,
            data_prefix[0],
            data_impl,
            splits_string,
            train_valid_test_num_samples,
            max_seq_length,
            masked_lm_prob,
            short_seq_prob,
            seed,
            skip_warmup,
            binary_head,
            max_seq_length_dec,
            dataset_type=dataset_type,
            tokenizer=tokenizer,
            max_ngram_size=max_ngram_size,
            mean_ngram_size=mean_ngram_size,
            geometric_dist=geometric_dist,
            permutation=permutation,
            whole_word_masking=whole_word_masking,
            favor_long_ngrams=favor_long_ngrams,
            delete_mask_prob=delete_mask_prob,
        )
    # Blending dataset.
    # Parse the values.
    output = get_datasets_weights_and_num_samples(
        data_prefix, train_valid_test_num_samples)
    prefixes, weights, datasets_train_valid_test_num_samples = output

    # Build individual datasets.
    train_datasets = []
    valid_datasets = []
    test_datasets = []
    for i in range(len(prefixes)):
        train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
            cfg,
            trainer,
            prefixes[i],
            data_impl,
            splits_string,
            datasets_train_valid_test_num_samples[i],
            max_seq_length,
            masked_lm_prob,
            short_seq_prob,
            seed,
            skip_warmup,
            binary_head,
            max_seq_length_dec,
            dataset_type=dataset_type,
            tokenizer=tokenizer,
            max_ngram_size=max_ngram_size,
            mean_ngram_size=mean_ngram_size,
            geometric_dist=geometric_dist,
            permutation=permutation,
            whole_word_masking=whole_word_masking,
            favor_long_ngrams=favor_long_ngrams,
            delete_mask_prob=delete_mask_prob,
        )
        if train_ds:
            train_datasets.append(train_ds)
        if valid_ds:
            valid_datasets.append(valid_ds)
        if test_ds:
            test_datasets.append(test_ds)

        # Blend.
    blending_train_dataset = None
    if train_datasets:
        blending_train_dataset = BlendableDataset(train_datasets, weights)
    blending_valid_dataset = None
    if valid_datasets:
        blending_valid_dataset = BlendableDataset(valid_datasets, weights)
    blending_test_dataset = None
    if test_datasets:
        blending_test_dataset = BlendableDataset(test_datasets, weights)

    return (blending_train_dataset, blending_valid_dataset,
            blending_test_dataset)
Esempio n. 4
0
    def build_memmap_dataset_from_config(self, cfg: DictConfig):
        """Builds a memmap dataset from a existing binary based o nthe provided config."""
        is_src_listconfig = isinstance(cfg.src_file_name, ListConfig)
        is_tgt_listconfig = isinstance(cfg.tgt_file_name, ListConfig)
        # If multilingual, make sure both source and target are list configs
        if self.multilingual:
            if not (is_src_listconfig and is_tgt_listconfig):
                raise ValueError(
                    f"Multilingual datasets must be configured with a ListConfig for both src_file_name and tgt_file_name"
                )
        if is_src_listconfig and not is_tgt_listconfig or is_tgt_listconfig and not is_src_listconfig:
            raise ValueError(
                f"Datasets must be configured with a ListConfig for both src_file_name and tgt_file_name or neither. Found only one of them as listconfig."
            )

        if is_src_listconfig and is_tgt_listconfig:
            if len(cfg.src_file_name) != len(cfg.tgt_file_name):
                raise ValueError(
                    f"Datasets must have the same number of files in src_file_name and tgt_file_name"
                )

            if cfg.concat_sampling_probabilities is None or not isinstance(
                    cfg.concat_sampling_probabilities, ListConfig):
                raise ValueError(
                    f"concat_sampling_probabilities must be a ListConfig with the same number of files in src_file_name and tgt_file_name, found {cfg.concat_sampling_probabilities}"
                )
            if len(cfg.concat_sampling_probabilities) != len(
                    cfg.src_file_name):
                raise ValueError(
                    f"concat_sampling_probabilities must be of the same size as src_file_name and tgt_file_name. Provided size {len(cfg.concat_sampling_probabilities)}, number of datasets {len(cfg.src_file_name)}"
                )

            datasets = []
            for src_file, tgt_fille in zip(cfg.src_file_name,
                                           cfg.tgt_file_name):
                if cfg.dataset_type == 'bin_memmap':
                    dataset = BinarizedMemmapSequenceToSequenceDataset(
                        src_dataset_prefix=src_file,
                        tgt_dataset_prefix=tgt_fille,
                        src_tokenizer=self.encoder_tokenizer,
                        tgt_tokenizer=self.decoder_tokenizer,
                        max_src_seq_length=cfg.max_seq_length,
                        max_tgt_seq_length=cfg.max_seq_length,
                        start_index=0,
                        end_index=None,
                        data_impl="mmap",
                        skip_warmup=True,
                    )
                elif cfg.dataset_type == 'text_memmap':
                    dataset = TextMemmapSequenceToSequenceDataset(
                        src_file_name=src_file,
                        tgt_file_name=tgt_fille,
                        src_tokenizer=self.encoder_tokenizer,
                        tgt_tokenizer=self.decoder_tokenizer,
                        max_src_seq_length=cfg.max_seq_length,
                        max_tgt_seq_length=cfg.max_seq_length,
                    )
                datasets.append(dataset)
            dataset = BlendableDataset(
                datasets=datasets, weights=cfg.concat_sampling_probabilities)
        else:
            if cfg.dataset_type == 'bin_memmap':
                dataset = BinarizedMemmapSequenceToSequenceDataset(
                    src_dataset_prefix=cfg.src_file_name,
                    tgt_dataset_prefix=cfg.tgt_file_name,
                    src_tokenizer=self.encoder_tokenizer,
                    tgt_tokenizer=self.decoder_tokenizer,
                    max_src_seq_length=cfg.max_seq_length,
                    max_tgt_seq_length=cfg.max_seq_length,
                    start_index=0,
                    end_index=None,
                    data_impl="mmap",
                    skip_warmup=True,
                )
            elif cfg.dataset_type == 'text_memmap':
                dataset = TextMemmapSequenceToSequenceDataset(
                    src_file_name=cfg.src_file_name,
                    tgt_file_name=cfg.tgt_file_name,
                    src_tokenizer=self.encoder_tokenizer,
                    tgt_tokenizer=self.decoder_tokenizer,
                    max_src_seq_length=cfg.max_seq_length,
                    max_tgt_seq_length=cfg.max_seq_length,
                )
        return dataset
Esempio n. 5
0
def build_train_valid_test_datasets(
    cfg,
    trainer,
    data_prefix: List[str],
    data_impl: str,
    splits_string: str,
    train_valid_test_num_samples,
    seq_length: int,
    seed: int,
    skip_warmup: bool,
    tokenizer,
    retrieval_prefix: str,
    knn_map_path: List[str],
):
    """Build train, valid, and test RETRO datasets.
       There is one to one mapping between data_prefix and knn_map_path.
       Currently only supports one retrieval dataset.
    """
    # make sure there is one to one mapping  between data_prefix and knn_map_path
    assert len(data_prefix) == len(knn_map_path)

    # Single dataset.
    if len(data_prefix) == 1:
        return _build_train_valid_test_datasets(
            cfg,
            trainer,
            data_prefix[0],
            data_impl,
            splits_string,
            train_valid_test_num_samples,
            seq_length,
            seed,
            skip_warmup,
            tokenizer,
            retrieval_prefix,
            knn_map_path[0],
        )

    # Blending dataset.
    # Parse the values.
    output = get_datasets_weights_and_num_samples(
        data_prefix, train_valid_test_num_samples)
    prefixes, weights, datasets_train_valid_test_num_samples = output

    # Build individual datasets.
    train_datasets = []
    valid_datasets = []
    test_datasets = []
    for i in range(len(prefixes)):
        train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
            cfg,
            trainer,
            prefixes[i],
            data_impl,
            splits_string,
            datasets_train_valid_test_num_samples[i],
            seq_length,
            seed,
            skip_warmup,
            tokenizer,
            retrieval_prefix,
            knn_map_path[i],
        )
        if train_ds:
            train_datasets.append(train_ds)
        if valid_ds:
            valid_datasets.append(valid_ds)
        if test_ds:
            test_datasets.append(test_ds)

    # Blend.
    blending_train_dataset = None
    if train_datasets:
        blending_train_dataset = BlendableDataset(train_datasets, weights)
    blending_valid_dataset = None
    if valid_datasets:
        blending_valid_dataset = BlendableDataset(valid_datasets, weights)
    blending_test_dataset = None
    if test_datasets:
        blending_test_dataset = BlendableDataset(test_datasets, weights)

    return (blending_train_dataset, blending_valid_dataset,
            blending_test_dataset)