def language_pair_dataset(lang_pair): src, tgt = lang_pair.split('-') langpair_dataset = load_langpair_dataset( data_path, split, src, self.dicts[src], tgt, self.dicts[tgt], combine=True, dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, ) return self.alter_dataset_langtok( langpair_dataset, src_eos=self.dicts[tgt].eos(), src_lang=src, tgt_lang=tgt, )
def _load_dataset(self, ds_name, split, epoch=0, combine=False, **kwargs): paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] data_path = os.path.join(data_path, ds_name) # infer langcode src, tgt = self.args.source_lang, self.args.target_lang return load_langpair_dataset( data_path, split, src, self.src_dicts[ds_name], tgt, self.tgt_dicts[ds_name], combine=combine, dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] # infer langcode src, tgt = self.args.source_lang, self.args.target_lang self.datasets[split] = load_langpair_dataset( data_path, split, src, self.src_dict, tgt, self.tgt_dict, combine=combine, dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, prepend_bos=True, )
def language_pair_dataset(lang_pair): logger.info('loading language pair {}'.format(lang_pair)) src, tgt = lang_pair.split('-') return load_langpair_dataset( data_path, split, src, self.src_dict, tgt, self.tgt_dict, combine=combine, dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, load_alignments=self.args.load_alignments, load_cls_labels=self.args.load_cls_labels, load_cls_indices=self.args.load_cls_indices, load_sample_weights=self.args.load_sample_weights, truncate_source=self.args.truncate_source, shuffle=False)
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 if split != getattr(self.args, "train_subset", None): # if not training data set, use the first shard for valid and test paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] # infer langcode src, tgt = self.args.source_lang, self.args.target_lang self.datasets[split] = load_langpair_dataset( data_path, split, src, self.src_dict, tgt, self.tgt_dict, combine=combine, dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, load_alignments=self.args.load_alignments, truncate_source=self.args.truncate_source, num_buckets=self.args.num_batch_buckets, shuffle=(split != "test"), pad_to_multiple=self.args.required_seq_len_multiple, )
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ logger.info("To load the dataset {}".format(split)) paths = utils.split_paths(self.args.data) assert len(paths) > 0 if split != getattr(self.args, "train_subset", None): # if not training data set, use the first shard for valid and test paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] mono_paths = utils.split_paths(self.args.mono_data) # infer langcode src, tgt = self.args.source_lang, self.args.target_lang parallel_data = load_langpair_dataset( data_path, split, src, self.src_dict, tgt, self.tgt_dict, combine=combine, dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, load_alignments=self.args.load_alignments, num_buckets=self.args.num_batch_buckets, shuffle=(split != "test"), pad_to_multiple=self.args.required_seq_len_multiple, ) if split == "train": parallel_data = SubsampleLanguagePairDataset( parallel_data, size_ratio=self.args.parallel_ratio, seed=self.args.seed, epoch=epoch) if self.args.mono_one_split_each_epoch: mono_path = mono_paths[(epoch - 1) % len(mono_paths)] # each at one epoch mono_data = load_langpair_dataset( mono_path, split, src, self.src_dict, tgt, self.tgt_dict, combine=combine, dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, shuffle=(split != "test"), max_target_positions=self.args.max_target_positions, ) mono_data = SubsampleLanguagePairDataset( mono_data, size_ratio=self.args.mono_ratio, seed=self.args.seed, epoch=epoch) all_dataset = [parallel_data, mono_data] else: mono_datas = [] for mono_path in mono_paths: mono_data = load_langpair_dataset( mono_path, split, src, self.src_dict, tgt, self.tgt_dict, combine=combine, dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, shuffle=(split != "test"), max_target_positions=self.args.max_target_positions, ) mono_data = SubsampleLanguagePairDataset( mono_data, size_ratio=self.args.mono_ratio, seed=self.args.seed, epoch=epoch) mono_datas.append(mono_data) all_dataset = [parallel_data] + mono_datas self.datasets[split] = ConcatDataset(all_dataset) else: self.datasets[split] = parallel_data