Exemple #1
0
    def _get_noising_dataset_batch(
        self,
        src_tokens_no_pad,
        src_dict,
        append_eos_to_tgt=False,
    ):
        """
        Constructs a NoisingDataset and the corresponding
        ``LanguagePairDataset(NoisingDataset(src), src)``. If
        *append_eos_to_tgt* is True, wrap the source dataset in
        :class:`TransformEosDataset` to append EOS to the clean source when
        using it as the target.
        """
        src_dataset = test_utils.TestDataset(data=src_tokens_no_pad)

        noising_dataset = noising.NoisingDataset(
            src_dataset=src_dataset,
            src_dict=src_dict,
            seed=1234,
            max_word_shuffle_distance=3,
            word_dropout_prob=0.2,
            word_blanking_prob=0.2,
            noising_class=noising.UnsupervisedMTNoising,
        )
        tgt = src_dataset
        language_pair_dataset = LanguagePairDataset(src=noising_dataset,
                                                    tgt=tgt,
                                                    src_sizes=None,
                                                    src_dict=src_dict)
        language_pair_dataset = TransformEosDataset(
            language_pair_dataset,
            src_dict.eos(),
            append_eos_to_tgt=append_eos_to_tgt,
        )

        dataloader = torch.utils.data.DataLoader(
            dataset=language_pair_dataset,
            batch_size=2,
            collate_fn=language_pair_dataset.collater,
        )
        denoising_batch_result = next(iter(dataloader))
        return denoising_batch_result
Exemple #2
0
 def build_dataset_for_inference(self, src_tokens, src_lengths):
     return TransformEosDataset(
         MonolingualDataset(
             TokenBlockDataset(
                 src_tokens,
                 src_lengths,
                 block_size=None,
                 pad=self.source_dictionary.pad(),
                 eos=self.source_dictionary.eos(),
                 break_mode='eos',
                 include_targets=False,
             ),
             src_lengths,
             self.source_dictionary,
             self.target_dictionary,
             add_eos_for_other_targets=False,
             shuffle=False,
         ),
         eos=self.source_dictionary.eos(),
         # remove EOS since this will be used as a prefix for generation
         remove_eos_from_src=True,
         has_target=False,
     )
    def _backtranslation_dataset_helper(
        self, remove_eos_from_input_src, remove_eos_from_output_src,
    ):
        tgt_dataset = LanguagePairDataset(
            src=self.tgt_dataset,
            src_sizes=self.tgt_dataset.sizes,
            src_dict=self.tgt_dict,
            tgt=None,
            tgt_sizes=None,
            tgt_dict=None,
        )

        generator = SequenceGenerator(
            models=[self.model],
            tgt_dict=self.tgt_dict,
            beam_size=2,
            unk_penalty=0,
            sampling=False,
        )
        if self.cuda:
            generator.cuda()

        backtranslation_dataset = BacktranslationDataset(
            tgt_dataset=TransformEosDataset(
                dataset=tgt_dataset,
                eos=self.tgt_dict.eos(),
                # remove eos from the input src
                remove_eos_from_src=remove_eos_from_input_src,
            ),
            backtranslation_fn=generator.generate,
            max_len_a=0,
            max_len_b=200,
            output_collater=TransformEosDataset(
                dataset=tgt_dataset,
                eos=self.tgt_dict.eos(),
                # if we remove eos from the input src, then we need to add it
                # back to the output tgt
                append_eos_to_tgt=remove_eos_from_input_src,
                remove_eos_from_src=remove_eos_from_output_src,
            ).collater,
            cuda=self.cuda,
        )
        dataloader = torch.utils.data.DataLoader(
            backtranslation_dataset,
            batch_size=2,
            collate_fn=backtranslation_dataset.collater,
        )
        backtranslation_batch_result = next(iter(dataloader))

        eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2

        # Note that we sort by src_lengths and add left padding, so actually
        # ids will look like: [1, 0]
        expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]])
        if remove_eos_from_output_src:
            expected_src = expected_src[:, :-1]
        expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
        generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
        tgt_tokens = backtranslation_batch_result["target"]

        self.assertTensorEqual(expected_src, generated_src)
        self.assertTensorEqual(expected_tgt, tgt_tokens)
Exemple #4
0
    def load_dataset(
        self, split, src_bin_path, tgt_bin_path, forward_model=None, backward_model=None
    ):
        """Load a dataset split."""

        corpus = ptt_data.ParallelCorpusConfig(
            source=ptt_data.CorpusConfig(
                dialect=self.source_lang, data_file=src_bin_path
            ),
            target=ptt_data.CorpusConfig(
                dialect=self.target_lang, data_file=tgt_bin_path
            ),
            weights_file=None,
        )

        if self.args.log_verbose:
            print("Starting to load binarized data files.", flush=True)
        data_utils.validate_corpus_exists(corpus=corpus, split=split)

        forward_tgt_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.target.data_file
        )
        backward_tgt_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.source.data_file
        )
        forward_src_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.source.data_file
        )
        backward_src_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.target.data_file
        )
        forward_parallel_dataset = weighted_data.WeightedLanguagePairDataset(
            src=forward_src_dataset,
            src_sizes=forward_src_dataset.sizes,
            src_dict=self.source_dictionary,
            tgt=forward_tgt_dataset,
            tgt_sizes=forward_tgt_dataset.sizes,
            tgt_dict=self.target_dictionary,
            remove_eos_from_source=self.remove_eos_from_source,
            append_eos_to_target=True,
        )
        backward_parallel_dataset = weighted_data.WeightedLanguagePairDataset(
            src=backward_src_dataset,
            src_sizes=backward_src_dataset.sizes,
            src_dict=self.target_dictionary,
            tgt=backward_tgt_dataset,
            tgt_sizes=backward_tgt_dataset.sizes,
            tgt_dict=self.source_dictionary,
            remove_eos_from_source=self.remove_eos_from_source,
            append_eos_to_target=True,
        )

        dataset_map = OrderedDict(
            [
                (f"{self.source_lang}-{self.target_lang}", forward_parallel_dataset),
                (f"{self.target_lang}-{self.source_lang}", backward_parallel_dataset),
            ]
        )

        assert (forward_model and backward_model) or (
            forward_model is None and backward_model is None
        ), (
            "Only one of forward or backward models can't be null;"
            " both have to be non-null or null"
        )
        if forward_model and backward_model:
            fwd_generator = beam_decode.SequenceGenerator(
                models=[forward_model], tgt_dict=self.source_dictionary
            )
            bwd_generator = beam_decode.SequenceGenerator(
                models=[backward_model], tgt_dict=self.target_dictionary
            )

            def monolingual_dataset(
                path,
                dictionary,
                is_source=False,
                num_examples_limit: Optional[int] = None,
            ):
                dataset = self.load_monolingual_dataset(
                    path, is_source=is_source, num_examples_limit=num_examples_limit
                )
                return LanguagePairDataset(
                    src=dataset,
                    src_sizes=dataset.sizes,
                    src_dict=dictionary,
                    tgt=None,
                    tgt_sizes=None,
                    tgt_dict=None,
                )

            monolingual_num_examples_limit = None
            if self.args.monolingual_ratio is not None:
                monolingual_num_examples_limit = int(
                    self.args.monolingual_ratio * len(forward_parallel_dataset)
                )

            src_dataset = monolingual_dataset(
                path=self.args.train_mono_source_binary_path,
                dictionary=self.source_dictionary,
                is_source=True,
                num_examples_limit=monolingual_num_examples_limit,
            )
            tgt_dataset = monolingual_dataset(
                path=self.args.train_mono_target_binary_path,
                dictionary=self.target_dictionary,
                is_source=False,
                num_examples_limit=monolingual_num_examples_limit,
            )

            dataset_map[
                f"{self.source_lang}-"
                f"{self.target_lang}_{constants.MONOLINGUAL_DATA_IDENTIFIER}"
            ] = BacktranslationDataset(
                tgt_dataset=TransformEosDataset(
                    dataset=tgt_dataset,
                    eos=self.target_dictionary.eos(),
                    # Remove EOS from the input before backtranslation.
                    remove_eos_from_src=True,
                ),
                backtranslation_fn=bwd_generator.generate,
                max_len_a=self.args.max_len_a,
                max_len_b=self.args.max_len_b,
                output_collater=TransformEosDataset(
                    dataset=tgt_dataset,
                    eos=self.target_dictionary.eos(),
                    # The original input (now the target) doesn't have
                    # an EOS, so we need to add one. The generated
                    # backtranslation (now the source) will have an EOS,
                    # so we want to remove it.
                    append_eos_to_tgt=True,
                    remove_eos_from_src=True,
                ).collater,
            )
            dataset_map[
                f"{self.target_lang}-"
                f"{self.source_lang}_{constants.MONOLINGUAL_DATA_IDENTIFIER}"
            ] = BacktranslationDataset(
                tgt_dataset=src_dataset,
                backtranslation_fn=fwd_generator.generate,
                max_len_a=self.args.max_len_a,
                max_len_b=self.args.max_len_b,
                output_collater=TransformEosDataset(
                    dataset=src_dataset,
                    eos=self.source_dictionary.eos(),
                    # The original input (now the target) doesn't have
                    # an EOS, so we need to add one. The generated
                    # backtranslation (now the source) will have an EOS,
                    # so we want to remove it.
                    append_eos_to_tgt=True,
                    remove_eos_from_src=True,
                ).collater,
            )

        # print before loading RoundRobinZipDatasets to help catch any bugs
        for dataset_key, dataset in dataset_map.items():
            print(f"| {split}: {dataset_key} {len(dataset)} examples in dataset")

        self.datasets[split] = RoundRobinZipDatasets(dataset_map)
        print(
            f"| {split} {len(self.datasets[split])} examples in RoundRobinZipDatasets"
        )

        if self.args.log_verbose:
            print("Finished loading dataset", flush=True)