Exemplo n.º 1
0
    def _load_dataset_single_path(
        self,
        split: str,
        src_bin_path: str,
        tgt_bin_path: str,
        weights_file=None,
        is_npz=True,
    ):
        corpus = pytorch_translate_data.ParallelCorpusConfig(
            source=pytorch_translate_data.CorpusConfig(
                dialect=self.args.source_lang, data_file=src_bin_path),
            target=pytorch_translate_data.CorpusConfig(
                dialect=self.args.target_lang, data_file=tgt_bin_path),
            weights_file=weights_file,
        )

        if self.args.log_verbose:
            print("Starting to load binarized data files.", flush=True)
        data_utils.validate_corpus_exists(corpus=corpus,
                                          split=split,
                                          is_npz=is_npz)

        dst_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file(
            corpus.target.data_file, is_npz=is_npz)
        if getattr(self.args, "reverse_target", None):
            dst_dataset.reverse()
        weights_dataset = None
        if corpus.weights_file and os.path.exists(corpus.weights_file):
            weights_dataset = weighted_data.IndexedWeightsDataset(
                corpus.weights_file)
            assert len(dst_dataset) == len(weights_dataset)

        if self.char_source_dict is not None:
            src_dataset = char_data.InMemoryNumpyWordCharDataset.create_from_file(
                corpus.source.data_file)
            self.datasets[split] = char_data.LanguagePairSourceCharDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.source_dictionary,
                tgt=dst_dataset,
                tgt_sizes=dst_dataset.sizes,
                tgt_dict=self.target_dictionary,
                weights=weights_dataset,
            )
        else:
            src_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file(
                corpus.source.data_file, is_npz=is_npz)
            self.datasets[split] = LanguagePairDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.source_dictionary,
                tgt=dst_dataset,
                tgt_sizes=dst_dataset.sizes,
                tgt_dict=self.target_dictionary,
                left_pad_source=False,
            )
Exemplo n.º 2
0
    def load_dataset(self,
                     split,
                     src_bin_path,
                     tgt_bin_path,
                     weights_file=None):
        corpus = pytorch_translate_data.ParallelCorpusConfig(
            source=pytorch_translate_data.CorpusConfig(
                dialect=self.args.source_lang, data_file=src_bin_path),
            target=pytorch_translate_data.CorpusConfig(
                dialect=self.args.target_lang, data_file=tgt_bin_path),
            weights_file=weights_file,
        )

        if self.args.log_verbose:
            print("Starting to load binarized data files.", flush=True)
        data_utils.validate_corpus_exists(corpus=corpus, split=split)

        dst_dataset = pytorch_translate_data.InMemoryNumpyDataset.create_from_file(
            corpus.target.data_file)
        weights_dataset = None
        if corpus.weights_file and os.path.exists(corpus.weights_file):
            weights_dataset = weighted_data.IndexedWeightsDataset(
                corpus.weights_file)
            assert len(dst_dataset) == len(weights_dataset)

        if self.char_source_dict is not None:
            src_dataset = char_data.InMemoryNumpyWordCharDataset.create_from_file(
                corpus.source.data_file)
            self.datasets[split] = char_data.LanguagePairSourceCharDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.source_dictionary,
                tgt=dst_dataset,
                tgt_sizes=dst_dataset.sizes,
                tgt_dict=self.target_dictionary,
                weights=weights_dataset,
            )
        else:
            src_dataset = pytorch_translate_data.InMemoryNumpyDataset.create_from_file(
                corpus.source.data_file)
            self.datasets[split] = weighted_data.WeightedLanguagePairDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.source_dictionary,
                tgt=dst_dataset,
                tgt_sizes=dst_dataset.sizes,
                tgt_dict=self.target_dictionary,
                weights=weights_dataset,
                left_pad_source=False,
            )

        if self.args.log_verbose:
            print("Finished loading dataset", flush=True)

        print(f"| {split} {len(self.datasets[split])} examples")