def load_dataset(self, split, epoch=0, **kwargs): """Load a dataset split.""" paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] def split_exists(split, src, tgt, lang): if src is not None: filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) else: filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, src, tgt)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedDataset.exists(filename): return True return False def indexed_dataset(path, dictionary): if self.args.raw_text: return IndexedRawTextDataset(path, dictionary) elif IndexedDataset.exists(path): if self.args.lazy_load: return IndexedDataset(path, fix_lua_indexing=True) else: return IndexedCachedDataset(path, fix_lua_indexing=True) return None # load parallel datasets src_datasets, tgt_datasets = {}, {} if (self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train")): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split('-') if split_exists(split, src, tgt, src): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt)) elif split_exists(split, tgt, src, src): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src)) else: continue src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src]) tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt]) print('| parallel-{} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair]))) if len(src_datasets) == 0: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) # back translation datasets backtranslate_datasets = {} if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and split.startswith("train"): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split('-') if not split_exists(split, tgt, None, tgt): raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, data_path)) filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt)) dataset = indexed_dataset(filename, self.dicts[tgt]) lang_pair_dataset_tgt = LanguagePairDataset( dataset, dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) lang_pair_dataset = LanguagePairDataset( dataset, dataset.sizes, src_dict=self.dicts[src], tgt=dataset, tgt_sizes=dataset.sizes, tgt_dict=self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) backtranslate_datasets[lang_pair] = BacktranslationDataset( tgt_dataset=self.alter_dataset_langtok( lang_pair_dataset_tgt, src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_lang=src, ), backtranslation_fn=self.backtranslators[lang_pair], src_dict=self.dicts[src], tgt_dict=self.dicts[tgt], output_collater=self.alter_dataset_langtok( lang_pair_dataset=lang_pair_dataset, src_eos=self.dicts[src].eos(), src_lang=src, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ).collater, ) print('| backtranslate-{}: {} {} {} examples'.format( tgt, data_path, split, len(backtranslate_datasets[lang_pair]), )) self.backtranslate_datasets[lang_pair] = backtranslate_datasets[lang_pair] # denoising autoencoder noising_datasets = {} if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None) and split.startswith("train"): for lang_pair in self.lang_pairs: _, tgt = lang_pair.split('-') if not split_exists(split, tgt, None, tgt): continue filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt)) tgt_dataset1 = indexed_dataset(filename, self.dicts[tgt]) tgt_dataset2 = indexed_dataset(filename, self.dicts[tgt]) noising_dataset = NoisingDataset( tgt_dataset1, self.dicts[tgt], seed=1, max_word_shuffle_distance=self.args.max_word_shuffle_distance, word_dropout_prob=self.args.word_dropout_prob, word_blanking_prob=self.args.word_blanking_prob, ) noising_datasets[lang_pair] = self.alter_dataset_langtok( LanguagePairDataset( noising_dataset, tgt_dataset1.sizes, self.dicts[tgt], tgt_dataset2, tgt_dataset2.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ), src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ) print('| denoising-{}: {} {} {} examples'.format( tgt, data_path, split, len(noising_datasets[lang_pair]), )) def language_pair_dataset(lang_pair): src, tgt = lang_pair.split('-') src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair] return self.alter_dataset_langtok( LanguagePairDataset( src_dataset, src_dataset.sizes, self.dicts[src], tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, ), self.dicts[src].eos(), src, self.dicts[tgt].eos(), tgt, ) self.datasets[split] = RoundRobinZipDatasets( OrderedDict([ (lang_pair, language_pair_dataset(lang_pair)) for lang_pair in src_datasets.keys() ] + [ (_get_bt_dataset_key(lang_pair), dataset) for lang_pair, dataset in backtranslate_datasets.items() ] + [ (_get_denoising_dataset_key(lang_pair), dataset) for lang_pair, dataset in noising_datasets.items() ]), eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), )
def _backtranslation_dataset_helper( self, remove_eos_from_input_src, remove_eos_from_output_src, ): tgt_dataset = LanguagePairDataset( src=self.tgt_dataset, src_sizes=self.tgt_dataset.sizes, src_dict=self.tgt_dict, tgt=None, tgt_sizes=None, tgt_dict=None, ) generator = SequenceGenerator( models=[self.model], tgt_dict=self.tgt_dict, beam_size=2, unk_penalty=0, sampling=False, ) if self.cuda: generator.cuda() backtranslation_dataset = BacktranslationDataset( tgt_dataset=TransformEosDataset( dataset=tgt_dataset, eos=self.tgt_dict.eos(), # remove eos from the input src remove_eos_from_src=remove_eos_from_input_src, ), backtranslation_fn=generator.generate, max_len_a=0, max_len_b=200, output_collater=TransformEosDataset( dataset=tgt_dataset, eos=self.tgt_dict.eos(), # if we remove eos from the input src, then we need to add it # back to the output tgt append_eos_to_tgt=remove_eos_from_input_src, remove_eos_from_src=remove_eos_from_output_src, ).collater, cuda=self.cuda, ) dataloader = torch.utils.data.DataLoader( backtranslation_dataset, batch_size=2, collate_fn=backtranslation_dataset.collater, ) backtranslation_batch_result = next(iter(dataloader)) eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2 # Note that we sort by src_lengths and add left padding, so actually # ids will look like: [1, 0] expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]]) if remove_eos_from_output_src: expected_src = expected_src[:, :-1] expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]]) generated_src = backtranslation_batch_result["net_input"]["src_tokens"] tgt_tokens = backtranslation_batch_result["target"] self.assertTensorEqual(expected_src, generated_src) self.assertTensorEqual(expected_tgt, tgt_tokens)
def load_dataset( self, split, src_bin_path, tgt_bin_path, forward_model=None, backward_model=None ): """Load a dataset split.""" corpus = ptt_data.ParallelCorpusConfig( source=ptt_data.CorpusConfig( dialect=self.source_lang, data_file=src_bin_path ), target=ptt_data.CorpusConfig( dialect=self.target_lang, data_file=tgt_bin_path ), weights_file=None, ) if self.args.log_verbose: print("Starting to load binarized data files.", flush=True) data_utils.validate_corpus_exists(corpus=corpus, split=split) forward_tgt_dataset = ptt_data.InMemoryNumpyDataset.create_from_file( corpus.target.data_file ) backward_tgt_dataset = ptt_data.InMemoryNumpyDataset.create_from_file( corpus.source.data_file ) forward_src_dataset = ptt_data.InMemoryNumpyDataset.create_from_file( corpus.source.data_file ) backward_src_dataset = ptt_data.InMemoryNumpyDataset.create_from_file( corpus.target.data_file ) forward_parallel_dataset = weighted_data.WeightedLanguagePairDataset( src=forward_src_dataset, src_sizes=forward_src_dataset.sizes, src_dict=self.source_dictionary, tgt=forward_tgt_dataset, tgt_sizes=forward_tgt_dataset.sizes, tgt_dict=self.target_dictionary, remove_eos_from_source=self.remove_eos_from_source, append_eos_to_target=True, ) backward_parallel_dataset = weighted_data.WeightedLanguagePairDataset( src=backward_src_dataset, src_sizes=backward_src_dataset.sizes, src_dict=self.target_dictionary, tgt=backward_tgt_dataset, tgt_sizes=backward_tgt_dataset.sizes, tgt_dict=self.source_dictionary, remove_eos_from_source=self.remove_eos_from_source, append_eos_to_target=True, ) dataset_map = OrderedDict( [ (f"{self.source_lang}-{self.target_lang}", forward_parallel_dataset), (f"{self.target_lang}-{self.source_lang}", backward_parallel_dataset), ] ) assert (forward_model and backward_model) or ( forward_model is None and backward_model is None ), ( "Only one of forward or backward models can't be null;" " both have to be non-null or null" ) if forward_model and backward_model: fwd_generator = beam_decode.SequenceGenerator( models=[forward_model], tgt_dict=self.source_dictionary ) bwd_generator = beam_decode.SequenceGenerator( models=[backward_model], tgt_dict=self.target_dictionary ) def monolingual_dataset( path, dictionary, is_source=False, num_examples_limit: Optional[int] = None, ): dataset = self.load_monolingual_dataset( path, is_source=is_source, num_examples_limit=num_examples_limit ) return LanguagePairDataset( src=dataset, src_sizes=dataset.sizes, src_dict=dictionary, tgt=None, tgt_sizes=None, tgt_dict=None, ) monolingual_num_examples_limit = None if self.args.monolingual_ratio is not None: monolingual_num_examples_limit = int( self.args.monolingual_ratio * len(forward_parallel_dataset) ) src_dataset = monolingual_dataset( path=self.args.train_mono_source_binary_path, dictionary=self.source_dictionary, is_source=True, num_examples_limit=monolingual_num_examples_limit, ) tgt_dataset = monolingual_dataset( path=self.args.train_mono_target_binary_path, dictionary=self.target_dictionary, is_source=False, num_examples_limit=monolingual_num_examples_limit, ) dataset_map[ f"{self.source_lang}-" f"{self.target_lang}_{constants.MONOLINGUAL_DATA_IDENTIFIER}" ] = BacktranslationDataset( tgt_dataset=TransformEosDataset( dataset=tgt_dataset, eos=self.target_dictionary.eos(), # Remove EOS from the input before backtranslation. remove_eos_from_src=True, ), backtranslation_fn=bwd_generator.generate, max_len_a=self.args.max_len_a, max_len_b=self.args.max_len_b, output_collater=TransformEosDataset( dataset=tgt_dataset, eos=self.target_dictionary.eos(), # The original input (now the target) doesn't have # an EOS, so we need to add one. The generated # backtranslation (now the source) will have an EOS, # so we want to remove it. append_eos_to_tgt=True, remove_eos_from_src=True, ).collater, ) dataset_map[ f"{self.target_lang}-" f"{self.source_lang}_{constants.MONOLINGUAL_DATA_IDENTIFIER}" ] = BacktranslationDataset( tgt_dataset=src_dataset, backtranslation_fn=fwd_generator.generate, max_len_a=self.args.max_len_a, max_len_b=self.args.max_len_b, output_collater=TransformEosDataset( dataset=src_dataset, eos=self.source_dictionary.eos(), # The original input (now the target) doesn't have # an EOS, so we need to add one. The generated # backtranslation (now the source) will have an EOS, # so we want to remove it. append_eos_to_tgt=True, remove_eos_from_src=True, ).collater, ) # print before loading RoundRobinZipDatasets to help catch any bugs for dataset_key, dataset in dataset_map.items(): print(f"| {split}: {dataset_key} {len(dataset)} examples in dataset") self.datasets[split] = RoundRobinZipDatasets(dataset_map) print( f"| {split} {len(self.datasets[split])} examples in RoundRobinZipDatasets" ) if self.args.log_verbose: print("Finished loading dataset", flush=True)
def load_dataset(self, split, epoch=1, **kwargs): """Load a dataset split.""" paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] def split_exists(split, src, tgt, lang): if src is not None: filename = os.path.join( data_path, "{}.{}-{}.{}".format(split, src, tgt, lang) ) else: filename = os.path.join( data_path, "{}.{}-None.{}".format(split, src, tgt) ) return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl) def load_indexed_dataset(path, dictionary): return data_utils.load_indexed_dataset( path, dictionary, self.args.dataset_impl ) # load parallel datasets src_datasets, tgt_datasets = {}, {} if ( self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train") ): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split("-") if split_exists(split, src, tgt, src): prefix = os.path.join( data_path, "{}.{}-{}.".format(split, src, tgt) ) elif split_exists(split, tgt, src, src): prefix = os.path.join( data_path, "{}.{}-{}.".format(split, tgt, src) ) else: continue src_datasets[lang_pair] = load_indexed_dataset( prefix + src, self.dicts[src] ) tgt_datasets[lang_pair] = load_indexed_dataset( prefix + tgt, self.dicts[tgt] ) logger.info( "parallel-{} {} {} examples".format( data_path, split, len(src_datasets[lang_pair]) ) ) if len(src_datasets) == 0: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, data_path) ) # back translation datasets backtranslate_datasets = {} if ( self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None ) and split.startswith("train"): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split("-") if not split_exists(split, tgt, None, tgt): raise FileNotFoundError( "Dataset not found: backtranslation {} ({})".format( split, data_path ) ) filename = os.path.join( data_path, "{}.{}-None.{}".format(split, tgt, tgt) ) dataset = load_indexed_dataset(filename, self.dicts[tgt]) lang_pair_dataset_tgt = LanguagePairDataset( dataset, dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) lang_pair_dataset = LanguagePairDataset( dataset, dataset.sizes, src_dict=self.dicts[src], tgt=dataset, tgt_sizes=dataset.sizes, tgt_dict=self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) backtranslate_datasets[lang_pair] = BacktranslationDataset( tgt_dataset=self.alter_dataset_langtok( lang_pair_dataset_tgt, src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_lang=src, ), backtranslation_fn=self.backtranslators[lang_pair], src_dict=self.dicts[src], tgt_dict=self.dicts[tgt], output_collater=self.alter_dataset_langtok( lang_pair_dataset=lang_pair_dataset, src_eos=self.dicts[src].eos(), src_lang=src, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ).collater, ) logger.info( "backtranslate-{}: {} {} {} examples".format( tgt, data_path, split, len(backtranslate_datasets[lang_pair]), ) ) self.backtranslate_datasets[lang_pair] = backtranslate_datasets[ lang_pair ] # denoising autoencoder noising_datasets = {} if ( self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None ) and split.startswith("train"): for lang_pair in self.lang_pairs: _, tgt = lang_pair.split("-") if not split_exists(split, tgt, None, tgt): continue filename = os.path.join( data_path, "{}.{}-None.{}".format(split, tgt, tgt) ) tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt]) tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt]) noising_dataset = NoisingDataset( tgt_dataset1, self.dicts[tgt], seed=1, max_word_shuffle_distance=self.args.max_word_shuffle_distance, word_dropout_prob=self.args.word_dropout_prob, word_blanking_prob=self.args.word_blanking_prob, ) noising_datasets[lang_pair] = self.alter_dataset_langtok( LanguagePairDataset( noising_dataset, tgt_dataset1.sizes, self.dicts[tgt], tgt_dataset2, tgt_dataset2.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ), src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ) logger.info( "denoising-{}: {} {} {} examples".format( tgt, data_path, split, len(noising_datasets[lang_pair]), ) ) def language_pair_dataset(lang_pair): src, tgt = lang_pair.split("-") src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair] return self.alter_dataset_langtok( LanguagePairDataset( src_dataset, src_dataset.sizes, self.dicts[src], tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ), self.dicts[src].eos(), src, self.dicts[tgt].eos(), tgt, ) self.datasets[split] = RoundRobinZipDatasets( OrderedDict( [ (lang_pair, language_pair_dataset(lang_pair)) for lang_pair in src_datasets.keys() ] + [ (_get_bt_dataset_key(lang_pair), dataset) for lang_pair, dataset in backtranslate_datasets.items() ] + [ (_get_denoising_dataset_key(lang_pair), dataset) for lang_pair, dataset in noising_datasets.items() ] ), eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), )
def load_dataset(self, split, epoch=0, **kwargs): """Load a dataset split.""" paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] def split_exists(split, src, tgt, lang): if src is not None: filename = os.path.join( data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) else: filename = os.path.join( data_path, '{}.{}-None.{}'.format(split, src, tgt)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedDataset.exists(filename): return True return False if split == 'train' and self.args.bt_parallel_update > 0: lang_pairs = [] copied_lang_pairs = [p for p in self.lang_pairs] for lang_pair in copied_lang_pairs: src, tgt = lang_pair.split('-') key = '{}-{}'.format(tgt, src) lang_pairs.append(key) lang_pairs.append(lang_pair) else: lang_pairs = self.lang_pairs # load parallel datasets src_datasets, tgt_datasets = {}, {} for lang_pair in lang_pairs: src, tgt = lang_pair.split('-') if split_exists(split, src, tgt, src): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt)) elif split_exists(split, tgt, src, src): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src)) else: continue src_datasets[lang_pair] = data_utils.load_indexed_dataset( prefix + src, self.dicts[src]) tgt_datasets[lang_pair] = data_utils.load_indexed_dataset( prefix + tgt, self.dicts[tgt]) print('| parallel-{} {} {} examples'.format( data_path, split, len(src_datasets[lang_pair]))) if len(src_datasets) == 0: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) def language_pair_dataset(lang_pair): src, tgt = lang_pair.split('-') src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[ lang_pair] return self.alter_dataset_langtok( LanguagePairDataset( src_dataset, src_dataset.sizes, self.dicts[src], tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, ), self.dicts[src].eos(), src, self.dicts[tgt].eos(), tgt, ) # back translation datasets backtranslate_datasets = {} if split.startswith("train"): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split('-') if not split_exists(split, tgt, None, tgt): raise FileNotFoundError( 'Dataset not found: backtranslation {} ({})'.format( split, data_path)) filename = os.path.join( data_path, '{}.{}-None.{}'.format(split, tgt, tgt)) dataset = data_utils.load_indexed_dataset( filename, self.dicts[tgt]) lang_pair_dataset_tgt = LanguagePairDataset( dataset, dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) #lang_pair_dataset = LanguagePairDataset( # dataset, # dataset.sizes, # src_dict=self.dicts[src], # tgt=dataset, # tgt_sizes=dataset.sizes, # tgt_dict=self.dicts[tgt], # left_pad_source=self.args.left_pad_source, # left_pad_target=self.args.left_pad_target, #) backtranslate_datasets[lang_pair] = BacktranslationDataset( tgt_dataset=self.alter_dataset_langtok( lang_pair_dataset_tgt, src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_lang=src, ), backtranslation_fn=self.backtranslators[lang_pair], src_dict=self.dicts[src], tgt_dict=self.dicts[tgt], output_collater=language_pair_dataset(lang_pair).collater, noising=self.args.noise_bt_dds, ) print('| backtranslate-{}: {} {} {} examples'.format( tgt, data_path, split, len(backtranslate_datasets[lang_pair]), )) self.backtranslate_datasets[ lang_pair] = backtranslate_datasets[lang_pair] if split == 'train': self.datasets[split] = RoundRobinZipDatasets( OrderedDict([(lang_pair, language_pair_dataset(lang_pair)) for lang_pair in src_datasets.keys()] + [(_get_bt_dataset_key(lang_pair), dataset) for lang_pair, dataset in backtranslate_datasets.items()]), eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), upsample_factor=self.args.upsample_factor, ) else: self.datasets[split] = RoundRobinZipDatasets( OrderedDict([(lang_pair, language_pair_dataset(lang_pair)) for lang_pair in src_datasets.keys()]), eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), ) if split == 'valid' and self.args.bt_dds: if self.args.max_tokens_valid is not None: max_tokens_valid = self.args.max_tokens_valid / 4 else: max_tokens_valid = None if self.args.max_sentences_valid is not None: max_sentences_valid = self.args.max_sentences_valid / 4 else: max_sentences_valid = None self.dev_itr = self.get_batch_iterator( dataset=self.dataset('valid'), max_tokens=max_tokens_valid, max_sentences=max_sentences_valid, max_positions=utils.resolve_max_positions( self.max_positions(), ), ignore_invalid_inputs=self.args. skip_invalid_size_inputs_valid_test, required_batch_size_multiple=self.args. required_batch_size_multiple, seed=self.args.seed, num_shards=self.args.distributed_world_size, shard_id=self.args.distributed_rank, num_workers=self.args.num_workers, noskip=True, )[0] self.dev_itr.next_epoch_itr(shuffle=True)
def load_dataset(self, split, src_bin_path, tgt_bin_path, forward_model=None, backward_model=None): """Load a dataset split.""" corpus = ptt_data.ParallelCorpusConfig( source=ptt_data.CorpusConfig(dialect=self.source_lang, data_file=src_bin_path), target=ptt_data.CorpusConfig(dialect=self.target_lang, data_file=tgt_bin_path), weights_file=None, ) if self.args.log_verbose: print("Starting to load binarized data files.", flush=True) data_utils.validate_corpus_exists(corpus=corpus, split=split) forward_tgt_dataset = ptt_data.InMemoryNumpyDataset.create_from_file( corpus.target.data_file) backward_tgt_dataset = ptt_data.InMemoryNumpyDataset.create_from_file( corpus.source.data_file) forward_src_dataset = ptt_data.InMemoryNumpyDataset.create_from_file( corpus.source.data_file) backward_src_dataset = ptt_data.InMemoryNumpyDataset.create_from_file( corpus.target.data_file) forward_parallel_dataset = weighted_data.WeightedLanguagePairDataset( src=forward_src_dataset, src_sizes=forward_src_dataset.sizes, src_dict=self.source_dictionary, tgt=forward_tgt_dataset, tgt_sizes=forward_tgt_dataset.sizes, tgt_dict=self.target_dictionary, remove_eos_from_source=self.remove_eos_from_source, append_eos_to_target=True, ) backward_parallel_dataset = weighted_data.WeightedLanguagePairDataset( src=backward_src_dataset, src_sizes=backward_src_dataset.sizes, src_dict=self.target_dictionary, tgt=backward_tgt_dataset, tgt_sizes=backward_tgt_dataset.sizes, tgt_dict=self.source_dictionary, remove_eos_from_source=self.remove_eos_from_source, append_eos_to_target=True, ) dataset_map = OrderedDict([ (f"{self.source_lang}-{self.target_lang}", forward_parallel_dataset), (f"{self.target_lang}-{self.source_lang}", backward_parallel_dataset), ]) assert (forward_model and backward_model) or ( forward_model is None and backward_model is None), ( "Only one of forward or backward models can't be null;" " both have to be non-null or null") if forward_model and backward_model: dataset_map[ f"{self.source_lang}-" f"{self.target_lang}_{constants.MONOLINGUAL_DATA_IDENTIFIER}"] = BacktranslationDataset( tgt_dataset=self.load_monolingual_dataset( self.args.train_mono_target_binary_path), tgt_dict=self.target_dictionary, backtranslation_model=backward_model, max_len_a=self.args.max_len_a, max_len_b=self.args.max_len_b, remove_eos_at_src=True, generator_class=beam_decode.SequenceGenerator, ) dataset_map[ f"{self.target_lang}-" f"{self.source_lang}_{constants.MONOLINGUAL_DATA_IDENTIFIER}"] = BacktranslationDataset( tgt_dataset=self.load_monolingual_dataset( self.args.train_mono_source_binary_path), tgt_dict=self.source_dictionary, backtranslation_model=forward_model, max_len_a=self.args.max_len_a, max_len_b=self.args.max_len_b, remove_eos_at_src=True, generator_class=beam_decode.SequenceGenerator, ) self.datasets[split] = RoundRobinZipDatasets(dataset_map) if self.args.log_verbose: print("Finished loading dataset", flush=True) print(f"| {split} {len(self.datasets[split])} datasets")