def language_pair_dataset(lang_pair):
     src, tgt = lang_pair.split('-')
     langpair_dataset = load_langpair_dataset(
         data_path, split, src, self.dicts[src], tgt, self.dicts[tgt],
         combine=True, dataset_impl=self.args.dataset_impl,
         upsample_primary=self.args.upsample_primary,
         left_pad_source=self.args.left_pad_source,
         left_pad_target=self.args.left_pad_target,
         max_source_positions=self.args.max_source_positions,
         max_target_positions=self.args.max_target_positions,
     )
     return self.alter_dataset_langtok(
         langpair_dataset,
         src_eos=self.dicts[tgt].eos(),
         src_lang=src,
         tgt_lang=tgt,
     )
Exemple #2
0
    def _load_dataset(self, ds_name, split, epoch=0, combine=False, **kwargs):
        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)] 
        data_path = os.path.join(data_path, ds_name)

        # infer langcode
        src, tgt = self.args.source_lang, self.args.target_lang

        return load_langpair_dataset(
            data_path, split, src, self.src_dicts[ds_name], tgt, self.tgt_dicts[ds_name],
            combine=combine, dataset_impl=self.args.dataset_impl,
            upsample_primary=self.args.upsample_primary,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
        )
Exemple #3
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        # infer langcode
        src, tgt = self.args.source_lang, self.args.target_lang

        self.datasets[split] = load_langpair_dataset(
            data_path, split, src, self.src_dict, tgt, self.tgt_dict,
            combine=combine, dataset_impl=self.args.dataset_impl,
            upsample_primary=self.args.upsample_primary,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            prepend_bos=True,
        )
Exemple #4
0
 def language_pair_dataset(lang_pair):
     logger.info('loading language pair {}'.format(lang_pair))
     src, tgt = lang_pair.split('-')
     return load_langpair_dataset(
         data_path,
         split,
         src,
         self.src_dict,
         tgt,
         self.tgt_dict,
         combine=combine,
         dataset_impl=self.args.dataset_impl,
         upsample_primary=self.args.upsample_primary,
         left_pad_source=self.args.left_pad_source,
         left_pad_target=self.args.left_pad_target,
         max_source_positions=self.args.max_source_positions,
         max_target_positions=self.args.max_target_positions,
         load_alignments=self.args.load_alignments,
         load_cls_labels=self.args.load_cls_labels,
         load_cls_indices=self.args.load_cls_indices,
         load_sample_weights=self.args.load_sample_weights,
         truncate_source=self.args.truncate_source,
         shuffle=False)
Exemple #5
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        if split != getattr(self.args, "train_subset", None):
            # if not training data set, use the first shard for valid and test
            paths = paths[:1]
        data_path = paths[(epoch - 1) % len(paths)]

        # infer langcode
        src, tgt = self.args.source_lang, self.args.target_lang

        self.datasets[split] = load_langpair_dataset(
            data_path,
            split,
            src,
            self.src_dict,
            tgt,
            self.tgt_dict,
            combine=combine,
            dataset_impl=self.args.dataset_impl,
            upsample_primary=self.args.upsample_primary,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            load_alignments=self.args.load_alignments,
            truncate_source=self.args.truncate_source,
            num_buckets=self.args.num_batch_buckets,
            shuffle=(split != "test"),
            pad_to_multiple=self.args.required_seq_len_multiple,
        )
Exemple #6
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        logger.info("To load the dataset {}".format(split))
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        if split != getattr(self.args, "train_subset", None):
            # if not training data set, use the first shard for valid and test
            paths = paths[:1]
        data_path = paths[(epoch - 1) % len(paths)]

        mono_paths = utils.split_paths(self.args.mono_data)

        # infer langcode
        src, tgt = self.args.source_lang, self.args.target_lang

        parallel_data = load_langpair_dataset(
            data_path,
            split,
            src,
            self.src_dict,
            tgt,
            self.tgt_dict,
            combine=combine,
            dataset_impl=self.args.dataset_impl,
            upsample_primary=self.args.upsample_primary,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            load_alignments=self.args.load_alignments,
            num_buckets=self.args.num_batch_buckets,
            shuffle=(split != "test"),
            pad_to_multiple=self.args.required_seq_len_multiple,
        )
        if split == "train":
            parallel_data = SubsampleLanguagePairDataset(
                parallel_data,
                size_ratio=self.args.parallel_ratio,
                seed=self.args.seed,
                epoch=epoch)
            if self.args.mono_one_split_each_epoch:
                mono_path = mono_paths[(epoch - 1) %
                                       len(mono_paths)]  # each at one epoch
                mono_data = load_langpair_dataset(
                    mono_path,
                    split,
                    src,
                    self.src_dict,
                    tgt,
                    self.tgt_dict,
                    combine=combine,
                    dataset_impl=self.args.dataset_impl,
                    upsample_primary=self.args.upsample_primary,
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                    max_source_positions=self.args.max_source_positions,
                    shuffle=(split != "test"),
                    max_target_positions=self.args.max_target_positions,
                )
                mono_data = SubsampleLanguagePairDataset(
                    mono_data,
                    size_ratio=self.args.mono_ratio,
                    seed=self.args.seed,
                    epoch=epoch)
                all_dataset = [parallel_data, mono_data]
            else:
                mono_datas = []
                for mono_path in mono_paths:
                    mono_data = load_langpair_dataset(
                        mono_path,
                        split,
                        src,
                        self.src_dict,
                        tgt,
                        self.tgt_dict,
                        combine=combine,
                        dataset_impl=self.args.dataset_impl,
                        upsample_primary=self.args.upsample_primary,
                        left_pad_source=self.args.left_pad_source,
                        left_pad_target=self.args.left_pad_target,
                        max_source_positions=self.args.max_source_positions,
                        shuffle=(split != "test"),
                        max_target_positions=self.args.max_target_positions,
                    )
                    mono_data = SubsampleLanguagePairDataset(
                        mono_data,
                        size_ratio=self.args.mono_ratio,
                        seed=self.args.seed,
                        epoch=epoch)
                    mono_datas.append(mono_data)
                all_dataset = [parallel_data] + mono_datas
            self.datasets[split] = ConcatDataset(all_dataset)
        else:
            self.datasets[split] = parallel_data