예제 #1
0
    def load_dataset(self, split, epoch=0, **kwargs):
        """Load a dataset split."""

        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        def split_exists(split, src, tgt, lang):
            if src is not None:
                filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            else:
                filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, src, tgt))
            if self.args.raw_text and IndexedRawTextDataset.exists(filename):
                return True
            elif not self.args.raw_text and IndexedDataset.exists(filename):
                return True
            return False

        def indexed_dataset(path, dictionary):
            if self.args.raw_text:
                return IndexedRawTextDataset(path, dictionary)
            elif IndexedDataset.exists(path):
                if self.args.lazy_load:
                    return IndexedDataset(path, fix_lua_indexing=True)
                else:
                    return IndexedCachedDataset(path, fix_lua_indexing=True)
            return None

        # load parallel datasets
        src_datasets, tgt_datasets = {}, {}
        if (self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train")):
            for lang_pair in self.lang_pairs:
                src, tgt = lang_pair.split('-')
                if split_exists(split, src, tgt, src):
                    prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt))
                elif split_exists(split, tgt, src, src):
                    prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src))
                else:
                    continue
                src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src])
                tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt])
                print('| parallel-{} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair])))
            if len(src_datasets) == 0:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))

        # back translation datasets
        backtranslate_datasets = {}
        if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and split.startswith("train"):
            for lang_pair in self.lang_pairs:
                src, tgt = lang_pair.split('-')
                if not split_exists(split, tgt, None, tgt):
                    raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, data_path))
                filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt))
                dataset = indexed_dataset(filename, self.dicts[tgt])
                lang_pair_dataset_tgt = LanguagePairDataset(
                    dataset,
                    dataset.sizes,
                    self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                )
                lang_pair_dataset = LanguagePairDataset(
                    dataset,
                    dataset.sizes,
                    src_dict=self.dicts[src],
                    tgt=dataset,
                    tgt_sizes=dataset.sizes,
                    tgt_dict=self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                )
                backtranslate_datasets[lang_pair] = BacktranslationDataset(
                    tgt_dataset=self.alter_dataset_langtok(
                        lang_pair_dataset_tgt,
                        src_eos=self.dicts[tgt].eos(),
                        src_lang=tgt,
                        tgt_lang=src,
                    ),
                    backtranslation_fn=self.backtranslators[lang_pair],
                    src_dict=self.dicts[src], tgt_dict=self.dicts[tgt],
                    output_collater=self.alter_dataset_langtok(
                        lang_pair_dataset=lang_pair_dataset,
                        src_eos=self.dicts[src].eos(),
                        src_lang=src,
                        tgt_eos=self.dicts[tgt].eos(),
                        tgt_lang=tgt,
                    ).collater,
                )
                print('| backtranslate-{}: {} {} {} examples'.format(
                    tgt, data_path, split, len(backtranslate_datasets[lang_pair]),
                ))
                self.backtranslate_datasets[lang_pair] = backtranslate_datasets[lang_pair]

        # denoising autoencoder
        noising_datasets = {}
        if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None) and split.startswith("train"):
            for lang_pair in self.lang_pairs:
                _, tgt = lang_pair.split('-')
                if not split_exists(split, tgt, None, tgt):
                    continue
                filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt))
                tgt_dataset1 = indexed_dataset(filename, self.dicts[tgt])
                tgt_dataset2 = indexed_dataset(filename, self.dicts[tgt])
                noising_dataset = NoisingDataset(
                    tgt_dataset1,
                    self.dicts[tgt],
                    seed=1,
                    max_word_shuffle_distance=self.args.max_word_shuffle_distance,
                    word_dropout_prob=self.args.word_dropout_prob,
                    word_blanking_prob=self.args.word_blanking_prob,
                )
                noising_datasets[lang_pair] = self.alter_dataset_langtok(
                    LanguagePairDataset(
                        noising_dataset,
                        tgt_dataset1.sizes,
                        self.dicts[tgt],
                        tgt_dataset2,
                        tgt_dataset2.sizes,
                        self.dicts[tgt],
                        left_pad_source=self.args.left_pad_source,
                        left_pad_target=self.args.left_pad_target,
                    ),
                    src_eos=self.dicts[tgt].eos(),
                    src_lang=tgt,
                    tgt_eos=self.dicts[tgt].eos(),
                    tgt_lang=tgt,
                )
                print('| denoising-{}: {} {} {} examples'.format(
                    tgt, data_path, split, len(noising_datasets[lang_pair]),
                ))

        def language_pair_dataset(lang_pair):
            src, tgt = lang_pair.split('-')
            src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
            return self.alter_dataset_langtok(
                LanguagePairDataset(
                    src_dataset, src_dataset.sizes, self.dicts[src],
                    tgt_dataset, tgt_dataset.sizes, self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                    max_source_positions=self.args.max_source_positions,
                    max_target_positions=self.args.max_target_positions,
                ),
                self.dicts[src].eos(),
                src,
                self.dicts[tgt].eos(),
                tgt,
            )

        self.datasets[split] = RoundRobinZipDatasets(
            OrderedDict([
                (lang_pair, language_pair_dataset(lang_pair))
                for lang_pair in src_datasets.keys()
            ] + [
                (_get_bt_dataset_key(lang_pair), dataset)
                for lang_pair, dataset in backtranslate_datasets.items()
            ] + [
                (_get_denoising_dataset_key(lang_pair), dataset)
                for lang_pair, dataset in noising_datasets.items()
            ]),
            eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang),
        )
    def _backtranslation_dataset_helper(
        self, remove_eos_from_input_src, remove_eos_from_output_src,
    ):
        tgt_dataset = LanguagePairDataset(
            src=self.tgt_dataset,
            src_sizes=self.tgt_dataset.sizes,
            src_dict=self.tgt_dict,
            tgt=None,
            tgt_sizes=None,
            tgt_dict=None,
        )

        generator = SequenceGenerator(
            models=[self.model],
            tgt_dict=self.tgt_dict,
            beam_size=2,
            unk_penalty=0,
            sampling=False,
        )
        if self.cuda:
            generator.cuda()

        backtranslation_dataset = BacktranslationDataset(
            tgt_dataset=TransformEosDataset(
                dataset=tgt_dataset,
                eos=self.tgt_dict.eos(),
                # remove eos from the input src
                remove_eos_from_src=remove_eos_from_input_src,
            ),
            backtranslation_fn=generator.generate,
            max_len_a=0,
            max_len_b=200,
            output_collater=TransformEosDataset(
                dataset=tgt_dataset,
                eos=self.tgt_dict.eos(),
                # if we remove eos from the input src, then we need to add it
                # back to the output tgt
                append_eos_to_tgt=remove_eos_from_input_src,
                remove_eos_from_src=remove_eos_from_output_src,
            ).collater,
            cuda=self.cuda,
        )
        dataloader = torch.utils.data.DataLoader(
            backtranslation_dataset,
            batch_size=2,
            collate_fn=backtranslation_dataset.collater,
        )
        backtranslation_batch_result = next(iter(dataloader))

        eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2

        # Note that we sort by src_lengths and add left padding, so actually
        # ids will look like: [1, 0]
        expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]])
        if remove_eos_from_output_src:
            expected_src = expected_src[:, :-1]
        expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
        generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
        tgt_tokens = backtranslation_batch_result["target"]

        self.assertTensorEqual(expected_src, generated_src)
        self.assertTensorEqual(expected_tgt, tgt_tokens)
예제 #3
0
    def load_dataset(
        self, split, src_bin_path, tgt_bin_path, forward_model=None, backward_model=None
    ):
        """Load a dataset split."""

        corpus = ptt_data.ParallelCorpusConfig(
            source=ptt_data.CorpusConfig(
                dialect=self.source_lang, data_file=src_bin_path
            ),
            target=ptt_data.CorpusConfig(
                dialect=self.target_lang, data_file=tgt_bin_path
            ),
            weights_file=None,
        )

        if self.args.log_verbose:
            print("Starting to load binarized data files.", flush=True)
        data_utils.validate_corpus_exists(corpus=corpus, split=split)

        forward_tgt_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.target.data_file
        )
        backward_tgt_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.source.data_file
        )
        forward_src_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.source.data_file
        )
        backward_src_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.target.data_file
        )
        forward_parallel_dataset = weighted_data.WeightedLanguagePairDataset(
            src=forward_src_dataset,
            src_sizes=forward_src_dataset.sizes,
            src_dict=self.source_dictionary,
            tgt=forward_tgt_dataset,
            tgt_sizes=forward_tgt_dataset.sizes,
            tgt_dict=self.target_dictionary,
            remove_eos_from_source=self.remove_eos_from_source,
            append_eos_to_target=True,
        )
        backward_parallel_dataset = weighted_data.WeightedLanguagePairDataset(
            src=backward_src_dataset,
            src_sizes=backward_src_dataset.sizes,
            src_dict=self.target_dictionary,
            tgt=backward_tgt_dataset,
            tgt_sizes=backward_tgt_dataset.sizes,
            tgt_dict=self.source_dictionary,
            remove_eos_from_source=self.remove_eos_from_source,
            append_eos_to_target=True,
        )

        dataset_map = OrderedDict(
            [
                (f"{self.source_lang}-{self.target_lang}", forward_parallel_dataset),
                (f"{self.target_lang}-{self.source_lang}", backward_parallel_dataset),
            ]
        )

        assert (forward_model and backward_model) or (
            forward_model is None and backward_model is None
        ), (
            "Only one of forward or backward models can't be null;"
            " both have to be non-null or null"
        )
        if forward_model and backward_model:
            fwd_generator = beam_decode.SequenceGenerator(
                models=[forward_model], tgt_dict=self.source_dictionary
            )
            bwd_generator = beam_decode.SequenceGenerator(
                models=[backward_model], tgt_dict=self.target_dictionary
            )

            def monolingual_dataset(
                path,
                dictionary,
                is_source=False,
                num_examples_limit: Optional[int] = None,
            ):
                dataset = self.load_monolingual_dataset(
                    path, is_source=is_source, num_examples_limit=num_examples_limit
                )
                return LanguagePairDataset(
                    src=dataset,
                    src_sizes=dataset.sizes,
                    src_dict=dictionary,
                    tgt=None,
                    tgt_sizes=None,
                    tgt_dict=None,
                )

            monolingual_num_examples_limit = None
            if self.args.monolingual_ratio is not None:
                monolingual_num_examples_limit = int(
                    self.args.monolingual_ratio * len(forward_parallel_dataset)
                )

            src_dataset = monolingual_dataset(
                path=self.args.train_mono_source_binary_path,
                dictionary=self.source_dictionary,
                is_source=True,
                num_examples_limit=monolingual_num_examples_limit,
            )
            tgt_dataset = monolingual_dataset(
                path=self.args.train_mono_target_binary_path,
                dictionary=self.target_dictionary,
                is_source=False,
                num_examples_limit=monolingual_num_examples_limit,
            )

            dataset_map[
                f"{self.source_lang}-"
                f"{self.target_lang}_{constants.MONOLINGUAL_DATA_IDENTIFIER}"
            ] = BacktranslationDataset(
                tgt_dataset=TransformEosDataset(
                    dataset=tgt_dataset,
                    eos=self.target_dictionary.eos(),
                    # Remove EOS from the input before backtranslation.
                    remove_eos_from_src=True,
                ),
                backtranslation_fn=bwd_generator.generate,
                max_len_a=self.args.max_len_a,
                max_len_b=self.args.max_len_b,
                output_collater=TransformEosDataset(
                    dataset=tgt_dataset,
                    eos=self.target_dictionary.eos(),
                    # The original input (now the target) doesn't have
                    # an EOS, so we need to add one. The generated
                    # backtranslation (now the source) will have an EOS,
                    # so we want to remove it.
                    append_eos_to_tgt=True,
                    remove_eos_from_src=True,
                ).collater,
            )
            dataset_map[
                f"{self.target_lang}-"
                f"{self.source_lang}_{constants.MONOLINGUAL_DATA_IDENTIFIER}"
            ] = BacktranslationDataset(
                tgt_dataset=src_dataset,
                backtranslation_fn=fwd_generator.generate,
                max_len_a=self.args.max_len_a,
                max_len_b=self.args.max_len_b,
                output_collater=TransformEosDataset(
                    dataset=src_dataset,
                    eos=self.source_dictionary.eos(),
                    # The original input (now the target) doesn't have
                    # an EOS, so we need to add one. The generated
                    # backtranslation (now the source) will have an EOS,
                    # so we want to remove it.
                    append_eos_to_tgt=True,
                    remove_eos_from_src=True,
                ).collater,
            )

        # print before loading RoundRobinZipDatasets to help catch any bugs
        for dataset_key, dataset in dataset_map.items():
            print(f"| {split}: {dataset_key} {len(dataset)} examples in dataset")

        self.datasets[split] = RoundRobinZipDatasets(dataset_map)
        print(
            f"| {split} {len(self.datasets[split])} examples in RoundRobinZipDatasets"
        )

        if self.args.log_verbose:
            print("Finished loading dataset", flush=True)
    def load_dataset(self, split, epoch=1, **kwargs):
        """Load a dataset split."""
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        def split_exists(split, src, tgt, lang):
            if src is not None:
                filename = os.path.join(
                    data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)
                )
            else:
                filename = os.path.join(
                    data_path, "{}.{}-None.{}".format(split, src, tgt)
                )
            return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)

        def load_indexed_dataset(path, dictionary):
            return data_utils.load_indexed_dataset(
                path, dictionary, self.args.dataset_impl
            )

        # load parallel datasets
        src_datasets, tgt_datasets = {}, {}
        if (
            self.lambda_parallel > 0.0
            or self.lambda_parallel_steps is not None
            or not split.startswith("train")
        ):
            for lang_pair in self.lang_pairs:
                src, tgt = lang_pair.split("-")
                if split_exists(split, src, tgt, src):
                    prefix = os.path.join(
                        data_path, "{}.{}-{}.".format(split, src, tgt)
                    )
                elif split_exists(split, tgt, src, src):
                    prefix = os.path.join(
                        data_path, "{}.{}-{}.".format(split, tgt, src)
                    )
                else:
                    continue
                src_datasets[lang_pair] = load_indexed_dataset(
                    prefix + src, self.dicts[src]
                )
                tgt_datasets[lang_pair] = load_indexed_dataset(
                    prefix + tgt, self.dicts[tgt]
                )
                logger.info(
                    "parallel-{} {} {} examples".format(
                        data_path, split, len(src_datasets[lang_pair])
                    )
                )
            if len(src_datasets) == 0:
                raise FileNotFoundError(
                    "Dataset not found: {} ({})".format(split, data_path)
                )

        # back translation datasets
        backtranslate_datasets = {}
        if (
            self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None
        ) and split.startswith("train"):
            for lang_pair in self.lang_pairs:
                src, tgt = lang_pair.split("-")
                if not split_exists(split, tgt, None, tgt):
                    raise FileNotFoundError(
                        "Dataset not found: backtranslation {} ({})".format(
                            split, data_path
                        )
                    )
                filename = os.path.join(
                    data_path, "{}.{}-None.{}".format(split, tgt, tgt)
                )
                dataset = load_indexed_dataset(filename, self.dicts[tgt])
                lang_pair_dataset_tgt = LanguagePairDataset(
                    dataset,
                    dataset.sizes,
                    self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                )
                lang_pair_dataset = LanguagePairDataset(
                    dataset,
                    dataset.sizes,
                    src_dict=self.dicts[src],
                    tgt=dataset,
                    tgt_sizes=dataset.sizes,
                    tgt_dict=self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                )
                backtranslate_datasets[lang_pair] = BacktranslationDataset(
                    tgt_dataset=self.alter_dataset_langtok(
                        lang_pair_dataset_tgt,
                        src_eos=self.dicts[tgt].eos(),
                        src_lang=tgt,
                        tgt_lang=src,
                    ),
                    backtranslation_fn=self.backtranslators[lang_pair],
                    src_dict=self.dicts[src],
                    tgt_dict=self.dicts[tgt],
                    output_collater=self.alter_dataset_langtok(
                        lang_pair_dataset=lang_pair_dataset,
                        src_eos=self.dicts[src].eos(),
                        src_lang=src,
                        tgt_eos=self.dicts[tgt].eos(),
                        tgt_lang=tgt,
                    ).collater,
                )
                logger.info(
                    "backtranslate-{}: {} {} {} examples".format(
                        tgt,
                        data_path,
                        split,
                        len(backtranslate_datasets[lang_pair]),
                    )
                )
                self.backtranslate_datasets[lang_pair] = backtranslate_datasets[
                    lang_pair
                ]

        # denoising autoencoder
        noising_datasets = {}
        if (
            self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None
        ) and split.startswith("train"):
            for lang_pair in self.lang_pairs:
                _, tgt = lang_pair.split("-")
                if not split_exists(split, tgt, None, tgt):
                    continue
                filename = os.path.join(
                    data_path, "{}.{}-None.{}".format(split, tgt, tgt)
                )
                tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt])
                tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt])
                noising_dataset = NoisingDataset(
                    tgt_dataset1,
                    self.dicts[tgt],
                    seed=1,
                    max_word_shuffle_distance=self.args.max_word_shuffle_distance,
                    word_dropout_prob=self.args.word_dropout_prob,
                    word_blanking_prob=self.args.word_blanking_prob,
                )
                noising_datasets[lang_pair] = self.alter_dataset_langtok(
                    LanguagePairDataset(
                        noising_dataset,
                        tgt_dataset1.sizes,
                        self.dicts[tgt],
                        tgt_dataset2,
                        tgt_dataset2.sizes,
                        self.dicts[tgt],
                        left_pad_source=self.args.left_pad_source,
                        left_pad_target=self.args.left_pad_target,
                    ),
                    src_eos=self.dicts[tgt].eos(),
                    src_lang=tgt,
                    tgt_eos=self.dicts[tgt].eos(),
                    tgt_lang=tgt,
                )
                logger.info(
                    "denoising-{}: {} {} {} examples".format(
                        tgt,
                        data_path,
                        split,
                        len(noising_datasets[lang_pair]),
                    )
                )

        def language_pair_dataset(lang_pair):
            src, tgt = lang_pair.split("-")
            src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
            return self.alter_dataset_langtok(
                LanguagePairDataset(
                    src_dataset,
                    src_dataset.sizes,
                    self.dicts[src],
                    tgt_dataset,
                    tgt_dataset.sizes,
                    self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                ),
                self.dicts[src].eos(),
                src,
                self.dicts[tgt].eos(),
                tgt,
            )

        self.datasets[split] = RoundRobinZipDatasets(
            OrderedDict(
                [
                    (lang_pair, language_pair_dataset(lang_pair))
                    for lang_pair in src_datasets.keys()
                ]
                + [
                    (_get_bt_dataset_key(lang_pair), dataset)
                    for lang_pair, dataset in backtranslate_datasets.items()
                ]
                + [
                    (_get_denoising_dataset_key(lang_pair), dataset)
                    for lang_pair, dataset in noising_datasets.items()
                ]
            ),
            eval_key=None
            if self.training
            else "%s-%s" % (self.args.source_lang, self.args.target_lang),
        )
예제 #5
0
    def load_dataset(self, split, epoch=0, **kwargs):
        """Load a dataset split."""

        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        def split_exists(split, src, tgt, lang):
            if src is not None:
                filename = os.path.join(
                    data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            else:
                filename = os.path.join(
                    data_path, '{}.{}-None.{}'.format(split, src, tgt))
            if self.args.raw_text and IndexedRawTextDataset.exists(filename):
                return True
            elif not self.args.raw_text and IndexedDataset.exists(filename):
                return True
            return False

        if split == 'train' and self.args.bt_parallel_update > 0:
            lang_pairs = []
            copied_lang_pairs = [p for p in self.lang_pairs]
            for lang_pair in copied_lang_pairs:
                src, tgt = lang_pair.split('-')
                key = '{}-{}'.format(tgt, src)
                lang_pairs.append(key)
                lang_pairs.append(lang_pair)
        else:
            lang_pairs = self.lang_pairs

        # load parallel datasets
        src_datasets, tgt_datasets = {}, {}
        for lang_pair in lang_pairs:
            src, tgt = lang_pair.split('-')
            if split_exists(split, src, tgt, src):
                prefix = os.path.join(data_path,
                                      '{}.{}-{}.'.format(split, src, tgt))
            elif split_exists(split, tgt, src, src):
                prefix = os.path.join(data_path,
                                      '{}.{}-{}.'.format(split, tgt, src))
            else:
                continue

            src_datasets[lang_pair] = data_utils.load_indexed_dataset(
                prefix + src, self.dicts[src])
            tgt_datasets[lang_pair] = data_utils.load_indexed_dataset(
                prefix + tgt, self.dicts[tgt])
            print('| parallel-{} {} {} examples'.format(
                data_path, split, len(src_datasets[lang_pair])))
        if len(src_datasets) == 0:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, data_path))

        def language_pair_dataset(lang_pair):
            src, tgt = lang_pair.split('-')
            src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[
                lang_pair]
            return self.alter_dataset_langtok(
                LanguagePairDataset(
                    src_dataset,
                    src_dataset.sizes,
                    self.dicts[src],
                    tgt_dataset,
                    tgt_dataset.sizes,
                    self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                    max_source_positions=self.args.max_source_positions,
                    max_target_positions=self.args.max_target_positions,
                ),
                self.dicts[src].eos(),
                src,
                self.dicts[tgt].eos(),
                tgt,
            )

        # back translation datasets
        backtranslate_datasets = {}
        if split.startswith("train"):
            for lang_pair in self.lang_pairs:
                src, tgt = lang_pair.split('-')
                if not split_exists(split, tgt, None, tgt):
                    raise FileNotFoundError(
                        'Dataset not found: backtranslation {} ({})'.format(
                            split, data_path))
                filename = os.path.join(
                    data_path, '{}.{}-None.{}'.format(split, tgt, tgt))
                dataset = data_utils.load_indexed_dataset(
                    filename, self.dicts[tgt])
                lang_pair_dataset_tgt = LanguagePairDataset(
                    dataset,
                    dataset.sizes,
                    self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                )
                #lang_pair_dataset = LanguagePairDataset(
                #    dataset,
                #    dataset.sizes,
                #    src_dict=self.dicts[src],
                #    tgt=dataset,
                #    tgt_sizes=dataset.sizes,
                #    tgt_dict=self.dicts[tgt],
                #    left_pad_source=self.args.left_pad_source,
                #    left_pad_target=self.args.left_pad_target,
                #)
                backtranslate_datasets[lang_pair] = BacktranslationDataset(
                    tgt_dataset=self.alter_dataset_langtok(
                        lang_pair_dataset_tgt,
                        src_eos=self.dicts[tgt].eos(),
                        src_lang=tgt,
                        tgt_lang=src,
                    ),
                    backtranslation_fn=self.backtranslators[lang_pair],
                    src_dict=self.dicts[src],
                    tgt_dict=self.dicts[tgt],
                    output_collater=language_pair_dataset(lang_pair).collater,
                    noising=self.args.noise_bt_dds,
                )
                print('| backtranslate-{}: {} {} {} examples'.format(
                    tgt,
                    data_path,
                    split,
                    len(backtranslate_datasets[lang_pair]),
                ))
                self.backtranslate_datasets[
                    lang_pair] = backtranslate_datasets[lang_pair]

        if split == 'train':
            self.datasets[split] = RoundRobinZipDatasets(
                OrderedDict([(lang_pair, language_pair_dataset(lang_pair))
                             for lang_pair in src_datasets.keys()] +
                            [(_get_bt_dataset_key(lang_pair), dataset)
                             for lang_pair, dataset in
                             backtranslate_datasets.items()]),
                eval_key=None if self.training else "%s-%s" %
                (self.args.source_lang, self.args.target_lang),
                upsample_factor=self.args.upsample_factor,
            )
        else:
            self.datasets[split] = RoundRobinZipDatasets(
                OrderedDict([(lang_pair, language_pair_dataset(lang_pair))
                             for lang_pair in src_datasets.keys()]),
                eval_key=None if self.training else "%s-%s" %
                (self.args.source_lang, self.args.target_lang),
            )

        if split == 'valid' and self.args.bt_dds:
            if self.args.max_tokens_valid is not None:
                max_tokens_valid = self.args.max_tokens_valid / 4
            else:
                max_tokens_valid = None
            if self.args.max_sentences_valid is not None:
                max_sentences_valid = self.args.max_sentences_valid / 4
            else:
                max_sentences_valid = None
            self.dev_itr = self.get_batch_iterator(
                dataset=self.dataset('valid'),
                max_tokens=max_tokens_valid,
                max_sentences=max_sentences_valid,
                max_positions=utils.resolve_max_positions(
                    self.max_positions(), ),
                ignore_invalid_inputs=self.args.
                skip_invalid_size_inputs_valid_test,
                required_batch_size_multiple=self.args.
                required_batch_size_multiple,
                seed=self.args.seed,
                num_shards=self.args.distributed_world_size,
                shard_id=self.args.distributed_rank,
                num_workers=self.args.num_workers,
                noskip=True,
            )[0]
            self.dev_itr.next_epoch_itr(shuffle=True)
예제 #6
0
    def load_dataset(self,
                     split,
                     src_bin_path,
                     tgt_bin_path,
                     forward_model=None,
                     backward_model=None):
        """Load a dataset split."""

        corpus = ptt_data.ParallelCorpusConfig(
            source=ptt_data.CorpusConfig(dialect=self.source_lang,
                                         data_file=src_bin_path),
            target=ptt_data.CorpusConfig(dialect=self.target_lang,
                                         data_file=tgt_bin_path),
            weights_file=None,
        )

        if self.args.log_verbose:
            print("Starting to load binarized data files.", flush=True)
        data_utils.validate_corpus_exists(corpus=corpus, split=split)

        forward_tgt_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.target.data_file)
        backward_tgt_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.source.data_file)
        forward_src_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.source.data_file)
        backward_src_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.target.data_file)
        forward_parallel_dataset = weighted_data.WeightedLanguagePairDataset(
            src=forward_src_dataset,
            src_sizes=forward_src_dataset.sizes,
            src_dict=self.source_dictionary,
            tgt=forward_tgt_dataset,
            tgt_sizes=forward_tgt_dataset.sizes,
            tgt_dict=self.target_dictionary,
            remove_eos_from_source=self.remove_eos_from_source,
            append_eos_to_target=True,
        )
        backward_parallel_dataset = weighted_data.WeightedLanguagePairDataset(
            src=backward_src_dataset,
            src_sizes=backward_src_dataset.sizes,
            src_dict=self.target_dictionary,
            tgt=backward_tgt_dataset,
            tgt_sizes=backward_tgt_dataset.sizes,
            tgt_dict=self.source_dictionary,
            remove_eos_from_source=self.remove_eos_from_source,
            append_eos_to_target=True,
        )

        dataset_map = OrderedDict([
            (f"{self.source_lang}-{self.target_lang}",
             forward_parallel_dataset),
            (f"{self.target_lang}-{self.source_lang}",
             backward_parallel_dataset),
        ])

        assert (forward_model and backward_model) or (
            forward_model is None and backward_model is None), (
                "Only one of forward or backward models can't be null;"
                " both have to be non-null or null")
        if forward_model and backward_model:
            dataset_map[
                f"{self.source_lang}-"
                f"{self.target_lang}_{constants.MONOLINGUAL_DATA_IDENTIFIER}"] = BacktranslationDataset(
                    tgt_dataset=self.load_monolingual_dataset(
                        self.args.train_mono_target_binary_path),
                    tgt_dict=self.target_dictionary,
                    backtranslation_model=backward_model,
                    max_len_a=self.args.max_len_a,
                    max_len_b=self.args.max_len_b,
                    remove_eos_at_src=True,
                    generator_class=beam_decode.SequenceGenerator,
                )
            dataset_map[
                f"{self.target_lang}-"
                f"{self.source_lang}_{constants.MONOLINGUAL_DATA_IDENTIFIER}"] = BacktranslationDataset(
                    tgt_dataset=self.load_monolingual_dataset(
                        self.args.train_mono_source_binary_path),
                    tgt_dict=self.source_dictionary,
                    backtranslation_model=forward_model,
                    max_len_a=self.args.max_len_a,
                    max_len_b=self.args.max_len_b,
                    remove_eos_at_src=True,
                    generator_class=beam_decode.SequenceGenerator,
                )
        self.datasets[split] = RoundRobinZipDatasets(dataset_map)

        if self.args.log_verbose:
            print("Finished loading dataset", flush=True)

        print(f"| {split} {len(self.datasets[split])} datasets")