def _get_noising_dataset_batch( self, src_tokens_no_pad, src_dict, append_eos_to_tgt=False, ): """ Constructs a NoisingDataset and the corresponding ``LanguagePairDataset(NoisingDataset(src), src)``. If *append_eos_to_tgt* is True, wrap the source dataset in :class:`TransformEosDataset` to append EOS to the clean source when using it as the target. """ src_dataset = test_utils.TestDataset(data=src_tokens_no_pad) noising_dataset = noising.NoisingDataset( src_dataset=src_dataset, src_dict=src_dict, seed=1234, max_word_shuffle_distance=3, word_dropout_prob=0.2, word_blanking_prob=0.2, noising_class=noising.UnsupervisedMTNoising, ) tgt = src_dataset language_pair_dataset = LanguagePairDataset(src=noising_dataset, tgt=tgt, src_sizes=None, src_dict=src_dict) language_pair_dataset = TransformEosDataset( language_pair_dataset, src_dict.eos(), append_eos_to_tgt=append_eos_to_tgt, ) dataloader = torch.utils.data.DataLoader( dataset=language_pair_dataset, batch_size=2, collate_fn=language_pair_dataset.collater, ) denoising_batch_result = next(iter(dataloader)) return denoising_batch_result
def build_dataset_for_inference(self, src_tokens, src_lengths): return TransformEosDataset( MonolingualDataset( TokenBlockDataset( src_tokens, src_lengths, block_size=None, pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode='eos', include_targets=False, ), src_lengths, self.source_dictionary, self.target_dictionary, add_eos_for_other_targets=False, shuffle=False, ), eos=self.source_dictionary.eos(), # remove EOS since this will be used as a prefix for generation remove_eos_from_src=True, has_target=False, )
def _backtranslation_dataset_helper( self, remove_eos_from_input_src, remove_eos_from_output_src, ): tgt_dataset = LanguagePairDataset( src=self.tgt_dataset, src_sizes=self.tgt_dataset.sizes, src_dict=self.tgt_dict, tgt=None, tgt_sizes=None, tgt_dict=None, ) generator = SequenceGenerator( models=[self.model], tgt_dict=self.tgt_dict, beam_size=2, unk_penalty=0, sampling=False, ) if self.cuda: generator.cuda() backtranslation_dataset = BacktranslationDataset( tgt_dataset=TransformEosDataset( dataset=tgt_dataset, eos=self.tgt_dict.eos(), # remove eos from the input src remove_eos_from_src=remove_eos_from_input_src, ), backtranslation_fn=generator.generate, max_len_a=0, max_len_b=200, output_collater=TransformEosDataset( dataset=tgt_dataset, eos=self.tgt_dict.eos(), # if we remove eos from the input src, then we need to add it # back to the output tgt append_eos_to_tgt=remove_eos_from_input_src, remove_eos_from_src=remove_eos_from_output_src, ).collater, cuda=self.cuda, ) dataloader = torch.utils.data.DataLoader( backtranslation_dataset, batch_size=2, collate_fn=backtranslation_dataset.collater, ) backtranslation_batch_result = next(iter(dataloader)) eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2 # Note that we sort by src_lengths and add left padding, so actually # ids will look like: [1, 0] expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]]) if remove_eos_from_output_src: expected_src = expected_src[:, :-1] expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]]) generated_src = backtranslation_batch_result["net_input"]["src_tokens"] tgt_tokens = backtranslation_batch_result["target"] self.assertTensorEqual(expected_src, generated_src) self.assertTensorEqual(expected_tgt, tgt_tokens)
def load_dataset( self, split, src_bin_path, tgt_bin_path, forward_model=None, backward_model=None ): """Load a dataset split.""" corpus = ptt_data.ParallelCorpusConfig( source=ptt_data.CorpusConfig( dialect=self.source_lang, data_file=src_bin_path ), target=ptt_data.CorpusConfig( dialect=self.target_lang, data_file=tgt_bin_path ), weights_file=None, ) if self.args.log_verbose: print("Starting to load binarized data files.", flush=True) data_utils.validate_corpus_exists(corpus=corpus, split=split) forward_tgt_dataset = ptt_data.InMemoryNumpyDataset.create_from_file( corpus.target.data_file ) backward_tgt_dataset = ptt_data.InMemoryNumpyDataset.create_from_file( corpus.source.data_file ) forward_src_dataset = ptt_data.InMemoryNumpyDataset.create_from_file( corpus.source.data_file ) backward_src_dataset = ptt_data.InMemoryNumpyDataset.create_from_file( corpus.target.data_file ) forward_parallel_dataset = weighted_data.WeightedLanguagePairDataset( src=forward_src_dataset, src_sizes=forward_src_dataset.sizes, src_dict=self.source_dictionary, tgt=forward_tgt_dataset, tgt_sizes=forward_tgt_dataset.sizes, tgt_dict=self.target_dictionary, remove_eos_from_source=self.remove_eos_from_source, append_eos_to_target=True, ) backward_parallel_dataset = weighted_data.WeightedLanguagePairDataset( src=backward_src_dataset, src_sizes=backward_src_dataset.sizes, src_dict=self.target_dictionary, tgt=backward_tgt_dataset, tgt_sizes=backward_tgt_dataset.sizes, tgt_dict=self.source_dictionary, remove_eos_from_source=self.remove_eos_from_source, append_eos_to_target=True, ) dataset_map = OrderedDict( [ (f"{self.source_lang}-{self.target_lang}", forward_parallel_dataset), (f"{self.target_lang}-{self.source_lang}", backward_parallel_dataset), ] ) assert (forward_model and backward_model) or ( forward_model is None and backward_model is None ), ( "Only one of forward or backward models can't be null;" " both have to be non-null or null" ) if forward_model and backward_model: fwd_generator = beam_decode.SequenceGenerator( models=[forward_model], tgt_dict=self.source_dictionary ) bwd_generator = beam_decode.SequenceGenerator( models=[backward_model], tgt_dict=self.target_dictionary ) def monolingual_dataset( path, dictionary, is_source=False, num_examples_limit: Optional[int] = None, ): dataset = self.load_monolingual_dataset( path, is_source=is_source, num_examples_limit=num_examples_limit ) return LanguagePairDataset( src=dataset, src_sizes=dataset.sizes, src_dict=dictionary, tgt=None, tgt_sizes=None, tgt_dict=None, ) monolingual_num_examples_limit = None if self.args.monolingual_ratio is not None: monolingual_num_examples_limit = int( self.args.monolingual_ratio * len(forward_parallel_dataset) ) src_dataset = monolingual_dataset( path=self.args.train_mono_source_binary_path, dictionary=self.source_dictionary, is_source=True, num_examples_limit=monolingual_num_examples_limit, ) tgt_dataset = monolingual_dataset( path=self.args.train_mono_target_binary_path, dictionary=self.target_dictionary, is_source=False, num_examples_limit=monolingual_num_examples_limit, ) dataset_map[ f"{self.source_lang}-" f"{self.target_lang}_{constants.MONOLINGUAL_DATA_IDENTIFIER}" ] = BacktranslationDataset( tgt_dataset=TransformEosDataset( dataset=tgt_dataset, eos=self.target_dictionary.eos(), # Remove EOS from the input before backtranslation. remove_eos_from_src=True, ), backtranslation_fn=bwd_generator.generate, max_len_a=self.args.max_len_a, max_len_b=self.args.max_len_b, output_collater=TransformEosDataset( dataset=tgt_dataset, eos=self.target_dictionary.eos(), # The original input (now the target) doesn't have # an EOS, so we need to add one. The generated # backtranslation (now the source) will have an EOS, # so we want to remove it. append_eos_to_tgt=True, remove_eos_from_src=True, ).collater, ) dataset_map[ f"{self.target_lang}-" f"{self.source_lang}_{constants.MONOLINGUAL_DATA_IDENTIFIER}" ] = BacktranslationDataset( tgt_dataset=src_dataset, backtranslation_fn=fwd_generator.generate, max_len_a=self.args.max_len_a, max_len_b=self.args.max_len_b, output_collater=TransformEosDataset( dataset=src_dataset, eos=self.source_dictionary.eos(), # The original input (now the target) doesn't have # an EOS, so we need to add one. The generated # backtranslation (now the source) will have an EOS, # so we want to remove it. append_eos_to_tgt=True, remove_eos_from_src=True, ).collater, ) # print before loading RoundRobinZipDatasets to help catch any bugs for dataset_key, dataset in dataset_map.items(): print(f"| {split}: {dataset_key} {len(dataset)} examples in dataset") self.datasets[split] = RoundRobinZipDatasets(dataset_map) print( f"| {split} {len(self.datasets[split])} examples in RoundRobinZipDatasets" ) if self.args.log_verbose: print("Finished loading dataset", flush=True)