def _load_dataset_multi_path( self, split: str, src_multiple_bin_paths: Dict[str, str], tgt_multiple_bin_paths: Dict[str, str], dataset_upsampling: Optional[Dict[str, float]] = None, dataset_relative_ratio: Optional[Tuple[str, float]] = None, seed: Optional[int] = None, noiser: Optional[Dict[str, UnsupervisedMTNoising]] = None, ): corpora_map = pytorch_translate_data.ParallelCorporaMapConfig( src_files=src_multiple_bin_paths, tgt_files=tgt_multiple_bin_paths) datasets = OrderedDict() for key in corpora_map.src_files: src, tgt = corpora_map.src_files[key], corpora_map.tgt_files[key] src_dataset, tgt_dataset = ( pytorch_translate_data.InMemoryNumpyDataset.create_from_file( src), pytorch_translate_data.InMemoryNumpyDataset.create_from_file( tgt), ) src_sizes = src_dataset.sizes if noiser is not None and key in noiser: src_dataset = NoisingDataset( src_dataset=src_dataset, src_dict=self.source_dictionary, seed=seed, noiser=noiser[key], ) datasets[key] = LanguagePairDataset( src=src_dataset, src_sizes=src_sizes, src_dict=self.source_dictionary, tgt=tgt_dataset, tgt_sizes=tgt_dataset.sizes, tgt_dict=self.target_dictionary, left_pad_source=False, ) total_line_count = sum(len(datasets[key]) for key in datasets) if dataset_relative_ratio is not None: ds, ratio = dataset_relative_ratio line_count = len(datasets[ds]) # By definition ratio = u * line_count / sum(#lines of other datasets) u = (total_line_count - line_count) / line_count * ratio dataset_upsampling = {key: u} dataset_weights = { key: 1.0 * len(datasets[key]) / total_line_count for key in src_multiple_bin_paths.keys() } if dataset_upsampling is not None: for k, v in dataset_upsampling.items(): dataset_weights[k] *= v print(f"|dataset_weights:{dataset_weights}") self.datasets[split] = MultiCorpusSampledDataset( datasets=datasets, default_key=list(dataset_weights.keys())[0], sampling_func=self._normalized_weighted_sampling(dataset_weights), )
def _load_dataset_multi_path_helper( self, split: str, src_multiple_bin_paths: Dict[str, str], tgt_multiple_bin_paths: Dict[str, str], dataset_upsampling: Optional[Dict[str, float]] = None, dataset_relative_ratio: Optional[Tuple[str, float]] = None, seed: Optional[int] = None, noiser: Optional[Dict[str, UnsupervisedMTNoising]] = None, ): corpora_map = pytorch_translate_data.ParallelCorporaMapConfig( src_files=src_multiple_bin_paths, tgt_files=tgt_multiple_bin_paths) datasets = OrderedDict() for key in corpora_map.src_files: src, tgt = corpora_map.src_files[key], corpora_map.tgt_files[key] src_dataset, tgt_dataset = ( pytorch_translate_data.InMemoryNumpyDataset.create_from_file( src), pytorch_translate_data.InMemoryNumpyDataset.create_from_file( tgt), ) src_sizes = src_dataset.sizes if noiser is not None and key in noiser: src_dataset = NoisingDataset( src_dataset=src_dataset, src_dict=self.source_dictionary, seed=seed, noiser=noiser[key], ) datasets[key] = LanguagePairDataset( src=src_dataset, src_sizes=src_sizes, src_dict=self.source_dictionary, tgt=tgt_dataset, tgt_sizes=tgt_dataset.sizes, tgt_dict=self.target_dictionary, left_pad_source=False, ) total_line_count = sum(len(datasets[key]) for key in datasets) if dataset_relative_ratio: ds, ratio = dataset_relative_ratio line_count = len(datasets[ds]) # By definition ratio = u * line_count / sum(#lines of other datasets) u = (total_line_count - line_count) / line_count * ratio dataset_upsampling = {key: u} elif not dataset_upsampling: dataset_upsampling = {} print(f"|dataset upsampling:{dataset_upsampling}") ds_list = [] sample_ratios = [] for key, val in datasets.items(): ds_list.append(val) sample_ratios.append(dataset_upsampling.get(key, 1.0)) self.datasets[split] = ConcatDataset(datasets=datasets.values(), sample_ratios=sample_ratios)
def load_denoise_dataset(self, data_path: str, lang: str) -> FairseqDataset: """Classic denoising dataset""" dataset = data_utils.load_indexed_dataset(data_path, self.common_dict, self.args.dataset_impl) noisy_dataset = NoisingDataset( dataset, self.dictionary, seed=1, max_word_shuffle_distance=self.args.max_word_shuffle_distance, word_dropout_prob=self.args.word_dropout_prob, word_blanking_prob=self.args.word_blanking_prob, ) noisy_dataset = PrependTokenDataset( noisy_dataset, _lang_token_index(self.dictionary, lang)) clean_dataset = data_utils.load_indexed_dataset( data_path, self.common_dict, self.args.dataset_impl) denoising_dataset = self._langpair_dataset(noisy_dataset, clean_dataset) denoising_dataset = self._prepend_lang_bos_to_target( denoising_dataset, lang) return denoising_dataset
def load_dataset(self, split, epoch=0, **kwargs): """Load a dataset split.""" paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] def split_exists(split, src, tgt, lang): if src is not None: filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) else: filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, src, tgt)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedDataset.exists(filename): return True return False def indexed_dataset(path, dictionary): if self.args.raw_text: return IndexedRawTextDataset(path, dictionary) elif IndexedDataset.exists(path): if self.args.lazy_load: return IndexedDataset(path, fix_lua_indexing=True) else: return IndexedCachedDataset(path, fix_lua_indexing=True) return None # load parallel datasets src_datasets, tgt_datasets = {}, {} if (self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train")): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split('-') if split_exists(split, src, tgt, src): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt)) elif split_exists(split, tgt, src, src): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src)) else: continue src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src]) tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt]) print('| parallel-{} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair]))) if len(src_datasets) == 0: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) # back translation datasets backtranslate_datasets = {} if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and split.startswith("train"): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split('-') if not split_exists(split, tgt, None, tgt): raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, data_path)) filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt)) dataset = indexed_dataset(filename, self.dicts[tgt]) lang_pair_dataset_tgt = LanguagePairDataset( dataset, dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) lang_pair_dataset = LanguagePairDataset( dataset, dataset.sizes, src_dict=self.dicts[src], tgt=dataset, tgt_sizes=dataset.sizes, tgt_dict=self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) backtranslate_datasets[lang_pair] = BacktranslationDataset( tgt_dataset=self.alter_dataset_langtok( lang_pair_dataset_tgt, src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_lang=src, ), backtranslation_fn=self.backtranslators[lang_pair], src_dict=self.dicts[src], tgt_dict=self.dicts[tgt], output_collater=self.alter_dataset_langtok( lang_pair_dataset=lang_pair_dataset, src_eos=self.dicts[src].eos(), src_lang=src, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ).collater, ) print('| backtranslate-{}: {} {} {} examples'.format( tgt, data_path, split, len(backtranslate_datasets[lang_pair]), )) self.backtranslate_datasets[lang_pair] = backtranslate_datasets[lang_pair] # denoising autoencoder noising_datasets = {} if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None) and split.startswith("train"): for lang_pair in self.lang_pairs: _, tgt = lang_pair.split('-') if not split_exists(split, tgt, None, tgt): continue filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt)) tgt_dataset1 = indexed_dataset(filename, self.dicts[tgt]) tgt_dataset2 = indexed_dataset(filename, self.dicts[tgt]) noising_dataset = NoisingDataset( tgt_dataset1, self.dicts[tgt], seed=1, max_word_shuffle_distance=self.args.max_word_shuffle_distance, word_dropout_prob=self.args.word_dropout_prob, word_blanking_prob=self.args.word_blanking_prob, ) noising_datasets[lang_pair] = self.alter_dataset_langtok( LanguagePairDataset( noising_dataset, tgt_dataset1.sizes, self.dicts[tgt], tgt_dataset2, tgt_dataset2.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ), src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ) print('| denoising-{}: {} {} {} examples'.format( tgt, data_path, split, len(noising_datasets[lang_pair]), )) def language_pair_dataset(lang_pair): src, tgt = lang_pair.split('-') src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair] return self.alter_dataset_langtok( LanguagePairDataset( src_dataset, src_dataset.sizes, self.dicts[src], tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, ), self.dicts[src].eos(), src, self.dicts[tgt].eos(), tgt, ) self.datasets[split] = RoundRobinZipDatasets( OrderedDict([ (lang_pair, language_pair_dataset(lang_pair)) for lang_pair in src_datasets.keys() ] + [ (_get_bt_dataset_key(lang_pair), dataset) for lang_pair, dataset in backtranslate_datasets.items() ] + [ (_get_denoising_dataset_key(lang_pair), dataset) for lang_pair, dataset in noising_datasets.items() ]), eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), )
def load_dataset(self, split, epoch=1, **kwargs): """Load a dataset split.""" paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] def split_exists(split, src, tgt, lang): if src is not None: filename = os.path.join( data_path, "{}.{}-{}.{}".format(split, src, tgt, lang) ) else: filename = os.path.join( data_path, "{}.{}-None.{}".format(split, src, tgt) ) return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl) def load_indexed_dataset(path, dictionary): return data_utils.load_indexed_dataset( path, dictionary, self.args.dataset_impl ) # load parallel datasets src_datasets, tgt_datasets = {}, {} if ( self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train") ): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split("-") if split_exists(split, src, tgt, src): prefix = os.path.join( data_path, "{}.{}-{}.".format(split, src, tgt) ) elif split_exists(split, tgt, src, src): prefix = os.path.join( data_path, "{}.{}-{}.".format(split, tgt, src) ) else: continue src_datasets[lang_pair] = load_indexed_dataset( prefix + src, self.dicts[src] ) tgt_datasets[lang_pair] = load_indexed_dataset( prefix + tgt, self.dicts[tgt] ) logger.info( "parallel-{} {} {} examples".format( data_path, split, len(src_datasets[lang_pair]) ) ) if len(src_datasets) == 0: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, data_path) ) # back translation datasets backtranslate_datasets = {} if ( self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None ) and split.startswith("train"): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split("-") if not split_exists(split, tgt, None, tgt): raise FileNotFoundError( "Dataset not found: backtranslation {} ({})".format( split, data_path ) ) filename = os.path.join( data_path, "{}.{}-None.{}".format(split, tgt, tgt) ) dataset = load_indexed_dataset(filename, self.dicts[tgt]) lang_pair_dataset_tgt = LanguagePairDataset( dataset, dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) lang_pair_dataset = LanguagePairDataset( dataset, dataset.sizes, src_dict=self.dicts[src], tgt=dataset, tgt_sizes=dataset.sizes, tgt_dict=self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) backtranslate_datasets[lang_pair] = BacktranslationDataset( tgt_dataset=self.alter_dataset_langtok( lang_pair_dataset_tgt, src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_lang=src, ), backtranslation_fn=self.backtranslators[lang_pair], src_dict=self.dicts[src], tgt_dict=self.dicts[tgt], output_collater=self.alter_dataset_langtok( lang_pair_dataset=lang_pair_dataset, src_eos=self.dicts[src].eos(), src_lang=src, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ).collater, ) logger.info( "backtranslate-{}: {} {} {} examples".format( tgt, data_path, split, len(backtranslate_datasets[lang_pair]), ) ) self.backtranslate_datasets[lang_pair] = backtranslate_datasets[ lang_pair ] # denoising autoencoder noising_datasets = {} if ( self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None ) and split.startswith("train"): for lang_pair in self.lang_pairs: _, tgt = lang_pair.split("-") if not split_exists(split, tgt, None, tgt): continue filename = os.path.join( data_path, "{}.{}-None.{}".format(split, tgt, tgt) ) tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt]) tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt]) noising_dataset = NoisingDataset( tgt_dataset1, self.dicts[tgt], seed=1, max_word_shuffle_distance=self.args.max_word_shuffle_distance, word_dropout_prob=self.args.word_dropout_prob, word_blanking_prob=self.args.word_blanking_prob, ) noising_datasets[lang_pair] = self.alter_dataset_langtok( LanguagePairDataset( noising_dataset, tgt_dataset1.sizes, self.dicts[tgt], tgt_dataset2, tgt_dataset2.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ), src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ) logger.info( "denoising-{}: {} {} {} examples".format( tgt, data_path, split, len(noising_datasets[lang_pair]), ) ) def language_pair_dataset(lang_pair): src, tgt = lang_pair.split("-") src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair] return self.alter_dataset_langtok( LanguagePairDataset( src_dataset, src_dataset.sizes, self.dicts[src], tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ), self.dicts[src].eos(), src, self.dicts[tgt].eos(), tgt, ) self.datasets[split] = RoundRobinZipDatasets( OrderedDict( [ (lang_pair, language_pair_dataset(lang_pair)) for lang_pair in src_datasets.keys() ] + [ (_get_bt_dataset_key(lang_pair), dataset) for lang_pair, dataset in backtranslate_datasets.items() ] + [ (_get_denoising_dataset_key(lang_pair), dataset) for lang_pair, dataset in noising_datasets.items() ] ), eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), )
def load_dataset(self, split, epoch=0, combine=True, **kwargs): """Load a dataset split.""" paths = self.args.data.split(':') assert len(paths) > 0 # data_path = paths[epoch % len(paths)] para_data_path = paths[0] dae_data_paths = paths[1:] def split_exists(split, src, tgt, lang, data_path): if src is not None: filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) else: filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, src, tgt)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedDataset.exists(filename): return True return False def indexed_dataset(path, dictionary): if self.args.raw_text: return IndexedRawTextDataset(path, dictionary) elif IndexedDataset.exists(path): if self.args.lazy_load: return IndexedDataset(path, fix_lua_indexing=True) else: return IndexedCachedDataset(path, fix_lua_indexing=True) return None # load parallel datasets para_datasets = {} if (self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None): assert self.summ_pair is not None src, tgt = self.summ_pair.split('-') para_datasets[self.summ_pair] = load_langpair_dataset( para_data_path, split, src, self.dicts[src], tgt, self.dicts[tgt], combine=combine, dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, load_alignments=self.args.load_alignments, truncate_source=self.args.truncate_source, lang_pair=self.summ_pair, ) # denoising autoencoder noising_datasets = {} if (split != 'test' and self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None): for lang_pair, dae_data_path in zip(self.dae_pairs, dae_data_paths): _, tgt = lang_pair.split('-') if split_exists(split, tgt, None, tgt, dae_data_path): filename = os.path.join(dae_data_path, '{}.{}-None.{}'.format(split, tgt, tgt)) tgt_dataset1 = data_utils.load_indexed_dataset(filename, self.dicts[tgt], self.args.dataset_impl) tgt_dataset2 = data_utils.load_indexed_dataset(filename, self.dicts[tgt], self.args.dataset_impl) noising_dataset = NoisingDataset( tgt_dataset1, self.dicts[tgt], seed=1, max_word_shuffle_distance=self.args.max_word_shuffle_distance, word_dropout_prob=self.args.word_dropout_prob, word_blanking_prob=self.args.word_blanking_prob, ) noising_datasets[lang_pair] = self.alter_dataset_langtok( LanguagePairDataset( noising_dataset, tgt_dataset1.sizes, self.dicts[tgt], tgt_dataset2, tgt_dataset2.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, lang_pair=lang_pair, ), src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ) print('| denoising-{}: {} {} {} examples'.format( tgt, dae_data_path, split, len(noising_datasets[lang_pair]), )) else: raise ValueError('Target dataset of {} not existing!'.format(tgt)) self.datasets[split] = RoundRobinZipDatasets( OrderedDict([ (lang_pair, dataset) for lang_pair, dataset in para_datasets.items() ] + [ (_get_denoising_dataset_key(lang_pair), dataset) for lang_pair, dataset in noising_datasets.items() ]), eval_key=None if split != 'test' else (list(para_datasets.keys()) + list(noising_datasets.keys()))[0] )