Exemplo n.º 1
0
    def load_dataset(self,
                     split,
                     src_bin_path,
                     tgt_bin_path,
                     weights_file=None,
                     is_train=False):
        """
        Currently this method does not support character models.
        """
        corpus = pytorch_translate_data.ParallelCorpusConfig(
            source=pytorch_translate_data.CorpusConfig(
                dialect=self.args.source_lang, data_file=src_bin_path),
            target=pytorch_translate_data.CorpusConfig(
                dialect=self.args.target_lang, data_file=tgt_bin_path),
            weights_file=weights_file,
        )

        if self.args.log_verbose:
            print("Starting to load binarized data files.", flush=True)
        data_utils.validate_corpus_exists(corpus=corpus, split=split)

        dst_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file(
            corpus.target.data_file)
        src_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file(
            corpus.source.data_file)
        if is_train:
            self.datasets[split] = TeacherDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.src_dict,
                tgt=dst_dataset,
                tgt_sizes=dst_dataset.sizes,
                tgt_dict=self.tgt_dict,
                top_k_probs_binary_file=self.top_k_probs_binary_file,
                teacher_models=self.teacher_models,
                top_k_teacher_tokens=self.top_k_teacher_tokens,
                top_k_teacher_scores=self.top_k_teacher_scores,
                top_k_teacher_indices=self.top_k_teacher_indices,
                left_pad_source=False,
            )
        else:
            self.datasets[split] = weighted_data.WeightedLanguagePairDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.src_dict,
                tgt=dst_dataset,
                tgt_sizes=dst_dataset.sizes,
                tgt_dict=self.tgt_dict,
                weights=None,
                left_pad_source=False,
            )

        if self.args.log_verbose:
            print("Finished loading dataset", flush=True)

        print(f"| {split} {len(self.datasets[split])} examples")
Exemplo n.º 2
0
    def _load_dataset_single_path(
        self,
        split: str,
        src_bin_path: str,
        tgt_bin_path: str,
        weights_file=None,
        is_npz=True,
    ):
        corpus = pytorch_translate_data.ParallelCorpusConfig(
            source=pytorch_translate_data.CorpusConfig(
                dialect=self.args.source_lang, data_file=src_bin_path),
            target=pytorch_translate_data.CorpusConfig(
                dialect=self.args.target_lang, data_file=tgt_bin_path),
            weights_file=weights_file,
        )

        if self.args.log_verbose:
            print("Starting to load binarized data files.", flush=True)
        data_utils.validate_corpus_exists(corpus=corpus,
                                          split=split,
                                          is_npz=is_npz)

        dst_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file(
            corpus.target.data_file, is_npz=is_npz)
        if getattr(self.args, "reverse_target", None):
            dst_dataset.reverse()
        weights_dataset = None
        if corpus.weights_file and os.path.exists(corpus.weights_file):
            weights_dataset = weighted_data.IndexedWeightsDataset(
                corpus.weights_file)
            assert len(dst_dataset) == len(weights_dataset)

        if self.char_source_dict is not None:
            src_dataset = char_data.InMemoryNumpyWordCharDataset.create_from_file(
                corpus.source.data_file)
            self.datasets[split] = char_data.LanguagePairSourceCharDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.source_dictionary,
                tgt=dst_dataset,
                tgt_sizes=dst_dataset.sizes,
                tgt_dict=self.target_dictionary,
                weights=weights_dataset,
            )
        else:
            src_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file(
                corpus.source.data_file, is_npz=is_npz)
            self.datasets[split] = LanguagePairDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.source_dictionary,
                tgt=dst_dataset,
                tgt_sizes=dst_dataset.sizes,
                tgt_dict=self.target_dictionary,
                left_pad_source=False,
            )
Exemplo n.º 3
0
    def load_dataset(self,
                     split,
                     src_bin_path,
                     tgt_bin_path,
                     weights_file=None):
        corpus = pytorch_translate_data.ParallelCorpusConfig(
            source=pytorch_translate_data.CorpusConfig(
                dialect=self.args.source_lang, data_file=src_bin_path),
            target=pytorch_translate_data.CorpusConfig(
                dialect=self.args.target_lang, data_file=tgt_bin_path),
            weights_file=weights_file,
        )

        if self.args.log_verbose:
            print("Starting to load binarized data files.", flush=True)
        data_utils.validate_corpus_exists(corpus=corpus, split=split)

        dst_dataset = pytorch_translate_data.InMemoryNumpyDataset.create_from_file(
            corpus.target.data_file)
        weights_dataset = None
        if corpus.weights_file and os.path.exists(corpus.weights_file):
            weights_dataset = weighted_data.IndexedWeightsDataset(
                corpus.weights_file)
            assert len(dst_dataset) == len(weights_dataset)

        if self.char_source_dict is not None:
            src_dataset = char_data.InMemoryNumpyWordCharDataset.create_from_file(
                corpus.source.data_file)
            self.datasets[split] = char_data.LanguagePairSourceCharDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.source_dictionary,
                tgt=dst_dataset,
                tgt_sizes=dst_dataset.sizes,
                tgt_dict=self.target_dictionary,
                weights=weights_dataset,
            )
        else:
            src_dataset = pytorch_translate_data.InMemoryNumpyDataset.create_from_file(
                corpus.source.data_file)
            self.datasets[split] = weighted_data.WeightedLanguagePairDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.source_dictionary,
                tgt=dst_dataset,
                tgt_sizes=dst_dataset.sizes,
                tgt_dict=self.target_dictionary,
                weights=weights_dataset,
                left_pad_source=False,
            )

        if self.args.log_verbose:
            print("Finished loading dataset", flush=True)

        print(f"| {split} {len(self.datasets[split])} examples")
Exemplo n.º 4
0
def load_parallel_dataset(
    source_lang,
    target_lang,
    src_bin_path,
    tgt_bin_path,
    source_dictionary,
    target_dictionary,
    split,
    remove_eos_from_source,
    append_eos_to_target=True,
    char_source_dict=None,
    log_verbose=True,
):
    corpus = pytorch_translate_data.ParallelCorpusConfig(
        source=pytorch_translate_data.CorpusConfig(dialect=source_lang,
                                                   data_file=src_bin_path),
        target=pytorch_translate_data.CorpusConfig(dialect=target_lang,
                                                   data_file=tgt_bin_path),
        weights_file=None,
    )

    if log_verbose:
        print("Starting to load binarized data files.", flush=True)
    validate_corpus_exists(corpus=corpus, split=split)

    tgt_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file(
        corpus.target.data_file)
    if char_source_dict is not None:
        src_dataset = char_data.InMemoryNumpyWordCharDataset.create_from_file(
            corpus.source.data_file)
    else:
        src_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file(
            corpus.source.data_file)
    parallel_dataset = weighted_data.WeightedLanguagePairDataset(
        src=src_dataset,
        src_sizes=src_dataset.sizes,
        src_dict=source_dictionary,
        tgt=tgt_dataset,
        tgt_sizes=tgt_dataset.sizes,
        tgt_dict=target_dictionary,
        remove_eos_from_source=remove_eos_from_source,
        append_eos_to_target=append_eos_to_target,
    )
    return parallel_dataset, src_dataset, tgt_dataset
    def load_dataset(self, split, **kwargs):
        """Load a dataset split."""

        lang_pair_to_datasets = {}

        binary_path_arg = ("--multilingual-train-binary-path" if split
                           == "train" else "--multilingual-eval-binary-path")
        binary_path_value = (self.args.multilingual_train_binary_path
                             if split == "train" else
                             self.args.multilingual_eval_binary_path)

        format_warning = (
            f"{binary_path_arg} has to be in the format "
            " src_lang-tgt_lang:src_dataset_path,tgt_dataset_path")

        for path_config in binary_path_value:
            # path_config: str
            # in the format "src_lang-tgt_lang:src_dataset_path,tgt_dataset_path"
            assert ":" in path_config, format_warning
            lang_pair, dataset_paths = path_config.split(":")

            assert "-" in lang_pair, format_warning

            assert "," in dataset_paths, format_warning
            src_dataset_path, tgt_dataset_path = dataset_paths.split(",")

            lang_pair_to_datasets[lang_pair] = (src_dataset_path,
                                                tgt_dataset_path)

        for lang_pair in self.args.lang_pairs:
            assert (
                lang_pair in lang_pair_to_datasets
            ), "Not all language pairs have dataset binary paths specified!"

        datasets = {}
        for lang_pair in self.args.lang_pairs:
            src, tgt = lang_pair.split("-")
            src_bin_path, tgt_bin_path = lang_pair_to_datasets[lang_pair]
            corpus = pytorch_translate_data.ParallelCorpusConfig(
                source=pytorch_translate_data.CorpusConfig(
                    dialect=src, data_file=src_bin_path),
                target=pytorch_translate_data.CorpusConfig(
                    dialect=tgt, data_file=tgt_bin_path),
            )
            if self.args.log_verbose:
                print("Starting to load binarized data files.", flush=True)

            data_utils.validate_corpus_exists(corpus=corpus, split=split)

            tgt_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file(
                corpus.target.data_file)
            src_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file(
                corpus.source.data_file)
            datasets[lang_pair] = weighted_data.WeightedLanguagePairDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.dicts[src],
                tgt=tgt_dataset,
                tgt_sizes=tgt_dataset.sizes,
                tgt_dict=self.dicts[tgt],
                weights=None,
                left_pad_source=False,
            )
        self.datasets[split] = RoundRobinZipDatasets(
            OrderedDict([(lang_pair, datasets[lang_pair])
                         for lang_pair in self.args.lang_pairs]),
            eval_key=None if self.training else
            f"{self.args.source_lang}-{self.args.target_lang}",
        )

        if self.args.log_verbose:
            print("Finished loading dataset", flush=True)

        print(f"| {split} {len(self.datasets[split])} examples")
Exemplo n.º 6
0
    def load_dataset(self,
                     split,
                     src_bin_path,
                     tgt_bin_path,
                     forward_model=None,
                     backward_model=None):
        """Load a dataset split."""

        corpus = ptt_data.ParallelCorpusConfig(
            source=ptt_data.CorpusConfig(dialect=self.source_lang,
                                         data_file=src_bin_path),
            target=ptt_data.CorpusConfig(dialect=self.target_lang,
                                         data_file=tgt_bin_path),
            weights_file=None,
        )

        if self.args.log_verbose:
            print("Starting to load binarized data files.", flush=True)
        ptt_data_utils.validate_corpus_exists(corpus=corpus, split=split)

        forward_tgt_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.target.data_file)
        backward_tgt_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.source.data_file)
        forward_src_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.source.data_file)
        backward_src_dataset = ptt_data.InMemoryNumpyDataset.create_from_file(
            corpus.target.data_file)
        forward_parallel_dataset = weighted_data.WeightedLanguagePairDataset(
            src=forward_src_dataset,
            src_sizes=forward_src_dataset.sizes,
            src_dict=self.source_dictionary,
            tgt=forward_tgt_dataset,
            tgt_sizes=forward_tgt_dataset.sizes,
            tgt_dict=self.target_dictionary,
            remove_eos_from_source=self.remove_eos_from_source,
            append_eos_to_target=True,
        )
        backward_parallel_dataset = weighted_data.WeightedLanguagePairDataset(
            src=backward_src_dataset,
            src_sizes=backward_src_dataset.sizes,
            src_dict=self.target_dictionary,
            tgt=backward_tgt_dataset,
            tgt_sizes=backward_tgt_dataset.sizes,
            tgt_dict=self.source_dictionary,
            remove_eos_from_source=self.remove_eos_from_source,
            append_eos_to_target=True,
        )

        dataset_map = OrderedDict([
            (f"{self.source_lang}-{self.target_lang}",
             forward_parallel_dataset),
            (f"{self.target_lang}-{self.source_lang}",
             backward_parallel_dataset),
        ])

        assert (forward_model and backward_model) or (
            forward_model is None and backward_model is None), (
                "Only one of forward or backward models can't be null;"
                " both have to be non-null or null")
        if forward_model and backward_model:
            fwd_generator = beam_decode.SequenceGenerator(
                models=[forward_model], tgt_dict=self.source_dictionary)
            bwd_generator = beam_decode.SequenceGenerator(
                models=[backward_model], tgt_dict=self.target_dictionary)

            def monolingual_dataset(
                path,
                dictionary,
                is_source=False,
                num_examples_limit: Optional[int] = None,
            ):
                dataset = self.load_monolingual_dataset(
                    path,
                    is_source=is_source,
                    num_examples_limit=num_examples_limit)
                return LanguagePairDataset(
                    src=dataset,
                    src_sizes=dataset.sizes,
                    src_dict=dictionary,
                    tgt=None,
                    tgt_sizes=None,
                    tgt_dict=None,
                )

            monolingual_num_examples_limit = None
            if self.args.monolingual_ratio is not None:
                monolingual_num_examples_limit = int(
                    self.args.monolingual_ratio *
                    len(forward_parallel_dataset))

            src_dataset = monolingual_dataset(
                path=self.args.train_mono_source_binary_path,
                dictionary=self.source_dictionary,
                is_source=True,
                num_examples_limit=monolingual_num_examples_limit,
            )
            tgt_dataset = monolingual_dataset(
                path=self.args.train_mono_target_binary_path,
                dictionary=self.target_dictionary,
                is_source=False,
                num_examples_limit=monolingual_num_examples_limit,
            )

            def generate_fn(generator):
                def _generate_fn(sample):
                    net_input = sample["net_input"]
                    maxlen = int(self.args.max_len_a *
                                 net_input["src_tokens"].size(1) +
                                 self.args.max_len_b)
                    return generator.generate(net_input, maxlen=maxlen)

                return _generate_fn

            dataset_map[
                f"{self.source_lang}-"
                f"{self.target_lang}_{constants.MONOLINGUAL_DATA_IDENTIFIER}"] = weighted_data.WeightedBacktranslationDataset(
                    dataset=BacktranslationDataset(
                        tgt_dataset=TransformEosDataset(
                            dataset=tgt_dataset,
                            eos=self.target_dictionary.eos(),
                            # Remove EOS from the input before backtranslation.
                            remove_eos_from_src=True,
                        ),
                        src_dict=self.source_dictionary,
                        backtranslation_fn=generate_fn(bwd_generator),
                        output_collater=TransformEosDataset(
                            dataset=tgt_dataset,
                            eos=self.target_dictionary.eos(),
                            # The original input (now the target) doesn't have
                            # an EOS, so we need to add one. The generated
                            # backtranslation (now the source) will have an EOS,
                            # so we want to remove it.
                            append_eos_to_tgt=True,
                            remove_eos_from_src=True,
                        ).collater,
                    ))
            dataset_map[
                f"{self.target_lang}-"
                f"{self.source_lang}_{constants.MONOLINGUAL_DATA_IDENTIFIER}"] = weighted_data.WeightedBacktranslationDataset(
                    dataset=BacktranslationDataset(
                        tgt_dataset=src_dataset,
                        src_dict=self.source_dictionary,
                        backtranslation_fn=generate_fn(fwd_generator),
                        output_collater=TransformEosDataset(
                            dataset=src_dataset,
                            eos=self.source_dictionary.eos(),
                            # The original input (now the target) doesn't have
                            # an EOS, so we need to add one. The generated
                            # backtranslation (now the source) will have an EOS,
                            # so we want to remove it.
                            append_eos_to_tgt=True,
                            remove_eos_from_src=True,
                        ).collater,
                    ))

        # print before loading RoundRobinZipDatasets to help catch any bugs
        for dataset_key, dataset in dataset_map.items():
            print(
                f"| {split}: {dataset_key} {len(dataset)} examples in dataset")

        self.datasets[split] = RoundRobinZipDatasets(dataset_map)
        print(
            f"| {split} {len(self.datasets[split])} examples in RoundRobinZipDatasets"
        )

        if self.args.log_verbose:
            print("Finished loading dataset", flush=True)