def load_dataset( self, split, src_bin_path, tgt_bin_path, seed=None, use_noiser=False ): """ Load a dataset split. Seed and noiser are only used for loading train data, not eval data. """ parallel_dataset, src_dataset, tgt_dataset = data_utils.load_parallel_dataset( source_lang=self.source_lang, target_lang=self.target_lang, src_bin_path=src_bin_path, tgt_bin_path=tgt_bin_path, source_dictionary=self.source_dictionary, target_dictionary=self.target_dictionary, split=split, remove_eos_from_source=not self.args.append_eos_to_source, append_eos_to_target=True, char_source_dict=self.char_source_dict, log_verbose=self.args.log_verbose, ) dataset_map = OrderedDict( [(f"{self.source_lang}-{self.target_lang}", parallel_dataset)] ) monolingual_num_examples_limit = None if self.args.monolingual_ratio is not None: monolingual_num_examples_limit = int( self.args.monolingual_ratio * len(parallel_dataset) ) if use_noiser: if getattr(self.args, "denoising_source_parallel", False): dataset_map[ (f"{self.source_lang}-{self.source_lang}") ] = weighted_data.WeightedLanguagePairDataset( src=noising.NoisingDataset( src_dataset=src_dataset, src_dict=self.source_dictionary, seed=seed, noiser=self.source_noiser, ), tgt=src_dataset, src_sizes=src_dataset.sizes, src_dict=self.source_dictionary, remove_eos_from_source=not self.args.append_eos_to_source, append_eos_to_target=True, ) if getattr(self.args, "denoising_target_parallel", False): dataset_map[ (f"{self.target_lang}-{self.target_lang}") ] = weighted_data.WeightedLanguagePairDataset( src=noising.NoisingDataset( src_dataset=tgt_dataset, src_dict=self.target_dictionary, seed=seed, noiser=self.target_noiser, ), tgt=tgt_dataset, src_sizes=tgt_dataset.sizes, src_dict=self.target_dictionary, remove_eos_from_source=not self.args.append_eos_to_source, append_eos_to_target=True, ) if getattr(self.args, "denoising_source_mono", False): source_mono_dataset = self.load_monolingual_dataset( bin_path=self.args.train_mono_source_binary_path, is_source=True, num_examples_limit=monolingual_num_examples_limit, ) dataset_map[ ( f"{self.source_lang}-{self.source_lang}_" f"{constants.MONOLINGUAL_DATA_IDENTIFIER}" ) ] = weighted_data.WeightedLanguagePairDataset( src=noising.NoisingDataset( src_dataset=source_mono_dataset, src_dict=self.source_dictionary, seed=seed, noiser=self.source_noiser, ), tgt=source_mono_dataset, src_sizes=source_mono_dataset.sizes, src_dict=self.source_dictionary, remove_eos_from_source=not self.args.append_eos_to_source, append_eos_to_target=True, ) if getattr(self.args, "denoising_target_mono", False): target_mono_dataset = self.load_monolingual_dataset( bin_path=self.args.train_mono_target_binary_path, is_source=False, num_examples_limit=monolingual_num_examples_limit, ) dataset_map[ ( f"{self.target_lang}-{self.target_lang}_" f"{constants.MONOLINGUAL_DATA_IDENTIFIER}" ) ] = weighted_data.WeightedLanguagePairDataset( src=noising.NoisingDataset( src_dataset=target_mono_dataset, src_dict=self.target_dictionary, seed=seed, noiser=self.target_noiser, ), tgt=target_mono_dataset, src_sizes=target_mono_dataset.sizes, src_dict=self.target_dictionary, remove_eos_from_source=not self.args.append_eos_to_source, append_eos_to_target=True, ) # print before loading RoundRobinZipDatasets to help catch any bugs for dataset_key, dataset in dataset_map.items(): print(f"| {split}: {dataset_key} {len(dataset)} examples in dataset") self.datasets[split] = RoundRobinZipDatasets(dataset_map) print( f"| {split} {len(self.datasets[split])} examples in RoundRobinZipDatasets" ) if self.args.log_verbose: print("Finished loading dataset", flush=True) print(f"| {split} {len(self.datasets[split])} datasets")
def load_dataset(self, split, seed=None): """Load split, which is train (monolingual data, optional parallel data), or eval (always parallel data). """ if split == self.args.valid_subset: # tune set is always parallel primal_parallel, _, _ = data_utils.load_parallel_dataset( source_lang=self.source_lang, target_lang=self.target_lang, src_bin_path=self.args.forward_eval_source_binary_path, tgt_bin_path=self.args.forward_eval_target_binary_path, source_dictionary=self.primal_src_dict, target_dictionary=self.primal_tgt_dict, split=split, remove_eos_from_source=not self.args.append_eos_to_source, append_eos_to_target=True, char_source_dict=None, log_verbose=self.args.log_verbose, ) # now just flip the source and target dual_parallel, _, _ = data_utils.load_parallel_dataset( source_lang=self.target_lang, target_lang=self.source_lang, src_bin_path=self.args.backward_eval_source_binary_path, tgt_bin_path=self.args.backward_eval_target_binary_path, source_dictionary=self.dual_src_dict, target_dictionary=self.dual_tgt_dict, split=split, remove_eos_from_source=not self.args.append_eos_to_source, append_eos_to_target=True, char_source_dict=None, log_verbose=self.args.log_verbose, ) self.datasets[split] = RoundRobinZipDatasets( OrderedDict( [ ("primal_parallel", primal_parallel), ("dual_parallel", dual_parallel), ] ) ) elif split == self.args.train_subset: src_dataset = data_utils.load_monolingual_dataset( self.args.train_mono_source_binary_path, is_source=True ) tgt_dataset = data_utils.load_monolingual_dataset( self.args.train_mono_target_binary_path, is_source=True ) primal_source_mono = LanguagePairDataset( src=src_dataset, src_sizes=src_dataset.sizes, src_dict=self.primal_src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, ) dual_source_mono = LanguagePairDataset( src=tgt_dataset, src_sizes=tgt_dataset.sizes, src_dict=self.dual_src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, ) primal_parallel, _, _ = data_utils.load_parallel_dataset( source_lang=self.source_lang, target_lang=self.target_lang, src_bin_path=self.args.forward_train_source_binary_path, tgt_bin_path=self.args.forward_train_target_binary_path, source_dictionary=self.primal_src_dict, target_dictionary=self.primal_tgt_dict, split=split, remove_eos_from_source=not self.args.append_eos_to_source, append_eos_to_target=True, char_source_dict=None, log_verbose=self.args.log_verbose, ) dual_parallel, _, _ = data_utils.load_parallel_dataset( source_lang=self.target_lang, target_lang=self.source_lang, src_bin_path=self.args.backward_train_source_binary_path, tgt_bin_path=self.args.backward_train_target_binary_path, source_dictionary=self.dual_src_dict, target_dictionary=self.dual_tgt_dict, split=split, remove_eos_from_source=not self.args.append_eos_to_source, append_eos_to_target=True, char_source_dict=None, log_verbose=self.args.log_verbose, ) self.datasets[split] = RoundRobinZipDatasets( OrderedDict( [ ("primal_parallel", primal_parallel), ("dual_parallel", dual_parallel), ("primal_source", primal_source_mono), ("dual_source", dual_source_mono), ] ) ) else: raise ValueError("Invalid data split.")