예제 #1
0
    def _load_single_lang_dataset(self, split, epoch):
        loaded_datasets = []

        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(data_path, split_k)

            ds = data_utils.load_indexed_dataset(path, self.dictionary,
                                                 self.args.dataset_impl)
            if ds is None:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(split, data_path))

            # Since we append each block with the classification_token,
            # we need to effectively create blocks of length
            # tokens_per_sample-1
            loaded_datasets.append(
                TokenBlockDataset(
                    ds,
                    ds.sizes,
                    self.args.tokens_per_sample - 1,
                    pad=self.dictionary.pad(),
                    eos=self.dictionary.eos(),
                ))

            print('| {} {} {} examples'.format(data_path, split_k,
                                               len(loaded_datasets[-1])))

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        return dataset, sizes
예제 #2
0
    def from_tsv(cls, root: str, data_cfg: S2TDataConfig, splits: str,
                 tgt_dict, pre_tokenizer, bpe_tokenizer, is_train_split: bool,
                 epoch: int, seed: int, audio_dict) -> AudioDictDataset:
        samples = []
        _splits = splits.split(",")
        for split in _splits:
            tsv_path = op.join(root, f"{split}.tsv")
            if not op.isfile(tsv_path):
                raise FileNotFoundError(f"Dataset not found: {tsv_path}")
            with open(tsv_path) as f:
                reader = csv.DictReader(
                    f,
                    delimiter="\t",
                    quotechar=None,
                    doublequote=False,
                    lineterminator="\n",
                    quoting=csv.QUOTE_NONE,
                )
                samples.append([dict(e) for e in reader])
                assert len(samples) > 0

        datasets = [
            cls._from_list(name, is_train_split, [s], data_cfg, tgt_dict,
                           pre_tokenizer, bpe_tokenizer, audio_dict)
            for name, s in zip(_splits, samples)
        ]

        if is_train_split and len(
                _splits) > 1 and data_cfg.sampling_alpha != 1.0:
            # temperature-based sampling
            size_ratios = cls._get_size_ratios(_splits,
                                               [len(s) for s in samples],
                                               alpha=data_cfg.sampling_alpha)
            datasets = [
                ResamplingDataset(d,
                                  size_ratio=r,
                                  seed=seed,
                                  epoch=epoch,
                                  replace=(r >= 1.0))
                for d, r in zip(datasets, size_ratios)
            ]
        return ConcatDataset(datasets)
예제 #3
0
    def from_tsv(
        cls,
        root: str,
        cfg: S2TDataConfig,
        splits: str,
        tgt_dict,
        pre_tokenizer,
        bpe_tokenizer,
        is_train_split: bool,
        epoch: int,
        seed: int,
        n_frames_per_step: int = 1,
        speaker_to_id=None,
    ) -> SpeechToTextDataset:
        datasets = [
            cls._from_tsv(
                root,
                cfg,
                split,
                tgt_dict,
                is_train_split,
                pre_tokenizer,
                bpe_tokenizer,
                n_frames_per_step,
                speaker_to_id,
            ) for split in splits.split(",")
        ]

        if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
            # temperature-based sampling
            size_ratios = cls.get_size_ratios(datasets,
                                              alpha=cfg.sampling_alpha)
            datasets = [
                ResamplingDataset(d,
                                  size_ratio=r,
                                  seed=seed,
                                  epoch=epoch,
                                  replace=(r >= 1.0))
                for r, d in zip(size_ratios, datasets)
            ]

        return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
예제 #4
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        if self.args.dataset_from_json:
            raise NotImplementedError
        datasets = []
        for path in self.paths:
            try:
                ds = get_datasets_from_indexed_filterbanks(
                    path,
                    self.args.target_lang,
                    self.tgt_dict,
                    split,
                    self.args.dataset_impl,
                    self.args.skip_normalization,
                    self.args.legacy_audio_fix_lua_indexing)
                if self.training:
                    if self.args.context_type == 'src':
                        context_ds = FilterBanksDataset(
                            os.path.join(path, split) + ".context.npz",
                            self.args.dataset_impl == "cached",
                            self.args.legacy_audio_fix_lua_indexing)
                    else:
                        context_ds = data_utils.load_indexed_dataset(
                            os.path.join(path, split) + ".context." + self.args.target_lang,
                            self.tgt_dict,
                            self.args.dataset_impl)
                    datasets.append(ContextAwareDataset(
                        ds, context_ds, self.tgt_dict, self.args.context_type == 'src'))
                else:
                    datasets.append(ds)
            except Exception:
                logger.warning("Split {} not found in {}. Skipping...".format(split, path))
        assert len(datasets) > 0
        if len(datasets) > 1:
            self.datasets[split] = ConcatDataset(datasets)
        else:
            self.datasets[split] = datasets[0]
예제 #5
0
    def load_dataset(lang,
                     lang_dict,
                     prefix,
                     dataset_length,
                     sample_ratios=None):
        """
        Function to load additional dataset and deal with all parameters.
        Easier than copying redudant code for each dataset.
        Requires src_dataset to provide the length and sample_ratios.
        """
        lang_datasets = []
        lang_dataset = data_utils.load_indexed_dataset(prefix + lang,
                                                       lang_dict, dataset_impl)
        if lang_dataset is not None:
            lang_datasets.append(lang_dataset)
        assert dataset_length == len(lang_datasets) or len(lang_datasets) == 0
        if dataset_length == 1:
            lang_dataset = lang_datasets[0] if len(lang_datasets) > 0 else None
        else:
            assert sample_ratios is not None
            if len(lang_datasets) > 0:
                lang_dataset = ConcatDataset(lang_datasets, sample_ratios)
            else:
                lang_dataset = None
        if prepend_bos:
            assert hasattr(src_dict, "bos_index") and hasattr(
                lang_dict, "bos_index")
            if lang_dataset is not None:
                lang_dataset = PrependTokenDataset(lang_dataset,
                                                   lang_dict.bos())
        eos = None
        if append_source_id:
            if lang_dataset is not None:
                lang_dataset = AppendTokenDataset(
                    lang_dataset, lang_dict.index('[{}]'.format(lang)))

        lang_dataset_sizes = lang_dataset.sizes if lang_dataset is not None else None
        return lang_dataset, lang_dataset_sizes
예제 #6
0
 def __load_dataset(self, split, lang_pair):
     src, tgt = lang_pair.split('-')
     datasets = []
     for path in self.paths:
         try:
             ds = get_datasets_from_indexed_filterbanks(
                 path, tgt, self.dicts[tgt], split, self.args.dataset_impl,
                 self.args.skip_normalization,
                 self.args.legacy_audio_fix_lua_indexing)
             datasets.append(ds)
         except Exception:
             logger.warning("Split {} not found in {}. Skipping...".format(
                 split, path))
     assert len(datasets) > 0
     if len(datasets) > 1:
         dataset = ConcatDataset(datasets)
     else:
         dataset = datasets[0]
     return self.alter_dataset_langtok(dataset,
                                       src_eos=None,
                                       src_lang=src,
                                       tgt_eos=self.dicts[tgt].eos(),
                                       tgt_lang=tgt)
예제 #7
0
    def load_dataset(self, split, combine=False):
        """
        Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        loaded_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(self.args.data, split_k)

            if self.args.raw_text and IndexedRawTextDataset.exists(path):
                ds = IndexedRawTextDataset(path, self.dictionary)
            elif not self.args.raw_text and IndexedDataset.exists(path):
                if self.args.lazy_load:
                    ds = IndexedDataset(path, fix_lua_indexing=True)
                else:
                    ds = IndexedCachedDataset(path, fix_lua_indexing=True)
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(
                            split, self.args.data))
            with data_utils.numpy_seed(self.seed + k):
                loaded_datasets.append(
                    BlockPairDataset(
                        ds,
                        self.dictionary,
                        ds.sizes,
                        self.args.tokens_per_sample,
                        break_mode=self.args.break_mode,
                    ))

            logger.info('{} {} {} examples'.format(self.args.data, split_k,
                                                   len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        self.datasets[split] = MaskedLMDataset(
            dataset=dataset,
            sizes=sizes,
            vocab=self.dictionary,
            pad_idx=self.dictionary.pad(),
            mask_idx=self.dictionary.mask(),
            classif_token_idx=self.dictionary.cls(),
            sep_token_idx=self.dictionary.sep(),
            shuffle=False,
            seed=self.seed,
        )
예제 #8
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(":")
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        if self.langs is None:
            languages = sorted(
                [
                    name
                    for name in os.listdir(data_path)
                    if os.path.isdir(os.path.join(data_path, name))
                ]
            )
        else:
            languages = self.langs.split(",")
            for name in languages:
                p = os.path.join(data_path, name)
                assert os.path.exists(p), "data not found: {}".format(p)

        logger.info("Training on {0} languages: {1}".format(len(languages), languages))
        logger.info(
            "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}
        )

        mask_whole_words = get_whole_word_mask(self.args, self.dictionary)
        language_without_segmentations = self.args.no_whole_word_mask_langs.split(",")
        lang_datasets = []
        for language in languages:
            split_path = os.path.join(data_path, language, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            if dataset is None:
                raise FileNotFoundError(
                    "Dataset not found: {} ({})".format(split, split_path)
                )

            end_token = (
                self.source_dictionary.index("[{}]".format(language))
                if self.args.add_lang_token
                else self.source_dictionary.eos()
            )

            # create continuous blocks of tokens
            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                self.args.tokens_per_sample - 2,  # one less for <s>
                pad=self.source_dictionary.pad(),
                eos=end_token,
                break_mode=self.args.sample_break_mode,
            )
            logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))

            # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
            dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
            dataset = AppendTokenDataset(dataset, end_token)

            lang_mask_whole_words = (
                mask_whole_words
                if language not in language_without_segmentations
                else None
            )
            lang_dataset = DenoisingDataset(
                dataset,
                dataset.sizes,
                self.dictionary,
                self.mask_idx,
                lang_mask_whole_words,
                shuffle=self.args.shuffle_instance,
                seed=self.seed,
                args=self.args,
                eos=None
                if not self.args.add_lang_token
                else self.source_dictionary.index("[{}]".format(language)),
            )
            lang_datasets.append(lang_dataset)

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        logger.info(
            "loaded total {} blocks for all languages".format(
                int(dataset_lengths.sum()),
            )
        )
        if split == self.args.train_subset:
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(dataset_lengths)
            logger.info(
                "Sample probability by language: {}".format(
                    {
                        lang: "{0:.4f}".format(sample_probs[id])
                        for id, lang in enumerate(languages)
                    }
                )
            )
            size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
            logger.info(
                "Up/Down Sampling ratio by language: {}".format(
                    {
                        lang: "{0:.2f}".format(size_ratio[id])
                        for id, lang in enumerate(languages)
                    }
                )
            )

            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args.seed,
                    epoch=epoch,
                    replace=size_ratio[i] >= 1.0,
                )
                for i, d in enumerate(lang_datasets)
            ]
            dataset = ConcatDataset(
                resampled_lang_datasets,
            )
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            for lang_id, lang_dataset in enumerate(lang_datasets):
                split_name = split + "_" + languages[lang_id]
                lang_splits.append(split_name)
                self.datasets[split_name] = lang_dataset

            if split in self.args.valid_subset:
                self.args.valid_subset = self.args.valid_subset.replace(
                    split, ",".join(lang_splits)
                )

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
예제 #9
0
def get_asr_dataset_from_json(
    data_path,
    split,
    tgt_dict,
    combine,
    upsample_primary,
    max_source_positions,
    max_target_positions,
    seed=1,
    specaugment_config=None,
):
    """
    Parse data json and create dataset.
    See espresso/tools/asr_prep_json.py which pack json from raw files
    Json example:
    {
        "011c0202": {
            "feat": "fbank/raw_fbank_pitch_train_si284.1.ark:54819",
            "token_text": "T H E <space> H O T E L",
            "utt2num_frames": "693",
        },
        "011c0203": {
            ...
        }
    }
    """
    src_datasets = []
    tgt_datasets = []
    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else "")
        data_json_path = os.path.join(data_path, "{}.json".format(split_k))
        if not os.path.isfile(data_json_path):
            if k > 0:
                break
            else:
                raise FileNotFoundError(
                    "Dataset not found: {}".format(data_json_path))

        with open(data_json_path, "rb") as f:
            loaded_json = json.load(f, object_pairs_hook=OrderedDict)

        utt_ids, feats, token_text, utt2num_frames = [], [], [], []
        for utt_id, val in loaded_json.items():
            utt_ids.append(utt_id)
            feats.append(val["feat"])
            if "token_text" in val:
                token_text.append(val["token_text"])
            if "utt2num_frames" in val:
                utt2num_frames.append(int(val["utt2num_frames"]))

        assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames)
        src_datasets.append(
            FeatScpCachedDataset(
                utt_ids,
                feats,
                utt2num_frames=utt2num_frames,
                seed=seed,
                specaugment_config=specaugment_config
                if split == "train" else None,
                ordered_prefetch=True,
            ))
        if len(token_text) > 0:
            assert len(utt_ids) == len(token_text)
            assert tgt_dict is not None
            tgt_datasets.append(AsrTextDataset(utt_ids, token_text, tgt_dict))

        logger.info("{} {} examples".format(data_json_path,
                                            len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    feat_dim = src_datasets[0].feat_dim

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        for i in range(1, len(src_datasets)):
            assert feat_dim == src_datasets[i].feat_dim, \
                "feature dimension does not match across multiple json files"
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return AsrDataset(
        src_dataset,
        src_dataset.sizes,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=False,
        left_pad_target=False,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
    )
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        def split_exists(split, data_type, data_path):
            filename = os.path.join(data_path, f'{split}.{data_type}')
            assert not self.args.raw_text
            exists = [
                IndexedDataset.exists(
                    os.path.join(data_path, f'{split}.{data_type}.{k}'))
                for k in DPTREE_KEYS
            ]
            if all(exists):
                return True
            else:
                print(f'Following modality not exists: {exists}')
                return False

        # def indexed_dataset(path, dictionary):
        def indexed_dataset(path):
            assert IndexedCachedDataset.exists(
                path), f'IndexedCachedDataset.exists({path})'
            return IndexedCachedDataset(path, fix_lua_indexing=True)

        def dptree_indexed_dataset(path):
            assert DPTreeIndexedCachedDataset.exists(
                path), f'DPTreeIndexedCachedDataset.exists({path})'
            return DPTreeIndexedCachedDataset(path, fix_lua_indexing=True)

        src_datasets = []
        tgt_datasets = []
        src_datasets_dict = {k: [] for k in DPTREE_KEYS}

        # data_paths = self.args.data
        data_path = self.args.data
        print(f'| split = {split}')
        print(f'| self.args.data = {self.args.data}')
        # singular data path
        lang = self.args.source_lang

        src, tgt = 'input', 'target'

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            if split_exists(split_k, src, data_path):
                prefix = os.path.join(data_path, f'{split}.')
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(split, data_path))
            # src_datasets.append(indexed_dataset(prefix + src))
            for modality in src_datasets_dict.keys():
                src_datasets_dict[modality].append(
                    dptree_indexed_dataset(f'{prefix}{src}.{modality}'))

            tgt_datasets.append(indexed_dataset(prefix + tgt))

            print('| {} {} {} examples'.format(data_path, split_k,
                                               len(tgt_datasets[-1])))
            if not combine:
                break

        assert len(src_datasets_dict[DPTREE_KEYS[0]]) == len(tgt_datasets)

        if len(tgt_datasets) == 1:
            # src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
            src_dataset_dict = {k: v[0] for k, v in src_datasets_dict.items()}
            tgt_dataset = tgt_datasets[0]
        else:
            sample_ratios = [1] * len(src_datasets)
            sample_ratios[0] = self.args.upsample_primary
            # src_dataset = ConcatDataset(src_datasets, sample_ratios)
            src_dataset_dict = {
                k: ConcatDataset(v, sample_ratios)
                for k, v in src_datasets_dict.items()
            }
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

        src_sizes = src_dataset_dict['nodes'].sizes.reshape(-1, 2).sum(-1)
        # print(f'src_sizes::: {src_sizes}')
        self.datasets[split] = NodeStackFromDPTreeSepMonoClassificationDataset(
            # srcs, src_sizes, src_dict
            src_dataset_dict,
            src_sizes,
            self.source_dictionary,
            tgt_dataset,
            left_pad_source=self.args.left_pad_source,
            # left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            # max_target_positions=self.args.max_target_positions,
        )
예제 #11
0
    def load_dataset(self, split, combine=False, epoch=0):
        """Load a given dataset split.
        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        def split_exists(split, src, tgt, lang, data_path):
            filename = os.path.join(
                data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            if self.args.raw_text and IndexedRawTextDataset.exists(filename):
                return True
            elif not self.args.raw_text and IndexedDataset.exists(filename):
                return True
            return False

        def indexed_dataset(path, dictionary):
            if self.args.raw_text:
                return IndexedRawTextDataset(path, dictionary)
            elif IndexedDataset.exists(path):
                return IndexedCachedDataset(path, fix_lua_indexing=True)
            return None

        src_datasets = []
        tgt_datasets = []

        data_paths = self.args.data

        for dk, data_path in enumerate(data_paths):
            for k in itertools.count():
                split_k = split + (str(k) if k > 0 else '')

                # infer langcode
                src, tgt = self.args.source_lang, self.args.target_lang
                if split_exists(split_k, src, tgt, src, data_path):
                    prefix = os.path.join(
                        data_path, '{}.{}-{}.'.format(split_k, src, tgt))
                elif split_exists(split_k, tgt, src, src, data_path):
                    prefix = os.path.join(
                        data_path, '{}.{}-{}.'.format(split_k, tgt, src))
                else:
                    if k > 0 or dk > 0:
                        break
                    else:
                        raise FileNotFoundError(
                            'Dataset not found: {} ({})'.format(
                                split, data_path))

                src_datasets.append(
                    indexed_dataset(prefix + src, self.src_dict))
                tgt_datasets.append(
                    indexed_dataset(prefix + tgt, self.tgt_dict))

                print('| {} {} {} examples'.format(data_path, split_k,
                                                   len(src_datasets[-1])))

                if not combine:
                    break

        assert len(src_datasets) == len(tgt_datasets)

        if len(src_datasets) == 1:
            src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
        else:
            sample_ratios = [1] * len(src_datasets)
            sample_ratios[0] = self.args.upsample_primary
            src_dataset = ConcatDataset(src_datasets, sample_ratios)
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

        if split == "train":
            train = True
            seed = None
        elif split == "valid":
            train = True
            seed = 1
        elif split == "test":
            train = False
            seed = 1
        else:
            raise Exception('No such split: ' + str(split))

        self.datasets[split] = LanguagePairSelfDatasetMask(
            src_dataset,
            src_dataset.sizes,
            self.src_dict,
            tgt_dataset,
            tgt_dataset.sizes,
            self.tgt_dict,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            shuffle=False,
            dynamic_length=self.args.dynamic_length,
            mask_range=self.args.mask_range,
            train=train,
            seed=seed,
            full_masking=self.args.full_masking,
            dynamic_masking=self.args.dynamic_masking,
            skip_eos=self.args.skip_eos,
        )
예제 #12
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        def split_exists(split, src, tgt, lang, data_path):
            filename = os.path.join(
                data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            if self.args.raw_text and IndexedRawTextDataset.exists(filename):
                return True
            elif not self.args.raw_text and IndexedDataset.exists(filename):
                return True
            return False

        def indexed_dataset(path,
                            dictionary,
                            copy_ext_dict=False,
                            src_dataset=None):
            if self.args.raw_text:
                return IndexedRawTextDataset(path,
                                             dictionary,
                                             copy_ext_dict=copy_ext_dict,
                                             src_dataset=src_dataset)
            elif IndexedDataset.exists(path):
                if self.args.lazy_load:
                    return IndexedDataset(path, fix_lua_indexing=True)
                else:
                    return IndexedCachedDataset(path, fix_lua_indexing=True)
            return None

        def indexed_label(path):
            if IndexedRawLabelDataset.exists(path):
                return IndexedRawLabelDataset(path)
            else:
                print('Label file not found: {}'.format(path))
            return None

        src_datasets = []
        tgt_datasets = []
        src_labels = []
        tgt_labels = []

        data_paths = self.args.data
        # 如果有其它文件,请按照train1, train2等命名
        for dk, data_path in enumerate(data_paths):
            for k in itertools.count():
                split_k = split + (str(k) if k > 0 else '')

                # infer langcode
                src, tgt = self.args.source_lang, self.args.target_lang
                if split_exists(split_k, src, tgt, src, data_path):
                    prefix = os.path.join(
                        data_path, '{}.{}-{}.'.format(split_k, src, tgt))
                elif split_exists(split_k, tgt, src, src, data_path):
                    prefix = os.path.join(
                        data_path, '{}.{}-{}.'.format(split_k, tgt, src))
                else:
                    if k > 0 or dk > 0:
                        break
                    else:
                        raise FileNotFoundError(
                            'Dataset not found: {} ({})'.format(
                                split, data_path))
                src_dataset = indexed_dataset(prefix + src, self.src_dict,
                                              self.args.copy_ext_dict)
                tgt_dataset = indexed_dataset(prefix + tgt, self.tgt_dict,
                                              self.args.copy_ext_dict,
                                              src_dataset)
                # src_dataset 包括 lines, sizes, tokens_list, words_list
                src_datasets.append(src_dataset)
                tgt_datasets.append(tgt_dataset)
                #label的索引
                label_prefix = os.path.join(data_path,
                                            '{}.label.'.format(split_k))
                src_label = indexed_label(label_prefix + src + '.txt')
                tgt_label = indexed_label(label_prefix + tgt + '.txt')
                src_labels.append(src_label)
                tgt_labels.append(tgt_label)

                print('| {} {} {} examples'.format(data_path, split_k,
                                                   len(src_datasets[-1])))

                if not combine:
                    break

        assert len(src_datasets) == len(tgt_datasets)

        src_label, tgt_label = None, None
        if len(src_datasets) == 1:
            src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
            src_label, tgt_label = src_labels[0], tgt_labels[0]
        else:
            sample_ratios = [1] * len(src_datasets)
            sample_ratios[0] = self.args.upsample_primary
            src_dataset = ConcatDataset(src_datasets, sample_ratios)
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        #源数据和目标数据组成的数据集
        self.datasets[split] = LanguagePairDataset(
            src_dataset,
            src_dataset.sizes,
            self.src_dict,
            src_label,
            tgt_dataset,
            tgt_dataset.sizes,
            self.tgt_dict,
            tgt_label,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
        )
예제 #13
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    num_buckets=0,
    shuffle=True,
    pad_to_multiple=1,
    add_lang_token=False,
):
    def split_exists(split, src, tgt, lang, data_path):
        logger.info(
            os.path.join(data_path,
                         "{}.{}-{}.{}".format(split, src, tgt, lang)))
        filename = os.path.join(data_path,
                                "{}.{}-{}.{}".format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    def split_exists_self(split, src, data_path):
        logger.info(
            os.path.join(data_path, "{}.{}-{}.{}".format(split, src, src,
                                                         src)))
        filename = os.path.join(data_path,
                                "{}.{}-{}.{}".format(split, src, src, src))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    def split_exists_valid(split, lang, data_path):
        logger.info(os.path.join(data_path, "{}.{}".format(split, lang)))
        filename = os.path.join(data_path, "{}.{}".format(split, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else "")
        # print(split_k, src, tgt, src, data_path)
        prefix_src = None
        prefix_tgt = None
        if not "-" in split_k:
            # infer langcode
            if split_exists(split_k, src, tgt, src, data_path):
                prefix = os.path.join(data_path,
                                      "{}.{}-{}.".format(split_k, src, tgt))
            elif split_exists(split_k, tgt, src, src, data_path):
                prefix = os.path.join(data_path,
                                      "{}.{}-{}.".format(split_k, tgt, src))
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        "Dataset not found: {} ({}) {} {}".format(
                            split, data_path, src, tgt))
        else:
            # infer langcode
            if split_exists_valid(split_k, src, data_path):
                prefix = os.path.join(data_path, split_k + ".")
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        "Dataset not found: {} ({}) ".format(split, data_path))
        if prefix_src != None:
            prefix = prefix_src

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict,
                                                      dataset_impl)
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)

        if prefix_tgt != None:
            prefix = prefix_tgt
        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict,
                                                      dataset_impl)
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        logger.info("{} {} {}-{} {} examples".format(data_path, split_k,
                                                     src, tgt,
                                                     len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        logger.info("::::data sample_ratios:{}".format(sample_ratios))
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(src_dataset,
                                         src_dict.index("[{}]".format(src)))
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(
                tgt_dataset, tgt_dict.index("[{}]".format(tgt)))
        eos = tgt_dict.index("[{}]".format(tgt))

    eos = None
    if add_lang_token:
        src_dataset = PrependTokenDataset(src_dataset,
                                          src_dict.index("[{}]".format(src)))
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(
                tgt_dataset, tgt_dict.index("[{}]".format(tgt)))

    align_dataset = None
    if load_alignments:
        align_path = os.path.join(data_path,
                                  "{}.align.{}-{}".format(split, src, tgt))
        if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
            align_dataset = data_utils.load_indexed_dataset(
                align_path, None, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        align_dataset=align_dataset,
        eos=eos,
        num_buckets=num_buckets,
        shuffle=shuffle,
        pad_to_multiple=pad_to_multiple,
    )
예제 #14
0
def load_generation_pair_dataset(
    data_path, split,
    tgt,
    src_dict,
    tgt_dict,
    combine, dataset_impl, upsample_primary,
    left_pad_source, left_pad_target, max_source_positions,
    max_target_positions, prepend_bos=False, load_alignments=False,
    truncate_source=False, append_source_id=False, common_eos=None
):

    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, "src", "tgt", "src", data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, "src", "tgt"))
        elif split_exists(split_k, "tgt", "src", "src", data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, "tgt", "src"))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + "src", src_dict, dataset_impl)
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + "tgt", tgt_dict, dataset_impl)
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        logger.info('{} {} {}-{} {} examples'.format(
            data_path, split_k, "src", "tgt", len(src_datasets[-1])
        ))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    eos = None
    if append_source_id:
        if common_eos is not None:
            src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(common_eos)))
            if tgt_dataset is not None:
                tgt_dataset = AppendTokenDataset(tgt_dataset, tgt_dict.index('[{}]'.format(common_eos)))
            eos = tgt_dict.index('[{}]'.format(common_eos))

    bos = tgt_dict.index('[{}]'.format(tgt))

    align_dataset = None
    if load_alignments:
        align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, "src", "tgt"))
        if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
            align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return GenerationPairDataset(
        src_dataset, src_dataset.sizes, src_dict,
        tgt_dataset, tgt_dataset_sizes, tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=align_dataset, eos=eos, bos=bos 
    )
예제 #15
0
def pos_loader(data_path,
               split,
               src,
               src_dict,
               tgt,
               tgt_dict,
               anchor,
               anchor_dict,
               combine,
               dataset_impl,
               upsample_primary,
               left_pad_source,
               left_pad_target,
               max_source_positions,
               max_target_positions,
               prepend_bos=False,
               truncate_source=False,
               append_source_id=False):

    # Check the existence of the file
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path,
                                '{}.{}-{}.{}'.format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []
    anchor_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode (from a->b or from b->a)
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.{}-{}.'.format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.{}-{}.'.format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict,
                                                      dataset_impl)
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict,
                                                      dataset_impl)
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        anchor_prefix = os.path.join(data_path, anchor,
                                     '{}.{}-{}.'.format(split_k, anchor, tgt))

        anchor_dataset = data_utils.load_indexed_dataset(
            anchor_prefix + anchor, anchor_dict, dataset_impl)
        if anchor_dataset is not None:
            anchor_datasets.append(anchor_dataset)

        logger.info('{} {} {}-{} {} examples'.format(data_path, split_k,
                                                     src, tgt,
                                                     len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
    # None is not avaliable for anchors
    assert len(src_datasets) == len(anchor_datasets)

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
        anchor_dataset = anchor_datasets[0]
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None
        anchor_dataset = ConcatDataset(anchor_datasets, sample_ratios)

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(src_dataset,
                                         src_dict.index('[{}]'.format(src)))
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(
                tgt_dataset, tgt_dict.index('[{}]'.format(tgt)))
        eos = tgt_dict.index('[{}]'.format(tgt))

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    return POSGraphLanguagePairDatasetb(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        anchor_dataset,
        anchor_dataset.sizes,
        anchor_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        eos=eos)
    def load_dataset(self, split: str, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        languages, data_path = MultilingualLanguageModelingTask._get_langs(
            self.args, epoch)
        lang_to_offline_shard_ratio = None
        if self.args.lang_to_offline_shard_ratio != "":
            lang_to_offline_shard_ratio = {}
            assert os.path.exists(
                self.args.lang_to_offline_shard_ratio
            ), "provided offline shard ratio file doesn't exist: {0}".format(
                self.args.lang_to_offline_shard_ratio)
            with open(self.args.lang_to_offline_shard_ratio) as fin:
                for line in fin:
                    lang, ratio = line.strip().split("\t")
                    ratio = float(ratio)
                    lang_to_offline_shard_ratio[lang] = ratio

            logger.info(
                "Found offline sharded ratio: %s",
                lang_to_offline_shard_ratio,
            )

        if split == self.args.train_subset:
            logger.info("Training on {0} languages: {1}".format(
                len(languages), languages))
        else:
            logger.info("Evaluating on {0} languages: {1}".format(
                len(languages), languages))

        tokens_per_sample = self.args.tokens_per_sample - int(
            self.args.add_bos_token)

        fixed_pad_length = None
        if self.args.pad_to_fixed_length:
            fixed_pad_length = self.args.tokens_per_sample

        pad_to_bsz = None
        if self.args.pad_to_fixed_bsz:
            pad_to_bsz = (self.args.batch_size_valid
                          if "valid" in split else self.args.batch_size)

        lang_datasets = []
        for lang_id, language in enumerate(languages):
            split_path = os.path.join(data_path, language, split)
            dataset = data_utils.load_indexed_dataset(split_path,
                                                      self.dictionary,
                                                      self.args.dataset_impl,
                                                      combine=combine)
            # print('len(dataset) =', len(dataset))
            if dataset is None:
                raise FileNotFoundError("Dataset not found: {} ({})".format(
                    split, split_path))

            dataset = maybe_shorten_dataset(
                dataset,
                split,
                self.args.shorten_data_split_list,
                self.args.shorten_method,
                tokens_per_sample,
                self.args.seed,
            )

            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                tokens_per_sample,
                pad=self.dictionary.pad(),
                eos=self.dictionary.eos(),
                break_mode=self.args.sample_break_mode,
                include_targets=True,
            )

            add_eos_for_other_targets = (
                self.args.sample_break_mode is not None
                and self.args.sample_break_mode != "none")
            src_lang_idx, tgt_lang_idx = None, None
            if self.args.add_bos_token:
                src_lang_idx = self.dictionary.index(lang_token(language))
                tgt_lang_idx = self.output_dictionary.index(
                    lang_token(language))

            lang_datasets.append(
                MonolingualDataset(
                    dataset=dataset,
                    sizes=dataset.sizes,
                    src_vocab=self.dictionary,
                    tgt_vocab=self.output_dictionary,
                    add_eos_for_other_targets=add_eos_for_other_targets,
                    shuffle=True,
                    targets=self.targets,
                    fixed_pad_length=fixed_pad_length,
                    pad_to_bsz=pad_to_bsz,
                    add_bos_token=self.args.add_bos_token,
                    src_lang_idx=src_lang_idx,
                    tgt_lang_idx=tgt_lang_idx,
                ))

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        logger.info("loaded total {} blocks for all languages".format(
            dataset_lengths.sum(), ))
        if split == self.args.train_subset:
            dataset_lengths_ratio_multiplier = np.ones(len(dataset_lengths))
            if lang_to_offline_shard_ratio is not None:
                dataset_lengths_ratio_multiplier = []
                for lang in languages:
                    assert (
                        lang in lang_to_offline_shard_ratio
                    ), "Lang: {0} missing in offline shard ratio file: {1}".format(
                        lang,
                        self.args.lang_to_offline_shard_ratio,
                    )
                    dataset_lengths_ratio_multiplier.append(
                        lang_to_offline_shard_ratio[lang])
                dataset_lengths_ratio_multiplier = np.array(
                    dataset_lengths_ratio_multiplier)
                true_dataset_lengths = (dataset_lengths *
                                        dataset_lengths_ratio_multiplier)
            else:
                true_dataset_lengths = dataset_lengths
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(true_dataset_lengths)

            logger.info(
                "Sample probability by language: %s",
                {
                    lang: "{0:.4f}".format(sample_probs[id])
                    for id, lang in enumerate(languages)
                },
            )
            size_ratio = (sample_probs *
                          true_dataset_lengths.sum()) / dataset_lengths
            # TODO: add an option for shrinking all size ratios to below 1
            # if self.args.multilang_sampling_alpha != 1:
            #   size_ratio /= size_ratio.max()

            # Fix numeric errors in size ratio computation
            #   0.999999999999999999 -> 1
            #   1.000000000000000002 -> 1
            for i in range(len(size_ratio)):
                size_ratio[i] = round(size_ratio[i], 8)

            logger.info(
                "Up/Down Sampling ratio by language: %s",
                {
                    lang: "{0:.2f}".format(size_ratio[id])
                    for id, lang in enumerate(languages)
                },
            )
            logger.info(
                "Actual dataset size by language: %s",
                {
                    lang: "{0:.2f}".format(len(lang_datasets[id]))
                    for id, lang in enumerate(languages)
                },
            )
            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args.seed,
                    epoch=epoch,
                    replace=size_ratio[i] > 1.0,
                ) for i, d in enumerate(lang_datasets)
            ]
            logger.info(
                "Resampled dataset size by language: %s",
                {
                    lang: "{0:.2f}".format(len(resampled_lang_datasets[id]))
                    for id, lang in enumerate(languages)
                },
            )
            dataset = ConcatDataset(resampled_lang_datasets)
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            for lang_id, lang_dataset in enumerate(lang_datasets):
                split_name = split + "_" + languages[lang_id]
                lang_splits.append(split_name)
                self.datasets[split_name] = lang_dataset

            # [TODO]: This is hacky for now to print validation ppl for each
            # language individually. Maybe need task API changes to allow it
            # in more generic ways.
            if split in self.args.valid_subset:
                self.args.valid_subset = self.args.valid_subset.replace(
                    split, ",".join(lang_splits))

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
예제 #17
0
def load_langpair_dataset(
    data_path, split,
    src, src_dict,
    tgt, tgt_dict,
    combine, dataset_impl, upsample_primary,
    left_pad_source, left_pad_target, max_source_positions,
    max_target_positions, prepend_bos=False, load_alignments=False,
    truncate_source=False, srcda=False, srcda_choice='uniform', 
    tgtda=False, tgtda_choice='uniform'
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl)
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)
        tgt_datasets.append(
            data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl)
        )

        print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets)

    if len(src_datasets) == 1:
        src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    align_dataset = None
    if load_alignments:
        align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt))
        if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
            align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl)

    return LanguagePairDatasetDA(
        src_dataset, src_dataset.sizes, src_dict,
        tgt_dataset, tgt_dataset.sizes, tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=align_dataset,
        srcda=srcda, srcda_choice=srcda_choice,
        tgtda=tgtda, tgtda_choice=tgtda_choice
    )
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0

        if self.args.multiple_datasets:
            if len(paths) == 1:
                paths = [
                    os.path.join(paths[0], p)
                    for p in next(os.walk(paths[0]))[1]
                ]
            datasets = [
                ShardedDataset(
                    self.dictionary,
                    self.args.dataset_impl,
                    path,
                    split,
                    epoch - 1,
                    combine=combine,
                ) for path in paths
            ]

            ds_names = [ds.name for ds in datasets]

            if split in self.subsample_splits:
                sizes = [sum(d.sizes) for d in datasets]
                min_sz = min(sizes)
                ratios = [min_sz / sz for sz in sizes]
                datasets = [
                    SubsampleDataset(d, r) if r < 1 else d
                    for d, r in zip(datasets, ratios)
                ]

            dataset = ConcatDataset(datasets)
        else:
            data_path = paths[(epoch - 1) % len(paths)]
            split_path = os.path.join(data_path, split)

            dataset = data_utils.load_indexed_dataset(split_path,
                                                      self.dictionary,
                                                      self.args.dataset_impl,
                                                      combine=combine)
            if dataset is None:
                raise FileNotFoundError("Dataset not found: {} ({})".format(
                    split, split_path))
            ds_names = [None]

        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample,
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            break_mode=self.args.sample_break_mode,
            include_targets=True,
        )

        if self.args.prepend_ds_name:
            dataset = PrependDataset(
                dataset,
                prepend_getter=ds_name_getter(
                    offset=0,
                    generic_ds_name_chance=self.args.generic_ds_name_chance,
                    dictionary=self.dictionary,
                ),
                ensure_first_token_is=self.dictionary.eos(),
            )

        dataset = ReplaceDataset(
            dataset,
            replace_map={
                self.dictionary.eos(): self.dictionary.indices["\\n"]
            },
            offsets=[1, -1],
        )

        add_eos_for_other_targets = (self.args.sample_break_mode is not None
                                     and self.args.sample_break_mode != "none")

        dataset = MonolingualDataset(
            dataset,
            dataset.sizes,
            self.dictionary,
            self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=True,
            targets=self.targets,
            add_bos_token=self.args.add_bos_token,
        )

        if self.args.colorize_ds_name:
            ds_names.append("generic")
            min_ds = min(self.dictionary.indices[n] for n in ds_names)
            dataset = ColorizeDataset(
                dataset,
                color_getter=ds_name_getter(
                    offset=-min_ds,
                    generic_ds_name_chance=self.args.generic_ds_name_chance,
                    dictionary=self.dictionary,
                ),
            )

        self.datasets[split] = dataset
예제 #19
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""
        if self.cfg.data.endswith("1"):
            data_shard = (epoch - 1) % self.cfg.num_data_splits + 1
            data_path = self.cfg.data[:-1] + str(data_shard)
        else:
            data_path = self.cfg.data

        def get_path(type, data_split):
            return os.path.join(data_path, str(type), data_split)

        def make_dataset(type, dictionary, data_split, combine):
            split_path = get_path(type, data_split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                dictionary,
                combine=combine,
            )
            return dataset

        def load_split(data_split, metric):
            input_src = None
            if self.cfg.include_src:
                input_src = make_dataset("input_src",
                                         self.dictionary,
                                         data_split,
                                         combine=False)
                assert input_src is not None, "could not find dataset: {}".format(
                    get_path("input_src", data_split))

            input_tgt = make_dataset("input_tgt",
                                     self.dictionary,
                                     data_split,
                                     combine=False)
            assert input_tgt is not None, "could not find dataset: {}".format(
                get_path("input_tgt", data_split))

            label_path = f"{get_path(metric, data_split)}.{metric}"
            assert os.path.exists(
                label_path), f"could not find dataset: {label_path}"

            np_labels = np.loadtxt(label_path)
            if self.cfg.target_metric == "ter":
                np_labels = -np_labels
            label = RawLabelDataset(np_labels)

            return input_src, input_tgt, label

        src_datasets = []
        tgt_datasets = []
        label_datasets = []

        if split == self.cfg.train_subset:
            for k in itertools.count():
                split_k = "train" + (str(k) if k > 0 else "")
                prefix = os.path.join(data_path, "input_tgt", split_k)
                if not indexed_dataset.dataset_exists(prefix, impl=None):
                    if k > 0:
                        break
                    else:
                        raise FileNotFoundError(f"Dataset not found: {prefix}")
                input_src, input_tgt, label = load_split(
                    split_k, self.cfg.target_metric)
                src_datasets.append(input_src)
                tgt_datasets.append(input_tgt)
                label_datasets.append(label)
        else:
            input_src, input_tgt, label = load_split(split,
                                                     self.cfg.target_metric)
            src_datasets.append(input_src)
            tgt_datasets.append(input_tgt)
            label_datasets.append(label)

        if len(tgt_datasets) == 1:
            input_tgt, label = tgt_datasets[0], label_datasets[0]
            if self.cfg.include_src:
                input_src = src_datasets[0]
        else:
            input_tgt = ConcatDataset(tgt_datasets)
            label = ConcatDataset(label_datasets)
            if self.cfg.include_src:
                input_src = ConcatDataset(src_datasets)

        input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions)
        if self.cfg.include_src:
            input_src = PrependTokenDataset(input_src, self.dictionary.bos())
            input_src = TruncateDataset(input_src, self.cfg.max_positions)
            src_lengths = NumelDataset(input_src, reduce=False)
            src_tokens = ConcatSentencesDataset(input_src, input_tgt)
        else:
            src_tokens = PrependTokenDataset(input_tgt, self.dictionary.bos())
            src_lengths = NumelDataset(src_tokens, reduce=False)

        dataset = {
            "id": IdDataset(),
            "net_input": {
                "src_tokens":
                RightPadDataset(
                    src_tokens,
                    pad_idx=self.source_dictionary.pad(),
                ),
                "src_lengths":
                src_lengths,
            },
            "nsentences": NumSamplesDataset(),
            "ntokens": NumelDataset(src_tokens, reduce=True),
            "target": label,
        }

        dataset = NestedDictionaryDataset(
            dataset,
            sizes=[src_tokens.sizes],
        )

        assert len(dataset) % self.cfg.mt_beam == 0, (
            "dataset size (%d) is not a multiple of beam size (%d)" %
            (len(dataset), self.cfg.mt_beam))

        # no need to shuffle valid/test sets
        if not self.cfg.no_shuffle and split == self.cfg.train_subset:

            # need to keep all hypothese together
            start_idx = np.arange(0, len(dataset), self.cfg.mt_beam)
            with data_utils.numpy_seed(self.cfg.seed + epoch):
                np.random.shuffle(start_idx)

            idx = np.arange(0, self.cfg.mt_beam)
            shuffle = np.tile(idx, (len(start_idx), 1)).reshape(-1) + np.tile(
                start_idx, (self.cfg.mt_beam, 1)).transpose().reshape(-1)

            dataset = SortDataset(
                dataset,
                sort_order=[shuffle],
            )

        logger.info(f"Loaded {split} with #samples: {len(dataset)}")

        self.datasets[split] = dataset
        return self.datasets[split]
    def load_dataset(self, split, combine=False, **kwargs):
        def split_exists(split, src, tgt, lang, data_path):
            filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            if self.args.raw_text and IndexedRawTextDataset.exists(filename):
                return True
            elif not self.args.raw_text and IndexedDataset.exists(filename):
                return True
            return False

        def indexed_dataset(path, dictionary):
            if self.args.raw_text:
                raise NotImplementedError
            elif IndexedDataset.exists(path):
                return DPTreeIndexedCachedDataset(path, fix_lua_indexing=True)
            return None

        src_datasets_dict = {k: [] for k in NSTACK_KEYS}
        tgt_datasets = []

        data_paths = self.args.data
        print(f'| split = {split}')
        print(f'| self.args.data = {self.args.data}')

        for dk, data_path in enumerate(data_paths):
            for k in itertools.count():
                split_k = split + (str(k) if k > 0 else '')

                # infer langcode
                src, tgt = self.args.source_lang, self.args.target_lang
                if split_exists(split_k, src, tgt, tgt, data_path):
                    prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
                elif split_exists(split_k, tgt, src, tgt, data_path):
                    prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
                else:
                    if k > 0 or dk > 0:
                        break
                    else:
                        raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))

                for modality in src_datasets_dict.keys():
                    src_datasets_dict[modality].append(indexed_dataset(f'{prefix}{src}.{modality}', self.src_dict))

                # src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
                tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))

                print('| {} {} {} examples'.format(data_path, split_k, len(tgt_datasets[-1])))

                if not combine:
                    break

        assert len(src_datasets_dict[NSTACK_KEYS[0]]) == len(tgt_datasets)

        if len(tgt_datasets) == 1:
            # src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]

            src_dataset_dict = {k: v[0] for k, v in src_datasets_dict.items()}
            tgt_dataset = tgt_datasets[0]
        else:
            sample_ratios = [1] * len(tgt_datasets)
            sample_ratios[0] = self.args.upsample_primary
            # src_dataset = ConcatDataset(src_datasets, sample_ratios)

            src_dataset_dict = {k: ConcatDataset(v, sample_ratios) for k, v in src_datasets_dict.items()}
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

        # src_sizes = src_dataset_dict['nodes'].sizes
        # src_sizes = src_dataset_dict['nodes'].sizes.reshape(-1, 2).sum(-1)
        leave_shape = src_dataset_dict['leaves'].sizes.reshape(-1, 2)
        node_shape = src_dataset_dict['nodes'].sizes.reshape(-1, 2)
        # leaves_sizes = leave_shape.sum(-1)
        # nodes_sizes = node_shape.sum(-1)
        leaves_sizes = leave_shape.prod(-1)
        nodes_sizes = node_shape.prod(-1)
        # print(f'| FIXED VERSION, size must be prod(-1)')
        src_sizes = leaves_sizes + nodes_sizes
        src_nsents = leave_shape[:, 0]
        # print(f'Some leave_size: {leave_shape[:10]}')
        # print(f'Some src_nsent: ({src_nsents[:10]})')

        self.datasets[split] = Nstack2SeqPairDataset(
            src_dataset_dict, src_sizes, self.src_dict, src_nsents,
            tgt_dataset, tgt_dataset.sizes, self.tgt_dict,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            remove_eos_from_source=self.args.remove_eos_from_source,
            append_eos_to_target=self.args.append_eos_to_target,
            input_feeding=self.args.input_feeding,
            is_infer=self.args.infer_mode
        )
예제 #21
0
    def load_lang_dataset(
        self,
        data_path,
        split,
        src,
        src_dict,
        tgt,
        tgt_dict,
        combine,
        dataset_impl,
        upsample_primary,
        max_source_positions,
        prepend_bos=False,
        load_alignments=False,
        truncate_source=False,
    ):

        src_datasets = []
        tgt_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else "")

            # infer langcode
            if self.split_exists(split_k, src, tgt, src, data_path, dataset_impl):
                prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
            elif self.split_exists(split_k, tgt, src, src, data_path, dataset_impl):
                prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
            else:
                if k > 0:
                    break
                else:
                    logger.error(
                        f"Dataset not found: {data_path}, {split_k}, {src}, {tgt}"
                    )
                    raise FileNotFoundError(
                        "Dataset not found: {} ({})".format(split, data_path)
                    )

            src_dataset = self.load_data(prefix + src, src_dict, dataset_impl)
            if truncate_source:
                src_dataset = AppendTokenDataset(
                    TruncateDataset(
                        StripTokenDataset(src_dataset, src_dict.eos()),
                        max_source_positions - 1,
                    ),
                    src_dict.eos(),
                )
            src_datasets.append(src_dataset)
            tgt_datasets.append(self.load_data(prefix + tgt, tgt_dict, dataset_impl))

            logger.info(
                "{} {} {}-{} {} examples".format(
                    data_path, split_k, src, tgt, len(src_datasets[-1])
                )
            )

            if not combine:
                break

        assert len(src_datasets) == len(tgt_datasets)

        if len(src_datasets) == 1:
            src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
        else:
            sample_ratios = [1] * len(src_datasets)
            sample_ratios[0] = upsample_primary
            src_dataset = ConcatDataset(src_datasets, sample_ratios)
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

        if prepend_bos:
            assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
            src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

        align_dataset = None
        if load_alignments:
            align_path = os.path.join(
                data_path, "{}.align.{}-{}".format(split, src, tgt)
            )
            if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
                align_dataset = data_utils.load_indexed_dataset(
                    align_path, None, dataset_impl
                )

        return src_dataset, tgt_dataset, align_dataset
예제 #22
0
def get_asr_dataset_from_json(
    data_path,
    split,
    tgt_dict,
    combine,
    upsample_primary=1,
    num_buckets=0,
    shuffle=True,
    pad_to_multiple=1,
    seed=1,
    global_cmvn_stats_path=None,
    specaugment_config=None,
):
    """
    Parse data json and create dataset.
    See espresso/tools/asr_prep_json.py which pack json from raw files
    Json example:
    {
        "011c0202": {
            "feat": "fbank/raw_fbank_pitch_train_si284.1.ark:54819" or
            "wave": "/export/corpora5/LDC/LDC93S6B/11-1.1/wsj0/si_tr_s/011/011c0202.wv1" or
            "command": "sph2pipe -f wav /export/corpora5/LDC/LDC93S6B/11-1.1/wsj0/si_tr_s/011/011c0202.wv1 |",
            "text": "THE HOTEL",
            "utt2num_frames": "693",
        },
        "011c0203": {
            ...
        }
    }
    """
    src_datasets = []
    tgt_datasets = []
    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else "")
        data_json_path = os.path.join(data_path, "{}.json".format(split_k))
        if not os.path.isfile(data_json_path):
            if k > 0:
                break
            else:
                raise FileNotFoundError(
                    "Dataset not found: {}".format(data_json_path)
                )

        with open(data_json_path, "rb") as f:
            loaded_json = json.load(f, object_pairs_hook=OrderedDict)

        utt_ids, audios, texts, utt2num_frames = [], [], [], []
        for utt_id, val in loaded_json.items():
            utt_ids.append(utt_id)
            if "feat" in val:
                audio = val["feat"]
            elif "wave" in val:
                audio = val["wave"]
            elif "command" in val:
                audio = val["command"]
            else:
                raise KeyError(
                    f"'feat', 'wave' or 'command' should be present as a field for the entry {utt_id} in {data_json_path}"
                )
            audios.append(audio)
            if "text" in val:
                texts.append(val["text"])
            if "utt2num_frames" in val:
                utt2num_frames.append(int(val["utt2num_frames"]))

        assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames)
        if "feat" in next(iter(loaded_json.items())):
            extra_kwargs = {}
        else:
            extra_kwargs = {"feat_dim": 80, "feature_type": "fbank"}
            if global_cmvn_stats_path is not None:
                feature_transforms_config = {
                    "transforms": ["global_cmvn"],
                    "global_cmvn": {"stats_npz_path": global_cmvn_stats_path}
                }
                extra_kwargs["feature_transforms_config"] = feature_transforms_config
        src_datasets.append(AudioFeatDataset(
            utt_ids, audios, utt2num_frames=utt2num_frames, seed=seed,
            specaugment_config=specaugment_config if split == "train" else None,
            **extra_kwargs
        ))
        if len(texts) > 0:
            assert len(utt_ids) == len(texts)
            assert tgt_dict is not None
            tgt_datasets.append(AsrTextDataset(utt_ids, texts, tgt_dict))

        logger.info("{} {} examples".format(data_json_path, len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    feat_dim = src_datasets[0].feat_dim

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        for i in range(1, len(src_datasets)):
            assert (
                feat_dim == src_datasets[i].feat_dim
            ), "feature dimension does not match across multiple json files"
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return AsrDataset(
        src_dataset,
        src_dataset.sizes,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=False,
        left_pad_target=False,
        num_buckets=num_buckets,
        shuffle=shuffle,
        pad_to_multiple=pad_to_multiple,
    )
예제 #23
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

        loaded_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(self.args.data, split_k)

            if self.args.raw_text and IndexedRawTextDataset.exists(path):
                ds = IndexedRawTextDataset(path, self.dictionary)
            elif not self.args.raw_text and IndexedDataset.exists(path):
                if self.args.lazy_load:
                    ds = IndexedDataset(path, fix_lua_indexing=True)
                else:
                    ds = IndexedCachedDataset(path, fix_lua_indexing=True)
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(
                            split, self.args.data))

            loaded_datasets.append(
                TokenBlockDataset(
                    ds,
                    ds.sizes,
                    self.args.tokens_per_sample,
                    pad=self.dictionary.pad(),
                    eos=self.dictionary.eos(),
                    break_mode=self.args.sample_break_mode,
                    include_targets=True,
                ))

            print('| {} {} {} examples'.format(self.args.data, split_k,
                                               len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'

        self.datasets[split] = MonolingualDataset(
            dataset,
            sizes,
            self.dictionary,
            self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=True,
            targets=self.targets,
        )
예제 #24
0
def load_langpair_dataset(
    data_path, split,
    src, src_dict,
    tgt, tgt_dict, ter,
    xml_dico, xml_params,
    combine, dataset_impl, upsample_primary,
    left_pad_source, left_pad_target, max_source_positions, max_target_positions, shuffle=True,
    task='translation_qe'
):
    def split_exists(split, src, tgt, lang, data_path):
        # filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
        # return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
        filename = os.path.join(data_path, split + '.')
        print(filename + src)
        return os.path.exists(filename + src)

    src_datasets = []
    tgt_datasets = []
    ter_datasets = []
    xml_datasets = []
    word_tag_datasets = []
    gap_tag_datasets = []
    bpe_tag_datasets = []
    xml_bpe_tag_datasets = []
    src_word_tag_datasets = []
    src_bpe_tag_datasets = []
    xml_src_bpe_tag_datasets = []
    # tgt_datasets_xml = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')
        filename = os.path.join(data_path, split + '.')
        print(filename + src)
       
        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            # prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
            prefix = os.path.join(data_path, split_k + '.')
            prefix_xml = os.path.join(data_path, 'xml_data/' + split_k + '.')
        elif split_exists(split_k, tgt, src, src, data_path):
            # prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
            prefix = os.path.join(data_path, split_k + '.')
            prefix_xml = os.path.join(data_path, 'xml_data/' + split_k + '.')
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))


        src_datasets.append(
            data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl)
        )
        tgt_datasets.append(
            data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl)
        )

        ter_datasets.append(
            torch.from_numpy(np.loadtxt(prefix + ter))
        )

        xml_datasets.append(
            data_utils.load_indexed_dataset(prefix_xml + src, xml_dico, dataset_impl, path_xml=prefix_xml + tgt)
        )

        if task == 'translation_qe_word':
            word_tag_datasets.append(
                data_utils.load_word_qe_tags(prefix + 'word_tags')
            )

            gap_tag_datasets.append(
                data_utils.load_word_qe_tags(prefix + 'gap_tags')
            )

            src_word_tag_datasets.append(
                data_utils.load_word_qe_tags(prefix + 'src_word_tags')
            )

            bpe_tag_datasets.append(
                data_utils.load_bpe_tags(prefix + 'bpe')
            )

            src_bpe_tag_datasets.append(
                data_utils.load_bpe_tags(prefix + 'src_bpe')
            )

            xml_bpe_tag_datasets.append(
                data_utils.load_bpe_tags(prefix_xml + 'bpe')
            )

            xml_src_bpe_tag_datasets.append(
                data_utils.load_bpe_tags(prefix_xml + 'src_bpe')
            )

        # tgt_datasets_xml.append(
        #     data_utils.load_indexed_dataset(prefix_xml + tgt, xml_dico, dataset_impl)
        # )

        print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets)
    assert len(src_datasets) == len(ter_datasets)

    if len(src_datasets) == 1:
        if task == 'translation_qe_word':
            src_dataset, tgt_dataset, ter_dataset, xml_dataset, \
            word_tag_dataset, gap_tag_dataset, bpe_tag_dataset, xml_bpe_tag_dataset, \
            src_word_tag_dataset, src_bpe_tag_dataset, xml_src_bpe_tag_dataset \
                        = src_datasets[0], tgt_datasets[0], ter_datasets[0], xml_datasets[0], \
                          word_tag_datasets[0], gap_tag_datasets[0], bpe_tag_datasets[0], xml_bpe_tag_datasets[0], \
                          src_word_tag_datasets[0], src_bpe_tag_datasets[0], xml_src_bpe_tag_datasets[0]
        else:
            src_dataset, tgt_dataset, ter_dataset, xml_dataset = src_datasets[0], tgt_datasets[0], ter_datasets[0], xml_datasets[0]
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        ter_dataset = ConcatDataset(ter_datasets, sample_ratios)
        xml_dataset = ConcatDataset(xml_datasets, sample_ratios)

    if task == 'translation_qe_word':
        return LanguagePairWordDataset(
            src_dataset, src_dataset.sizes, src_dict,
            tgt_dataset, tgt_dataset.sizes, tgt_dict,
            ter_dataset, xml_dataset, xml_dico, xml_params, xml_pad_indx=xml_params.pad_index,
            word_tag=word_tag_dataset, gap_tag=gap_tag_dataset,
            bpe_tag=bpe_tag_dataset, xml_bpe_tag=xml_bpe_tag_dataset,
            src_word_tag=src_word_tag_dataset, src_bpe_tag=src_bpe_tag_dataset,
            xml_src_bpe_tag=xml_src_bpe_tag_dataset,
            left_pad_source=left_pad_source,
            left_pad_target=left_pad_target,
            max_source_positions=max_source_positions,
            max_target_positions=max_target_positions,
            shuffle=shuffle
        )
    else:
        return LanguagePairDataset(
            src_dataset, src_dataset.sizes, src_dict,
            tgt_dataset, tgt_dataset.sizes, tgt_dict,
            ter_dataset, xml_dataset, xml_dico, xml_params, xml_pad_indx=xml_params.pad_index,
            left_pad_source=left_pad_source,
            left_pad_target=left_pad_target,
            max_source_positions=max_source_positions,
            max_target_positions=max_target_positions,
            shuffle=shuffle
        )
예제 #25
0
    def load_dataset(self, split, combine=False):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        def split_exists(split, src, tgt, lang, data_path):
            filename = os.path.join(
                data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            if self.args.raw_text and IndexedRawTextDataset.exists(filename):
                return True
            elif not self.args.raw_text and IndexedDataset.exists(filename):
                return True
            return False

        def indexed_dataset(path, dictionary):
            if self.args.raw_text and IndexedRawTextDataset.exists(path):
                return IndexedRawTextDataset(path, dictionary)
            elif not self.args.raw_text and IndexedInMemoryDataset.exists(
                    path):
                return IndexedDataset(path, fix_lua_indexing=False)
            return None

        src_datasets = []
        tgt_datasets = []

        data_paths = self.args.data

        for dk, data_path in enumerate(data_paths):
            for k in itertools.count():
                split_k = split + (str(k) if k > 0 else '')

                # infer langcode
                src, tgt = self.args.source_lang, self.args.target_lang
                if split_exists(split_k, src, tgt, src, data_path):
                    prefix = os.path.join(
                        data_path, '{}.{}-{}.'.format(split_k, src, tgt))
                elif split_exists(split_k, tgt, src, src, data_path):
                    prefix = os.path.join(
                        data_path, '{}.{}-{}.'.format(split_k, tgt, src))
                else:
                    if k > 0 or dk > 0:
                        break
                    else:
                        raise FileNotFoundError(
                            'Dataset not found: {} ({})'.format(
                                split, data_path))

                src_datasets.append(
                    indexed_dataset(prefix + src, self.src_dict))
                tgt_datasets.append(
                    indexed_dataset(prefix + tgt, self.tgt_dict))

                print('| {} {} {} examples'.format(data_path, split_k,
                                                   len(src_datasets[-1])))

                if not combine:
                    break

        assert len(src_datasets) == len(tgt_datasets)

        if len(src_datasets) == 1:
            src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
        else:
            sample_ratios = [1] * len(src_datasets)
            sample_ratios[0] = self.args.upsample_primary
            src_dataset = ConcatDataset(src_datasets, sample_ratios)
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

        self.datasets[split] = SummerizationLanguagePairDataset(
            src_dataset,
            src_dataset.sizes,
            self.src_dict,
            tgt_dataset,
            tgt_dataset.sizes,
            self.tgt_dict,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            with_target=(split != 'test'))
예제 #26
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

        loaded_datasets = []

        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(data_path, split_k)
            ds = indexed_dataset.make_dataset(path,
                                              impl=self.args.dataset_impl,
                                              fix_lua_indexing=True,
                                              dictionary=self.dictionary)

            if ds is None:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(split, data_path))

            loaded_datasets.append(
                TokenBlockDataset(
                    ds,
                    ds.sizes,
                    self.args.tokens_per_sample,
                    pad=self.dictionary.pad(),
                    eos=self.dictionary.eos(),
                    break_mode=self.args.sample_break_mode,
                    include_targets=True,
                ))

            print('| {} {} {} examples'.format(data_path, split_k,
                                               len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'

        self.datasets[split] = MonolingualDataset(
            dataset,
            sizes,
            self.dictionary,
            self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=True,
            targets=self.targets,
            add_bos_token=self.args.add_bos_token,
        )
예제 #27
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        def indexed_dataset(path, dictionary):
            return IndexedCachedDataset(path, fix_lua_indexing=True)

        src_datasets = []
        tgt_datasets = []
        src_lngs = []
        tgt_lngs = []
        dataset_ids = []
        dataset_names = []
        lng_borders = [0]
        data_path = self.args.data[0]
        fns = glob.glob(os.path.join(data_path, f'{split}.*'))
        lng_pairs = list(set([f.split('.')[1] for f in fns]))
        lng_pairs = sorted(lng_pairs)
        ds_idx = 0
        sources = [s for s in self.args.sources.split(",") if s != '']
        targets = [t for t in self.args.targets.split(",") if t != '']

        is_distill = self.args.criterion == 'distill_label_smoothed_cross_entropy' and split == 'train'
        topk_idxs = []
        topk_probs = []
        expert_scores = []

        for idx, lng_pair in enumerate(lng_pairs):
            src, tgt = lng_pair.split('-')
            prefix = os.path.join(data_path,
                                  '{}.{}-{}.'.format(split, src, tgt))

            def add_dataset(src, tgt):
                if (src not in sources
                        and len(sources) > 0) or (tgt not in targets
                                                  and len(targets) > 0):
                    return 0
                if not os.path.exists(prefix + src + ".bin") or \
                        not os.path.exists(prefix + tgt + ".bin"):
                    return 0

                if is_distill and not os.path.exists(
                        os.path.join(self.args.data[0],
                                     '{}_{}_topk_idx.idx'.format(src, tgt))):
                    return 0

                src_ds = indexed_dataset(prefix + src, self.src_dict)
                src_datasets.append(src_ds)
                tgt_ds = indexed_dataset(prefix + tgt, self.tgt_dict)
                tgt_datasets.append(tgt_ds)

                l = len(src_ds)
                if self.args.data_limit != '' \
                        and src + "-" + tgt == self.args.data_limit.split(':')[0] \
                        and l > int(self.args.data_limit.split(':')[1]):
                    l = int(self.args.data_limit.split(':')[1])
                    src_datasets[-1].size = l
                    tgt_datasets[-1].size = l
                    l = len(src_ds)

                print("| Add dataset {} -> {}. size:{}".format(src, tgt, l))
                lng_borders.append(lng_borders[-1] + l)
                dataset_names.append(f"{src}_{tgt}")
                for i in range(l):
                    src_lngs.append(self.lng2id[src])
                    tgt_lngs.append(self.lng2id[tgt])
                    dataset_ids.append(ds_idx)

                if is_distill:
                    assert self.args.data_limit == ''
                    path = os.path.join(self.args.data[0],
                                        '{}_{}_topk_idx'.format(src, tgt))
                    topk_idxs.append(TeacherOutputDataset(path))
                    path = os.path.join(self.args.data[0],
                                        '{}_{}_topk_prob'.format(src, tgt))
                    topk_probs.append(TeacherOutputDataset(path))
                    expert_bleu = os.path.join(
                        self.args.data[0],
                        'expert_bleu_{}_{}.json'.format(src, tgt))
                    expert_bleu = json.load(open(expert_bleu))
                    expert_scores.append(expert_bleu[f"bleu_{src}_{tgt}"])
                return 1

            ds_idx += add_dataset(src, tgt)
            ds_idx += add_dataset(tgt, src)

        src_dataset = ConcatDataset(src_datasets)
        tgt_dataset = ConcatDataset(tgt_datasets)
        src_sizes = np.concatenate([ds.sizes for ds in src_datasets])
        tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets])

        topk_idx_dataset = None
        topk_probs_dataset = None
        if is_distill:
            topk_idx_dataset = ConcatDataset(topk_idxs)
            topk_probs_dataset = ConcatDataset(topk_probs)
            assert len(topk_probs_dataset) == len(tgt_dataset), (
                len(topk_probs_dataset), len(tgt_dataset))
            assert len(topk_idx_dataset) == len(tgt_dataset)

        self.datasets[split] = UniversalDataset(
            self.args,
            src_dataset,
            src_sizes,
            self.src_dict,
            src_lngs,
            tgt_lngs,
            tgt_dataset,
            tgt_sizes,
            self.tgt_dict,
            dataset_ids,
            lng_borders,
            dataset_names,
            topk_idxs=topk_idx_dataset,
            topk_probs=topk_probs_dataset,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            expert_scores=expert_scores,
            is_train=split == 'train')
예제 #28
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        languages = sorted(name for name in os.listdir(data_path)
                           if os.path.isdir(os.path.join(data_path, name)))

        logger.info("Training on {0} languages: {1}".format(
            len(languages), languages))
        logger.info("Language to id mapping: ",
                    {lang: id
                     for id, lang in enumerate(languages)})

        mask_whole_words = self._get_whole_word_mask()
        lang_datasets = []
        for lang_id, language in enumerate(languages):
            split_path = os.path.join(data_path, language, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            if dataset is None:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, split_path))

            # create continuous blocks of tokens
            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                self.args.tokens_per_sample - 1,  # one less for <s>
                pad=self.source_dictionary.pad(),
                eos=self.source_dictionary.eos(),
                break_mode=self.args.sample_break_mode,
            )
            logger.info('loaded {} blocks from: {}'.format(
                len(dataset), split_path))

            # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
            dataset = PrependTokenDataset(dataset,
                                          self.source_dictionary.bos())

            src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
                dataset,
                self.source_dictionary,
                pad_idx=self.source_dictionary.pad(),
                mask_idx=self.mask_idx,
                seed=self.args.seed,
                mask_prob=self.args.mask_prob,
                leave_unmasked_prob=self.args.leave_unmasked_prob,
                random_token_prob=self.args.random_token_prob,
                freq_weighted_replacement=self.args.freq_weighted_replacement,
                mask_whole_words=mask_whole_words,
            )

            lang_dataset = NestedDictionaryDataset(
                {
                    'net_input': {
                        'src_tokens':
                        PadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                            left_pad=False,
                        ),
                        'src_lengths':
                        NumelDataset(src_dataset, reduce=False),
                    },
                    'target':
                    PadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                        left_pad=False,
                    ),
                    'nsentences':
                    NumSamplesDataset(),
                    'ntokens':
                    NumelDataset(src_dataset, reduce=True),
                    'lang_id':
                    RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]),
                },
                sizes=[src_dataset.sizes],
            )
            lang_datasets.append(lang_dataset)

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        logger.info('loaded total {} blocks for all languages'.format(
            dataset_lengths.sum(), ))
        if split == self.args.train_subset:
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(dataset_lengths)
            logger.info(
                "Sample probability by language: ", {
                    lang: "{0:.4f}".format(sample_probs[id])
                    for id, lang in enumerate(languages)
                })
            size_ratio = (sample_probs *
                          dataset_lengths.sum()) / dataset_lengths
            logger.info(
                "Up/Down Sampling ratio by language: ", {
                    lang: "{0:.2f}".format(size_ratio[id])
                    for id, lang in enumerate(languages)
                })

            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args.seed,
                    epoch=epoch,
                    replace=size_ratio[i] >= 1.0,
                ) for i, d in enumerate(lang_datasets)
            ]
            dataset = ConcatDataset(resampled_lang_datasets)
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            for lang_id, lang_dataset in enumerate(lang_datasets):
                split_name = split + '_' + languages[lang_id]
                lang_splits.append(split_name)
                self.datasets[split_name] = lang_dataset

            # [TODO]: This is hacky for now to print validation ppl for each
            # language individually. Maybe need task API changes to allow it
            # in more generic ways.
            if split in self.args.valid_subset:
                self.args.valid_subset = self.args.valid_subset.replace(
                    split, ','.join(lang_splits))

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
예제 #29
0
    def load_dataset(self, split, epoch=1, combine=False):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        loaded_datasets = []

        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        logger.info("data_path", data_path)

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(data_path, split_k)
            ds = indexed_dataset.make_dataset(
                path,
                impl=self.args.dataset_impl,
                fix_lua_indexing=True,
                dictionary=self.dictionary,
            )

            if ds is None:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(split, data_path))

            with data_utils.numpy_seed(self.seed + k):
                loaded_datasets.append(
                    BlockPairDataset(
                        ds,
                        self.dictionary,
                        ds.sizes,
                        self.args.tokens_per_sample,
                        break_mode=self.args.break_mode,
                        doc_break_size=1,
                    ))

            logger.info('{} {} {} examples都是非常重要的例子'.format(
                data_path, split_k, len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        self.datasets[split] = MaskedLMDataset(
            dataset=dataset,
            sizes=sizes,
            vocab=self.dictionary,
            pad_idx=self.dictionary.pad(),
            mask_idx=self.dictionary.mask(),
            classif_token_idx=self.dictionary.cls(),
            sep_token_idx=self.dictionary.sep(),
            shuffle=self.args.shuffle_dataset,
            seed=self.seed,
        )
예제 #30
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path,
                                '{}.{}-{}.{}'.format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            src_prefix = os.path.join(data_path,
                                      '{}.{}-{}.'.format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            src_prefix = os.path.join(data_path,
                                      '{}.{}-{}.'.format(split_k, tgt, src))
        elif split_exists(split_k, src, tgt.split("_")[0], src, data_path):
            src_prefix = os.path.join(
                data_path, '{}.{}-{}.'.format(split_k, src,
                                              tgt.split("_")[0]))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, data_path))

        src_datasets.append(
            data_utils.load_indexed_dataset(src_prefix + src, src_dict,
                                            dataset_impl))

        if split_exists(split_k, src, tgt, tgt, data_path):
            tgt_prefix = os.path.join(data_path,
                                      '{}.{}-{}.'.format(split_k, src, tgt))
        elif split_exists(split_k, src, tgt.split("_")[0], tgt, data_path):
            tgt_prefix = os.path.join(
                data_path, '{}.{}-{}.'.format(split_k, src,
                                              tgt.split("_")[0]))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, data_path))
        tgt_datasets.append(
            data_utils.load_indexed_dataset(tgt_prefix + tgt, tgt_dict,
                                            dataset_impl))

        print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt,
                                                 len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets)

    if len(src_datasets) == 1:
        src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)

    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset.sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
    )