예제 #1
0
    def from_tsv(cls,
                 root: str,
                 cfg: S2TDataConfig,
                 splits: str,
                 tgt_dict,
                 pre_tokenizer,
                 bpe_tokenizer,
                 is_train_split: bool,
                 epoch: int,
                 seed: int,
                 n_frames_per_step: int = 1,
                 speaker_to_id=None) -> SpeechToTextDataset:
        datasets = [
            cls._from_tsv(root, cfg, split, tgt_dict, is_train_split,
                          pre_tokenizer, bpe_tokenizer, n_frames_per_step,
                          speaker_to_id) for split in splits.split(",")
        ]

        if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
            # temperature-based sampling
            size_ratios = cls.get_size_ratios(datasets,
                                              alpha=cfg.sampling_alpha)
            datasets = [
                ResamplingDataset(d,
                                  size_ratio=r,
                                  seed=seed,
                                  epoch=epoch,
                                  replace=(r >= 1.0))
                for r, d in zip(size_ratios, datasets)
            ]

        return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
예제 #2
0
    def resample_datasets(self, lang_datasets, lang_pairs_all, epoch):
        # For train subset, additionally up or down sample languages.
        if self.args.multilang_sampling_alpha == 1.0:
            return lang_datasets

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        sample_probs = self._get_sample_prob(dataset_lengths)
        logger.info("Sample probability by language pair: {}".format({
            lp: "{0:.4f}".format(sample_probs[id])
            for id, lp in enumerate(lang_pairs_all)
        }))
        size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
        logger.info("Up/Down Sampling ratio by language: {}".format({
            lp: "{0:.2f}".format(size_ratio[id])
            for id, lp in enumerate(lang_pairs_all)
        }))

        resampled_lang_datasets = [
            ResamplingDataset(
                lang_datasets[i],
                size_ratio=size_ratio[i],
                seed=self.args.seed,
                epoch=epoch,
                replace=size_ratio[i] >= 1.0,
            ) for i, d in enumerate(lang_datasets)
        ]
        return resampled_lang_datasets
예제 #3
0
    def from_tsv(
        cls,
        root: str,
        data_cfg: S2TDataConfigSrc,
        splits: str,
        tgt_dict,
        src_dict,
        pre_tokenizer,
        bpe_tokenizer,
        is_train_split: bool,
        epoch: int,
        seed: int,
    ) -> SpeechToTextDatasetWithSrc:
        samples = []
        _splits = splits.split(",")
        for split in _splits:
            tsv_path = op.join(root, f"{split}.tsv")
            if not op.isfile(tsv_path):
                raise FileNotFoundError(f"Dataset not found: {tsv_path}")
            with open(tsv_path) as f:
                reader = csv.DictReader(
                    f,
                    delimiter="\t",
                    quotechar=None,
                    doublequote=False,
                    lineterminator="\n",
                    quoting=csv.QUOTE_NONE,
                )
                samples.append([dict(e) for e in reader])
                assert len(samples) > 0

        datasets = [
            cls._from_list(
                name,
                is_train_split,
                [s],
                data_cfg,
                tgt_dict,
                src_dict,
                pre_tokenizer,
                bpe_tokenizer,
            ) for name, s in zip(_splits, samples)
        ]

        if is_train_split and len(
                _splits) > 1 and data_cfg.sampling_alpha != 1.0:
            # temperature-based sampling
            size_ratios = cls._get_size_ratios(_splits,
                                               [len(s) for s in samples],
                                               alpha=data_cfg.sampling_alpha)
            datasets = [
                ResamplingDataset(d,
                                  size_ratio=r,
                                  seed=seed,
                                  epoch=epoch,
                                  replace=(r >= 1.0))
                for d, r in zip(datasets, size_ratios)
            ]
        return ConcatDataset(datasets)
예제 #4
0
 def load_langpair_dataset(self,
                           prepend_tgt_lang_tag=False,
                           sampling_alpha=1.0,
                           epoch=0):
     lang_pairs = []
     text_dataset = None
     split = "train"
     for lp in self.args.langpairs.split(","):
         src, tgt = lp.split("-")
         text_dataset = load_langpair_dataset(
             self.args.parallel_text_data,
             split,
             src,
             self.src_dict,
             tgt,
             self.tgt_dict,
             combine=True,
             dataset_impl=None,
             upsample_primary=1,
             left_pad_source=False,
             left_pad_target=False,
             max_source_positions=self.args.max_positions_text,
             max_target_positions=self.args.max_target_positions,
             load_alignments=False,
             truncate_source=False,
         )
         if prepend_tgt_lang_tag:
             # TODO
             text_dataset = TransformEosLangPairDataset(
                 text_dataset,
                 src_eos=self.src_dict.eos(),
                 tgt_bos=self.tgt_dict.eos(
                 ),  # 'prev_output_tokens' starts with eos
                 new_tgt_bos=self.tgt_dict.index(
                     LANG_TAG_TEMPLATE.format(tgt)),
             )
         lang_pairs.append(text_dataset)
     if len(lang_pairs) > 1:
         if sampling_alpha != 1.0:
             size_ratios = SpeechToTextDatasetCreator.get_size_ratios(
                 self.args.langpairs.split(","),
                 [len(s) for s in lang_pairs],
                 alpha=sampling_alpha,
             )
             lang_pairs = [
                 ResamplingDataset(d,
                                   size_ratio=r,
                                   epoch=epoch,
                                   replace=(r >= 1.0))
                 for d, r in zip(lang_pairs, size_ratios)
             ]
         return ConcatDataset(lang_pairs)
     return text_dataset
예제 #5
0
    def from_tsv(
        cls,
        root: str,
        cfg: S2TJointDataConfig,
        splits: str,
        tgt_dict,
        src_dict,
        pre_tokenizer,
        bpe_tokenizer,
        src_pre_tokenizer,
        src_bpe_tokenizer,
        is_train_split: bool,
        epoch: int,
        seed: int,
        append_eos: Optional[bool] = True,
        use_src_lang_id: Optional[int] = 0,
    ) -> SpeechToTextJointDataset:
        datasets = [
            cls._from_tsv(
                root,
                cfg,
                split,
                tgt_dict,
                src_dict,
                is_train_split,
                pre_tokenizer,
                bpe_tokenizer,
                src_pre_tokenizer,
                src_bpe_tokenizer,
                append_eos=append_eos,
                use_src_lang_id=use_src_lang_id,
            ) for split in splits.split(",")
        ]

        if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
            # temperature-based sampling
            size_ratios = cls.get_size_ratios(datasets,
                                              alpha=cfg.sampling_alpha)
            datasets = [
                ResamplingDataset(d,
                                  size_ratio=r,
                                  seed=seed,
                                  epoch=epoch,
                                  replace=(r >= 1.0))
                for r, d in zip(size_ratios, datasets)
            ]

        return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
    def test_resampling_dataset_batch_by_size_true(self):
        resampling_dataset = ResamplingDataset(
            self.dataset,
            self.weights,
            size_ratio=self.size_ratio,
            batch_by_size=True,
            seed=0,
        )

        results = self._test_common(resampling_dataset, iters=1000)

        # For batch_by_size = True, the batches should be returned in
        # increasing order of size.
        assert results["ordered_by_size"]

        # Allow tolerance in distribution error of 2%.
        assert results["max_distribution_diff"] < 0.02
예제 #7
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(":")
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        if self.langs is None:
            languages = sorted(
                [
                    name
                    for name in os.listdir(data_path)
                    if os.path.isdir(os.path.join(data_path, name))
                ]
            )
        else:
            languages = self.langs.split(",")
            for name in languages:
                p = os.path.join(data_path, name)
                assert os.path.exists(p), "data not found: {}".format(p)

        logger.info("Training on {0} languages: {1}".format(len(languages), languages))
        logger.info(
            "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}
        )

        mask_whole_words = get_whole_word_mask(self.args, self.dictionary)
        language_without_segmentations = self.args.no_whole_word_mask_langs.split(",")
        lang_datasets = []
        for language in languages:
            split_path = os.path.join(data_path, language, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            if dataset is None:
                raise FileNotFoundError(
                    "Dataset not found: {} ({})".format(split, split_path)
                )

            end_token = (
                self.source_dictionary.index("[{}]".format(language))
                if self.args.add_lang_token
                else self.source_dictionary.eos()
            )

            # create continuous blocks of tokens
            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                self.args.tokens_per_sample - 2,  # one less for <s>
                pad=self.source_dictionary.pad(),
                eos=end_token,
                break_mode=self.args.sample_break_mode,
            )
            logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))

            # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
            dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
            dataset = AppendTokenDataset(dataset, end_token)

            lang_mask_whole_words = (
                mask_whole_words
                if language not in language_without_segmentations
                else None
            )
            lang_dataset = DenoisingDataset(
                dataset,
                dataset.sizes,
                self.dictionary,
                self.mask_idx,
                lang_mask_whole_words,
                shuffle=self.args.shuffle_instance,
                seed=self.seed,
                args=self.args,
                eos=None
                if not self.args.add_lang_token
                else self.source_dictionary.index("[{}]".format(language)),
            )
            lang_datasets.append(lang_dataset)

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        logger.info(
            "loaded total {} blocks for all languages".format(
                int(dataset_lengths.sum()),
            )
        )
        if split == self.args.train_subset:
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(dataset_lengths)
            logger.info(
                "Sample probability by language: {}".format(
                    {
                        lang: "{0:.4f}".format(sample_probs[id])
                        for id, lang in enumerate(languages)
                    }
                )
            )
            size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
            logger.info(
                "Up/Down Sampling ratio by language: {}".format(
                    {
                        lang: "{0:.2f}".format(size_ratio[id])
                        for id, lang in enumerate(languages)
                    }
                )
            )

            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args.seed,
                    epoch=epoch,
                    replace=size_ratio[i] >= 1.0,
                )
                for i, d in enumerate(lang_datasets)
            ]
            dataset = ConcatDataset(
                resampled_lang_datasets,
            )
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            for lang_id, lang_dataset in enumerate(lang_datasets):
                split_name = split + "_" + languages[lang_id]
                lang_splits.append(split_name)
                self.datasets[split_name] = lang_dataset

            if split in self.args.valid_subset:
                self.args.valid_subset = self.args.valid_subset.replace(
                    split, ",".join(lang_splits)
                )

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
예제 #8
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        languages = sorted(name for name in os.listdir(data_path)
                           if os.path.isdir(os.path.join(data_path, name)))

        logger.info("Training on {0} languages: {1}".format(
            len(languages), languages))
        logger.info("Language to id mapping: ",
                    {lang: id
                     for id, lang in enumerate(languages)})

        mask_whole_words = self._get_whole_word_mask()
        lang_datasets = []
        for lang_id, language in enumerate(languages):
            split_path = os.path.join(data_path, language, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            if dataset is None:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, split_path))

            # create continuous blocks of tokens
            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                self.args.tokens_per_sample - 1,  # one less for <s>
                pad=self.source_dictionary.pad(),
                eos=self.source_dictionary.eos(),
                break_mode=self.args.sample_break_mode,
            )
            logger.info('loaded {} blocks from: {}'.format(
                len(dataset), split_path))

            # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
            dataset = PrependTokenDataset(dataset,
                                          self.source_dictionary.bos())

            src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
                dataset,
                self.source_dictionary,
                pad_idx=self.source_dictionary.pad(),
                mask_idx=self.mask_idx,
                seed=self.args.seed,
                mask_prob=self.args.mask_prob,
                leave_unmasked_prob=self.args.leave_unmasked_prob,
                random_token_prob=self.args.random_token_prob,
                freq_weighted_replacement=self.args.freq_weighted_replacement,
                mask_whole_words=mask_whole_words,
            )

            lang_dataset = NestedDictionaryDataset(
                {
                    'net_input': {
                        'src_tokens':
                        PadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                            left_pad=False,
                        ),
                        'src_lengths':
                        NumelDataset(src_dataset, reduce=False),
                    },
                    'target':
                    PadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                        left_pad=False,
                    ),
                    'nsentences':
                    NumSamplesDataset(),
                    'ntokens':
                    NumelDataset(src_dataset, reduce=True),
                    'lang_id':
                    RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]),
                },
                sizes=[src_dataset.sizes],
            )
            lang_datasets.append(lang_dataset)

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        logger.info('loaded total {} blocks for all languages'.format(
            dataset_lengths.sum(), ))
        if split == self.args.train_subset:
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(dataset_lengths)
            logger.info(
                "Sample probability by language: ", {
                    lang: "{0:.4f}".format(sample_probs[id])
                    for id, lang in enumerate(languages)
                })
            size_ratio = (sample_probs *
                          dataset_lengths.sum()) / dataset_lengths
            logger.info(
                "Up/Down Sampling ratio by language: ", {
                    lang: "{0:.2f}".format(size_ratio[id])
                    for id, lang in enumerate(languages)
                })

            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args.seed,
                    epoch=epoch,
                    replace=size_ratio[i] >= 1.0,
                ) for i, d in enumerate(lang_datasets)
            ]
            dataset = ConcatDataset(resampled_lang_datasets)
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            for lang_id, lang_dataset in enumerate(lang_datasets):
                split_name = split + '_' + languages[lang_id]
                lang_splits.append(split_name)
                self.datasets[split_name] = lang_dataset

            # [TODO]: This is hacky for now to print validation ppl for each
            # language individually. Maybe need task API changes to allow it
            # in more generic ways.
            if split in self.args.valid_subset:
                self.args.valid_subset = self.args.valid_subset.replace(
                    split, ','.join(lang_splits))

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
    def load_dataset(self, split: str, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        languages, data_path = MultilingualLanguageModelingTask._get_langs(
            self.args, epoch)
        lang_to_offline_shard_ratio = None
        if self.args.lang_to_offline_shard_ratio != "":
            lang_to_offline_shard_ratio = {}
            assert os.path.exists(
                self.args.lang_to_offline_shard_ratio
            ), "provided offline shard ratio file doesn't exist: {0}".format(
                self.args.lang_to_offline_shard_ratio)
            with open(self.args.lang_to_offline_shard_ratio) as fin:
                for line in fin:
                    lang, ratio = line.strip().split("\t")
                    ratio = float(ratio)
                    lang_to_offline_shard_ratio[lang] = ratio

            logger.info(
                "Found offline sharded ratio: %s",
                lang_to_offline_shard_ratio,
            )

        if split == self.args.train_subset:
            logger.info("Training on {0} languages: {1}".format(
                len(languages), languages))
        else:
            logger.info("Evaluating on {0} languages: {1}".format(
                len(languages), languages))

        tokens_per_sample = self.args.tokens_per_sample - int(
            self.args.add_bos_token)

        fixed_pad_length = None
        if self.args.pad_to_fixed_length:
            fixed_pad_length = self.args.tokens_per_sample

        pad_to_bsz = None
        if self.args.pad_to_fixed_bsz:
            pad_to_bsz = (self.args.batch_size_valid
                          if "valid" in split else self.args.batch_size)

        lang_datasets = []
        for lang_id, language in enumerate(languages):
            split_path = os.path.join(data_path, language, split)
            dataset = data_utils.load_indexed_dataset(split_path,
                                                      self.dictionary,
                                                      self.args.dataset_impl,
                                                      combine=combine)
            # print('len(dataset) =', len(dataset))
            if dataset is None:
                raise FileNotFoundError("Dataset not found: {} ({})".format(
                    split, split_path))

            dataset = maybe_shorten_dataset(
                dataset,
                split,
                self.args.shorten_data_split_list,
                self.args.shorten_method,
                tokens_per_sample,
                self.args.seed,
            )

            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                tokens_per_sample,
                pad=self.dictionary.pad(),
                eos=self.dictionary.eos(),
                break_mode=self.args.sample_break_mode,
                include_targets=True,
            )

            add_eos_for_other_targets = (
                self.args.sample_break_mode is not None
                and self.args.sample_break_mode != "none")
            src_lang_idx, tgt_lang_idx = None, None
            if self.args.add_bos_token:
                src_lang_idx = self.dictionary.index(lang_token(language))
                tgt_lang_idx = self.output_dictionary.index(
                    lang_token(language))

            lang_datasets.append(
                MonolingualDataset(
                    dataset=dataset,
                    sizes=dataset.sizes,
                    src_vocab=self.dictionary,
                    tgt_vocab=self.output_dictionary,
                    add_eos_for_other_targets=add_eos_for_other_targets,
                    shuffle=True,
                    targets=self.targets,
                    fixed_pad_length=fixed_pad_length,
                    pad_to_bsz=pad_to_bsz,
                    add_bos_token=self.args.add_bos_token,
                    src_lang_idx=src_lang_idx,
                    tgt_lang_idx=tgt_lang_idx,
                ))

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        logger.info("loaded total {} blocks for all languages".format(
            dataset_lengths.sum(), ))
        if split == self.args.train_subset:
            dataset_lengths_ratio_multiplier = np.ones(len(dataset_lengths))
            if lang_to_offline_shard_ratio is not None:
                dataset_lengths_ratio_multiplier = []
                for lang in languages:
                    assert (
                        lang in lang_to_offline_shard_ratio
                    ), "Lang: {0} missing in offline shard ratio file: {1}".format(
                        lang,
                        self.args.lang_to_offline_shard_ratio,
                    )
                    dataset_lengths_ratio_multiplier.append(
                        lang_to_offline_shard_ratio[lang])
                dataset_lengths_ratio_multiplier = np.array(
                    dataset_lengths_ratio_multiplier)
                true_dataset_lengths = (dataset_lengths *
                                        dataset_lengths_ratio_multiplier)
            else:
                true_dataset_lengths = dataset_lengths
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(true_dataset_lengths)

            logger.info(
                "Sample probability by language: %s",
                {
                    lang: "{0:.4f}".format(sample_probs[id])
                    for id, lang in enumerate(languages)
                },
            )
            size_ratio = (sample_probs *
                          true_dataset_lengths.sum()) / dataset_lengths
            # TODO: add an option for shrinking all size ratios to below 1
            # if self.args.multilang_sampling_alpha != 1:
            #   size_ratio /= size_ratio.max()

            # Fix numeric errors in size ratio computation
            #   0.999999999999999999 -> 1
            #   1.000000000000000002 -> 1
            for i in range(len(size_ratio)):
                size_ratio[i] = round(size_ratio[i], 8)

            logger.info(
                "Up/Down Sampling ratio by language: %s",
                {
                    lang: "{0:.2f}".format(size_ratio[id])
                    for id, lang in enumerate(languages)
                },
            )
            logger.info(
                "Actual dataset size by language: %s",
                {
                    lang: "{0:.2f}".format(len(lang_datasets[id]))
                    for id, lang in enumerate(languages)
                },
            )
            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args.seed,
                    epoch=epoch,
                    replace=size_ratio[i] > 1.0,
                ) for i, d in enumerate(lang_datasets)
            ]
            logger.info(
                "Resampled dataset size by language: %s",
                {
                    lang: "{0:.2f}".format(len(resampled_lang_datasets[id]))
                    for id, lang in enumerate(languages)
                },
            )
            dataset = ConcatDataset(resampled_lang_datasets)
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            for lang_id, lang_dataset in enumerate(lang_datasets):
                split_name = split + "_" + languages[lang_id]
                lang_splits.append(split_name)
                self.datasets[split_name] = lang_dataset

            # [TODO]: This is hacky for now to print validation ppl for each
            # language individually. Maybe need task API changes to allow it
            # in more generic ways.
            if split in self.args.valid_subset:
                self.args.valid_subset = self.args.valid_subset.replace(
                    split, ",".join(lang_splits))

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
예제 #10
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        # infer langcode

        lg_datasets = []
        for lg in self.gt_langs:
            src, tgt = lg, lg
            bos_id = self.tgt_dict.index('[{}]'.format(lg))
            data_path_lg = os.path.join(data_path, lg)
            dataset = load_generation_pair_dataset(
                data_path_lg,
                split,
                tgt,
                self.src_dict,
                self.tgt_dict,
                combine=combine,
                dataset_impl=self.args.dataset_impl,
                upsample_primary=self.args.upsample_primary,
                left_pad_source=self.args.left_pad_source,
                left_pad_target=self.args.left_pad_target,
                max_source_positions=getattr(self.args, 'max_source_positions',
                                             1024),
                max_target_positions=getattr(self.args, 'max_target_positions',
                                             1024),
                load_alignments=self.args.load_alignments,
                prepend_bos=getattr(self.args, 'preprend_bos', False),
                append_source_id=True,
                common_eos=self.args.common_eos,
                lg_id=bos_id)
            lg_datasets.append(dataset)

        dataset_lengths = np.array([len(d) for d in lg_datasets], dtype=float)

        sample_probs = self._get_sample_prob(dataset_lengths)
        logger.info(
            "| Sample probability by language: ", {
                lang: "{0:.4f}".format(sample_probs[id])
                for id, lang in enumerate(self.gt_langs)
            })
        size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
        logger.info(
            "| Up/Down Sampling ratio by language: ", {
                lang: "{0:.2f}".format(size_ratio[id])
                for id, lang in enumerate(self.gt_langs)
            })
        if split == getattr(self.args, "train_subset", "train"):
            resampled_lang_datasets = [
                ResamplingDataset(
                    lg_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args.seed,
                    epoch=epoch,
                    replace=size_ratio[i] >= 1.0,
                ) for i, d in enumerate(lg_datasets)
            ]
            dataset = ConcatDataset(resampled_lang_datasets, )
        else:
            dataset = ConcatDataset(lg_datasets)
            lang_splits = [split]
            for lang_id, lang_dataset in enumerate(lg_datasets):
                split_name = split + '_' + self.gt_langs[lang_id]
                lang_splits.append(split_name)
                self.datasets[split_name] = lang_dataset

            if hasattr(self.args, "valid_subset"):
                if split in self.args.valid_subset:
                    self.args.valid_subset = self.args.valid_subset.replace(
                        split, ','.join(lang_splits))

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))
        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
예제 #11
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0

        # if not training data set, use the first shard for valid and test
        if split != getattr(self.args, "train_subset", None):
            paths = paths[:1]
        data_path = paths[(epoch - 1) % len(paths)]

        lang_datasets = []
        for lang_pair in self.args.langs.split(","):
            [src, tgt] = lang_pair.split("-")

            # if multi-valid : valid.zh-en
            if "valid" in split and split != "valid":
                split_lang = split.split(".")[-1]
                if ( (lang_pair != split_lang) and (tgt + "-" + src != split_lang) ):
                    continue
                # special for (fil and fi) langs.
                if (lang_pair != split_lang and (tgt + "-" + src == split_lang and "fil" in split_lang)):
                    continue

            data_path_temp = data_path
            if self.args.add_lang_token:
                add_langs = (src, tgt)
            else:
                add_langs = None
            dataset = load_langpair_dataset(
                data_path_temp,
                split,
                src,
                self.src_dict,
                tgt,
                self.tgt_dict,
                combine=combine,
                dataset_impl=self.args.dataset_impl,
                upsample_primary=self.args.upsample_primary,
                left_pad_source=self.args.left_pad_source,
                left_pad_target=self.args.left_pad_target,
                max_source_positions=self.args.max_source_positions,
                max_target_positions=self.args.max_target_positions,
                load_alignments=self.args.load_alignments,
                truncate_source=self.args.truncate_source,
                num_buckets=self.args.num_batch_buckets,
                shuffle=(split != "test"),
                pad_to_multiple=self.args.required_seq_len_multiple,
                plus_encoder_loss=self.args.plus_encoder_loss,
                add_langs=add_langs,
                shuffle_lang_pair=self.args.shuffle_lang_pair,
                args=self.args,
                word_trans_dict=self.word_trans_dict,
                word_align_dict=self.word_align,
                policy_ratio_dicts=self.policy_ratio_dicts
            )
            lang_datasets.append(dataset)

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        logger.info(
            "loaded total {} blocks for all languages".format(
                int(dataset_lengths.sum()),
            )
        )
        if split == self.args.train_subset:
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(dataset_lengths)
            logger.info(
                "Sample probability by language: {}".format(
                    {
                        lang: "{0:.4f}".format(sample_probs[id])
                        for id, lang in enumerate(self.args.langs.split(","))
                    }
                )
            )
            size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
            logger.info(
                "Up/Down Sampling ratio by language: {}".format(
                    {
                        lang: "{0:.2f}".format(size_ratio[id])
                        for id, lang in enumerate(self.args.langs.split(","))
                    }
                )
            )

            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args.seed,
                    epoch=epoch,
                    replace=size_ratio[i] >= 1.0,
                )
                for i, d in enumerate(lang_datasets)
            ]
            dataset = ConcatPairDataset(
                resampled_lang_datasets,
            )
        else:
            dataset = ConcatPairDataset(lang_datasets)
        self.datasets[split] = dataset