def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, pad_to_multiple=1, prepend_bos_src=None, bert_model_name=None, bart_model_name=None, electra_model_name=None, electra_pretrain=False, denoising=False, masking=False, extra_data=False, input_mapping=False, mask_ratio=None, random_ratio=None, insert_ratio=None, rotate_ratio=None, permute_sentence_ratio=None, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name, do_lower_case=False) if denoising: bart_tokenizer = AutoTokenizer.from_pretrained(bart_model_name, do_lower_case=False) #bart_tokenizer = BartTokenizer.from_pretrained(bart_model_name, do_lower_case=False) if electra_pretrain: electra_tokenizer = ElectraTokenizer.from_pretrained( electra_model_name) srcbert_datasets = [] extra_datasets = [] extra_bert_datasets = [] extra_bert_mapping_datasets = [] extra_bart_datasets = [] extra_bart_mapping_datasets = [] if denoising: srcbart_datasets = [] if electra_pretrain: srcelectra_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) bertprefix = os.path.join( data_path, '{}.bert.{}-{}.'.format(split_k, src, tgt)) bert_mapping_prefix = os.path.join( data_path, '{}.bert.map.{}-{}.'.format(split_k, src, tgt)) if denoising: bartprefix = os.path.join( data_path, '{}.bart.{}-{}.'.format(split_k, src, tgt)) bart_mapping_prefix = os.path.join( data_path, '{}.bart.map.{}-{}.'.format(split_k, src, tgt)) if electra_pretrain: electraprefix = os.path.join( data_path, '{}.electra.{}-{}.'.format(split_k, src, tgt)) electra_mapping_prefix = os.path.join( data_path, '{}.electra.map.{}-{}.'.format(split_k, src, tgt)) if extra_data: extraprefix = os.path.join( data_path, '{}.extra.{}-{}.'.format(split_k, src, tgt)) extra_bert_prefix = os.path.join( data_path, '{}.extra.bert.{}-{}.'.format(split_k, src, tgt)) extra_bert_mapping_prefix = os.path.join( data_path, '{}.extra.bert.map.{}-{}.'.format(split_k, src, tgt)) extra_bart_prefix = os.path.join( data_path, '{}.extra.bart.{}-{}.'.format(split_k, src, tgt)) extra_bart_mapping_prefix = os.path.join( data_path, '{}.extra.bart.map.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) bertprefix = os.path.join( data_path, '{}.bert.{}-{}.'.format(split_k, tgt, src)) bert_mapping_prefix = os.path.join( data_path, '{}.bert.map.{}-{}.'.format(split_k, src, tgt)) if denoising: bartprefix = os.path.join( data_path, '{}.bart.{}-{}.'.format(split_k, tgt, src)) bart_mapping_prefix = os.path.join( data_path, '{}.bart.map.{}-{}.'.format(split_k, src, tgt)) if electra_pretrain: electraprefix = os.path.join( data_path, '{}.electra.{}-{}.'.format(split_k, src, tgt)) electra_mapping_prefix = os.path.join( data_path, '{}.electra.map.{}-{}.'.format(split_k, src, tgt)) if extra_data: extraprefix = os.path.join( data_path, '{}.extra.{}-{}.'.format(split_k, src, tgt)) extra_bert_prefix = os.path.join( data_path, '{}.extra.bert.{}-{}.'.format(split_k, src, tgt)) extra_bert_mapping_prefix = os.path.join( data_path, '{}.extra.bert.map.{}-{}.'.format(split_k, src, tgt)) extra_bart_prefix = os.path.join( data_path, '{}.extra.bart.{}-{}.'.format(split_k, src, tgt)) extra_bart_mapping_prefix = os.path.join( data_path, '{}.extra.bart.map.{}-{}.'.format(split_k, src, tgt)) else: if k > 0: break else: raise FileNotFoundError("Dataset not found: {} ({})".format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) # srcbert_datasets.append(indexed_dataset.make_dataset(bertprefix + src, impl=dataset_impl, # fix_lua_indexing=True, )) # if denoising: # srcbart_datasets.append(indexed_dataset.make_dataset(bartprefix + src, impl=dataset_impl, # fix_lua_indexing=True, )) # if extra_data: # extra_datasets.append(indexed_dataset.make_dataset(extraprefix + src, impl=dataset_impl, # fix_lua_indexing=True, )) srcbert_datasets.append( data_utils.load_indexed_dataset( bertprefix + src, dataset_impl=dataset_impl, )) if denoising: srcbart_datasets.append( data_utils.load_indexed_dataset( bartprefix + src, dataset_impl=dataset_impl, )) if electra_pretrain: srcelectra_datasets.append( data_utils.load_indexed_dataset( electraprefix + src, dataset_impl=dataset_impl, )) if extra_data and split == 'train': extra_datasets.append( data_utils.load_indexed_dataset( extraprefix + src, dataset_impl=dataset_impl, )) extra_bert_datasets.append( data_utils.load_indexed_dataset( extra_bert_prefix + src, dataset_impl=dataset_impl, )) extra_bert_mapping_datasets.append( data_utils.load_indexed_dataset( extra_bert_mapping_prefix + src, dataset_impl=dataset_impl, )) extra_bart_datasets.append( data_utils.load_indexed_dataset( extra_bart_prefix + src, dataset_impl=dataset_impl, )) extra_bart_mapping_datasets.append( data_utils.load_indexed_dataset( extra_bart_mapping_prefix + src, dataset_impl=dataset_impl, )) #import pdb; pdb.set_trace() assert extra_datasets != [] or extra_bert_datasets != [] or extra_bert_mapping_datasets != [] or extra_bart_datasets != [] or extra_bart_mapping_datasets != [] #extra_datasets = extra_datasets[0] #import pdb; pdb.set_trace() src_datasets[-1] = PrependTokenDataset(src_datasets[-1], token=src_dict.bos_index) if extra_data and split == 'train': extra_datasets[-1] = PrependTokenDataset(extra_datasets[-1], token=src_dict.bos_index) if denoising is True: if input_mapping is True and split == 'train': bart_mapping_dataset = data_utils.load_indexed_dataset( bart_mapping_prefix + src, dataset_impl=dataset_impl) else: bart_mapping_dataset = None src_datasets[-1] = DenoisingBartDataset( src_datasets[-1], src_datasets[-1].sizes, src_dict, srcbart_datasets[-1], srcbart_datasets[-1].sizes, bart_tokenizer, map_dataset=bart_mapping_dataset, mask_ratio=mask_ratio, random_ratio=random_ratio, insert_ratio=insert_ratio, rotate_ratio=rotate_ratio, permute_sentence_ratio=permute_sentence_ratio, ) if electra_pretrain is True: if input_mapping is True and split == 'train': electra_mapping_dataset = data_utils.load_indexed_dataset( electra_mapping_prefix + src, dataset_impl=dataset_impl) else: electra_mapping_dataset = None src_datasets[-1] = ElectrapretrainDataset( src_datasets[-1], src_datasets[-1].sizes, src_dict, srcelectra_datasets[-1], srcelectra_datasets[-1].sizes, electra_tokenizer, map_dataset=electra_mapping_dataset, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, ) if masking is True: if input_mapping is True and split == 'train': #bert_mapping_dataset = indexed_dataset.make_dataset(bert_mapping_prefix + src, impl=dataset_impl, fix_lua_indexing=True) bert_mapping_dataset = data_utils.load_indexed_dataset( bert_mapping_prefix + src, dataset_impl=dataset_impl) else: bert_mapping_dataset = None src_datasets[-1] = MaskingDataset( src_datasets[-1], src_datasets[-1].sizes, src_dict, srcbert_datasets[-1], srcbert_datasets[-1].sizes, bert_tokenizer, map_dataset=bert_mapping_dataset, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, ) if extra_data is True and split == 'train': assert input_mapping is True src_datasets[-1] = MaskingExtraDataset( src_datasets[-1], src_datasets[-1].sizes, src_dict, extra_datasets[-1], extra_datasets[-1].sizes, extra_bert_datasets[-1], extra_bert_datasets[-1].sizes, bert_tokenizer, map_dataset=extra_bert_mapping_datasets[-1], left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, ) src_datasets[-1] = DenoisingBartExtraDataset( src_datasets[-1], src_datasets[-1].sizes, src_dict, extra_datasets[-1], extra_datasets[-1].sizes, extra_bart_datasets[-1], extra_bart_datasets[-1].sizes, bart_tokenizer, map_dataset=extra_bart_mapping_datasets[-1], ) logger.info("{} {} {}-{} {} examples".format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None # srcbert_datasets = srcbert_datasets[0] # if denoising: # srcbart_datasets = srcbart_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) elif prepend_bos_src is not None: logger.info(f"prepending src bos: {prepend_bos_src}") src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index("[{}]".format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index("[{}]".format(tgt))) eos = tgt_dict.index("[{}]".format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None src_bart_dataset = None src_bert_dataset = None src_electra_dataset = None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, masking, src_bert_dataset, denoising, src_bart_dataset, src_electra_dataset, #extra_datasets, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, )
def load_dataset(self, split, epoch=1, **kwargs): """Load a dataset split.""" def indexed_dataset(path, dictionary): if self.args.raw_text: raise Exception("Unable to handle raw text.") dataset = IndexedDataset(path, fix_lua_indexing=True) return dataset pair_datasets = OrderedDict() if split == "valid": self.datasets[split] = pair_datasets return if split not in self.config: raise FileNotFoundError( "Dataset not found in config file: {}".format(split) ) size_by_corpus = defaultdict(int) size_sum = 0 size_sum_with_subsampling = 0 init_pair_datasets = {} for dataset_config in self.config[split]: src_path = os.path.dirname(dataset_config["src"]) corpus_name = src_path.split("/")[-2] language_pair_name = src_path.split("/")[-1] pair_datasets_key = corpus_name + "-" + language_pair_name logger.info(f"loading... {pair_datasets_key}") if "src" in dataset_config: src_dataset = indexed_dataset( dataset_config["src"], self.src_dictionary ) else: src_dataset = None if "tgt" in dataset_config: tgt_dataset = indexed_dataset( dataset_config["tgt"], self.tgt_dictionary ) else: tgt_dataset = None dataset = LanguagePairDataset( src_dataset, src_dataset.sizes, self.src_dictionary, tgt_dataset, tgt_dataset.sizes, self.tgt_dictionary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) if pair_datasets_key in init_pair_datasets: logger.warning( f"Ignoring already added {pair_datasets_key}. " f"Consider using `sample` key in order to upsample." ) else: init_pair_datasets[pair_datasets_key] = { "dataset": dataset, "sample": dataset_config.get("sample", None), "id": dataset_config.get("id", None), "len": len(dataset), } length_sum = 0 weighted_freqs_sum = 0 freq_per_dataset = {} vmax = 0 vmin = 1 weighted_freq_per_dataset = {} if self.args.weighting_alpha: for key in init_pair_datasets: if init_pair_datasets[key]["sample"] is None: length_sum += len(init_pair_datasets[key]["dataset"]) for key in init_pair_datasets: if init_pair_datasets[key]["sample"] is None: val = float(init_pair_datasets[key]["len"]) / length_sum freq_per_dataset[key] = val weighted_freqs_sum += val ** self.args.weighting_alpha for key in freq_per_dataset: val = ( freq_per_dataset[key] ** self.args.weighting_alpha / weighted_freqs_sum ) vmin = min(vmin, val) vmax = max(vmax, val) weighted_freq_per_dataset[key] = val for pair_datasets_key in init_pair_datasets: dataset_config = init_pair_datasets[pair_datasets_key] dataset = dataset_config["dataset"] sample = dataset_config["sample"] if sample is None: sample = 1.0 if pair_datasets_key in weighted_freq_per_dataset: w = vmax / weighted_freq_per_dataset[pair_datasets_key] sample = w sample = round(sample) initial_sample = sample initial_pair_datasets_key = pair_datasets_key while sample >= 1.0: assert ( pair_datasets_key not in pair_datasets ), f"{pair_datasets_key} already in" size_sum_with_subsampling += len(dataset) pair_datasets[pair_datasets_key] = MultitaskDatasetWrapper( dataset, dataset_config.get("id", 0), 1.0, name=pair_datasets_key ) size_sum += len(dataset) sample -= 1.0 pair_datasets_key += "-up" assert sample < 1e-6, f"sample remains > 0 {pair_datasets_key}" logger.info( f"added pair {initial_pair_datasets_key} length {len(dataset)} new_length = {len(dataset)*initial_sample}" ) size_by_corpus[corpus_name] += len(dataset) self.datasets[split] = pair_datasets logger.info( f"Datasets number = {len(self.datasets[split])} size = {size_sum} size_sum_with_subsampling = {size_sum_with_subsampling}" )
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] def split_exists(split, src, tgt, lang, data_path): filename = os.path.join( data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode src, tgt = self.args.source_lang, self.args.target_lang if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format(split, data_path)) src_datasets.append( indexed_dataset.make_dataset(prefix + src, impl=self.args.dataset_impl, fix_lua_indexing=True, dictionary=self.src_dict)) tgt_datasets.append( indexed_dataset.make_dataset(prefix + tgt, impl=self.args.dataset_impl, fix_lua_indexing=True, dictionary=self.tgt_dict)) print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = self.args.upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) self.datasets[split] = LanguagePairDataset( src_dataset, src_dataset.sizes, self.src_dict, tgt_dataset, tgt_dataset.sizes, self.tgt_dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, )
def build_dataset_for_inference(self, src_tokens, src_lengths): return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary)
def load_dataset(self, split, epoch=0, **kwargs): """Load a dataset split.""" paths = self.args.data.split(os.pathsep) assert len(paths) > 0 data_path = paths[epoch % len(paths)] def split_exists(split, src, tgt, lang): if src is not None: filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) else: filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, src, tgt)) return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl) def load_indexed_dataset(path, dictionary): return data_utils.load_indexed_dataset(path, dictionary, self.args.dataset_impl) # load parallel datasets src_datasets, tgt_datasets = {}, {} if (self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train")): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split('-') if split_exists(split, src, tgt, src): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt)) elif split_exists(split, tgt, src, src): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src)) else: continue src_datasets[lang_pair] = load_indexed_dataset(prefix + src, self.dicts[src]) tgt_datasets[lang_pair] = load_indexed_dataset(prefix + tgt, self.dicts[tgt]) logger.info('parallel-{} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair]))) if len(src_datasets) == 0: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) # back translation datasets backtranslate_datasets = {} if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and split.startswith("train"): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split('-') if not split_exists(split, tgt, None, tgt): raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, data_path)) filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt)) dataset = load_indexed_dataset(filename, self.dicts[tgt]) lang_pair_dataset_tgt = LanguagePairDataset( dataset, dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) lang_pair_dataset = LanguagePairDataset( dataset, dataset.sizes, src_dict=self.dicts[src], tgt=dataset, tgt_sizes=dataset.sizes, tgt_dict=self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) backtranslate_datasets[lang_pair] = BacktranslationDataset( tgt_dataset=self.alter_dataset_langtok( lang_pair_dataset_tgt, src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_lang=src, ), backtranslation_fn=self.backtranslators[lang_pair], src_dict=self.dicts[src], tgt_dict=self.dicts[tgt], output_collater=self.alter_dataset_langtok( lang_pair_dataset=lang_pair_dataset, src_eos=self.dicts[src].eos(), src_lang=src, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ).collater, ) logger.info('backtranslate-{}: {} {} {} examples'.format( tgt, data_path, split, len(backtranslate_datasets[lang_pair]), )) self.backtranslate_datasets[lang_pair] = backtranslate_datasets[lang_pair] # denoising autoencoder noising_datasets = {} if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None) and split.startswith("train"): for lang_pair in self.lang_pairs: _, tgt = lang_pair.split('-') if not split_exists(split, tgt, None, tgt): continue filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt)) tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt]) tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt]) noising_dataset = NoisingDataset( tgt_dataset1, self.dicts[tgt], seed=1, max_word_shuffle_distance=self.args.max_word_shuffle_distance, word_dropout_prob=self.args.word_dropout_prob, word_blanking_prob=self.args.word_blanking_prob, ) noising_datasets[lang_pair] = self.alter_dataset_langtok( LanguagePairDataset( noising_dataset, tgt_dataset1.sizes, self.dicts[tgt], tgt_dataset2, tgt_dataset2.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ), src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ) logger.info('denoising-{}: {} {} {} examples'.format( tgt, data_path, split, len(noising_datasets[lang_pair]), )) def language_pair_dataset(lang_pair): src, tgt = lang_pair.split('-') src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair] return self.alter_dataset_langtok( LanguagePairDataset( src_dataset, src_dataset.sizes, self.dicts[src], tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, ), self.dicts[src].eos(), src, self.dicts[tgt].eos(), tgt, ) self.datasets[split] = RoundRobinZipDatasets( OrderedDict([ (lang_pair, language_pair_dataset(lang_pair)) for lang_pair in src_datasets.keys() ] + [ (_get_bt_dataset_key(lang_pair), dataset) for lang_pair, dataset in backtranslate_datasets.items() ] + [ (_get_denoising_dataset_key(lang_pair), dataset) for lang_pair, dataset in noising_datasets.items() ]), eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), )
def load_dataset(self, split, seed=None): """Load split, which is train (monolingual data, optional parallel data), or eval (always parallel data). """ if split == self.args.valid_subset: # tune set is always parallel primal_parallel, _, _ = data_utils.load_parallel_dataset( source_lang=self.source_lang, target_lang=self.target_lang, src_bin_path=self.args.forward_eval_source_binary_path, tgt_bin_path=self.args.forward_eval_target_binary_path, source_dictionary=self.primal_src_dict, target_dictionary=self.primal_tgt_dict, split=split, remove_eos_from_source=not self.args.append_eos_to_source, append_eos_to_target=True, char_source_dict=None, log_verbose=self.args.log_verbose, ) # now just flip the source and target dual_parallel, _, _ = data_utils.load_parallel_dataset( source_lang=self.target_lang, target_lang=self.source_lang, src_bin_path=self.args.backward_eval_source_binary_path, tgt_bin_path=self.args.backward_eval_target_binary_path, source_dictionary=self.dual_src_dict, target_dictionary=self.dual_src_dict, split=split, remove_eos_from_source=not self.args.append_eos_to_source, append_eos_to_target=True, char_source_dict=None, log_verbose=self.args.log_verbose, ) self.datasets[split] = RoundRobinZipDatasets( OrderedDict([ ("primal_parallel", primal_parallel), ("dual_parallel", dual_parallel), ])) elif split == self.args.train_subset: src_dataset = data_utils.load_monolingual_dataset( self.args.train_mono_source_binary_path, is_source=True) tgt_dataset = data_utils.load_monolingual_dataset( self.args.train_mono_target_binary_path, is_source=True) primal_source_mono = LanguagePairDataset( src=src_dataset, src_sizes=src_dataset.sizes, src_dict=self.primal_src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, ) dual_source_mono = LanguagePairDataset( src=tgt_dataset, src_sizes=tgt_dataset.sizes, src_dict=self.dual_src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, ) primal_parallel, _, _ = data_utils.load_parallel_dataset( source_lang=self.source_lang, target_lang=self.target_lang, src_bin_path=self.args.forward_train_source_binary_path, tgt_bin_path=self.args.forward_train_target_binary_path, source_dictionary=self.primal_src_dict, target_dictionary=self.primal_tgt_dict, split=split, remove_eos_from_source=not self.args.append_eos_to_source, append_eos_to_target=True, char_source_dict=None, log_verbose=self.args.log_verbose, ) dual_parallel, _, _ = data_utils.load_parallel_dataset( source_lang=self.target_lang, target_lang=self.source_lang, src_bin_path=self.args.backward_train_source_binary_path, tgt_bin_path=self.args.backward_train_target_binary_path, source_dictionary=self.dual_src_dict, target_dictionary=self.dual_src_dict, split=split, remove_eos_from_source=not self.args.append_eos_to_source, append_eos_to_target=True, char_source_dict=None, log_verbose=self.args.log_verbose, ) self.datasets[split] = RoundRobinZipDatasets( OrderedDict([ ("primal_parallel", primal_parallel), ("dual_parallel", dual_parallel), ("primal_source", primal_source_mono), ("dual_source", dual_source_mono), ])) else: raise ValueError("Invalid data split.")
def _backtranslation_dataset_helper( self, remove_eos_from_input_src, remove_eos_from_output_src, ): tgt_dataset = LanguagePairDataset( src=self.tgt_dataset, src_sizes=self.tgt_dataset.sizes, src_dict=self.tgt_dict, tgt=None, tgt_sizes=None, tgt_dict=None, ) generator = SequenceGenerator( [self.model], tgt_dict=self.tgt_dict, max_len_a=0, max_len_b=200, beam_size=2, unk_penalty=0, ) backtranslation_dataset = BacktranslationDataset( tgt_dataset=TransformEosDataset( dataset=tgt_dataset, eos=self.tgt_dict.eos(), # remove eos from the input src remove_eos_from_src=remove_eos_from_input_src, ), src_dict=self.tgt_dict, backtranslation_fn=( lambda sample: generator.generate([self.model], sample)), output_collater=TransformEosDataset( dataset=tgt_dataset, eos=self.tgt_dict.eos(), # if we remove eos from the input src, then we need to add it # back to the output tgt append_eos_to_tgt=remove_eos_from_input_src, remove_eos_from_src=remove_eos_from_output_src, ).collater, cuda=self.cuda, ) dataloader = torch.utils.data.DataLoader( backtranslation_dataset, batch_size=2, collate_fn=backtranslation_dataset.collater, ) backtranslation_batch_result = next(iter(dataloader)) eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad( ), self.w1, self.w2 # Note that we sort by src_lengths and add left padding, so actually # ids will look like: [1, 0] expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]]) if remove_eos_from_output_src: expected_src = expected_src[:, :-1] expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]]) generated_src = backtranslation_batch_result["net_input"]["src_tokens"] tgt_tokens = backtranslation_batch_result["target"] self.assertTensorEqual(expected_src, generated_src) self.assertTensorEqual(expected_tgt, tgt_tokens)
d1 = vocab d2 = vocab token1 = x.t() tokens_ds1 = TokenBlockDataset( token1, sizes=src_lengths, break_mode='complete', block_size=1, pad=0, eos=1, include_targets=False, ) token2 = x.t() tokens_ds2 = TokenBlockDataset( token2, sizes=src_lengths, break_mode='complete', block_size=1, pad=0, eos=1, include_targets=False, ) p_tokens_ds2 = PermutedDataset(tokens_ds2, d2, seed=123) dataset = LanguagePairDataset(tokens_ds1, tokens_ds1.sizes, d1, tokens_ds2, tokens_ds2.sizes, d2, shuffle=False)
def _load_dataset_multi_path_helper( self, split: str, src_multiple_bin_paths: Dict[str, str], tgt_multiple_bin_paths: Dict[str, str], dataset_upsampling: Optional[Dict[str, float]] = None, dataset_relative_ratio: Optional[Tuple[str, float]] = None, seed: Optional[int] = None, noiser: Optional[Dict[str, UnsupervisedMTNoising]] = None, is_npz: bool = True, ): corpora_map = pytorch_translate_data.ParallelCorporaMapConfig( src_files=src_multiple_bin_paths, tgt_files=tgt_multiple_bin_paths) datasets = OrderedDict() for key in corpora_map.src_files: src, tgt = corpora_map.src_files[key], corpora_map.tgt_files[key] tgt_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file( tgt, is_npz=is_npz) if self.char_source_dict is not None: src_dataset = char_data.InMemoryNumpyWordCharDataset.create_from_file( src) else: src_dataset = pytorch_translate_data.InMemoryIndexedDataset.create_from_file( src, is_npz=is_npz) src_sizes = src_dataset.sizes if noiser is not None and key in noiser: src_dataset = NoisingDataset( src_dataset=src_dataset, src_dict=self.source_dictionary, seed=seed, noiser=noiser[key], ) if self.char_source_dict is not None: datasets[key] = char_data.LanguagePairSourceCharDataset( src=src_dataset, src_sizes=src_sizes, src_dict=self.source_dictionary, tgt=tgt_dataset, tgt_sizes=tgt_dataset.sizes, tgt_dict=self.target_dictionary, ) else: datasets[key] = LanguagePairDataset( src=src_dataset, src_sizes=src_sizes, src_dict=self.source_dictionary, tgt=tgt_dataset, tgt_sizes=tgt_dataset.sizes, tgt_dict=self.target_dictionary, left_pad_source=False, ) total_line_count = sum(len(datasets[key]) for key in datasets) if dataset_relative_ratio: ds, ratio = dataset_relative_ratio line_count = len(datasets[ds]) # By definition ratio = u * line_count / sum(#lines of other datasets) u = (total_line_count - line_count) / line_count * ratio dataset_upsampling = {key: u} elif not dataset_upsampling: dataset_upsampling = {} print(f"|dataset upsampling:{dataset_upsampling}") ds_list = [] sample_ratios = [] for key, val in datasets.items(): ds_list.append(val) sample_ratios.append(int(dataset_upsampling.get(key, 1))) self.datasets[split] = LanguagePairUpsamplingDataset( datasets=datasets.values(), sample_ratios=sample_ratios)
def load_langpair_dataset(data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, load_cls_labels=False, load_cls_indices=False, load_sample_weights=False, truncate_source=False, append_source_id=False, shuffle=True): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None src_prepended_bos = False if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) src_prepended_bos = True eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) eos = tgt_dict.index('[{}]'.format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None cls_dataset = None if load_cls_labels: cls_labels_path = os.path.join(data_path, '{}.cls'.format(split)) if indexed_dataset.dataset_exists(cls_labels_path, impl=dataset_impl): cls_dataset = data_utils.load_indexed_dataset( cls_labels_path, None, dataset_impl) if truncate_source: cls_dataset = AppendTokenDataset( TruncateDataset( TruncateLastElementDataset(cls_dataset), max_source_positions - 1, ), -1, # will ignore -1 label in training ) if src_prepended_bos: cls_dataset = PrependTokenDataset(cls_dataset, -1) else: print("cls_labels dataset NOT FOUND!", cls_labels_path) cls_indices_dataset = None if load_cls_indices: cls_indices_path = os.path.join(data_path, '{}.cls_ind'.format(split)) if indexed_dataset.dataset_exists(cls_indices_path, impl=dataset_impl): cls_indices_dataset = data_utils.load_indexed_dataset( cls_indices_path, None, dataset_impl) sample_weights = None if load_sample_weights: weights_file = os.path.join( data_path, '{}.{}-{}.weights.npy'.format(split, src, tgt)) assert os.path.exists(weights_file) with open(weights_file, 'rb') as f: sample_weights = np.load(f) logger.info('Loaded {} weights from {}'.format(len(sample_weights), weights_file)) return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=align_dataset, eos=eos, cls_dataset=cls_dataset, cls_indices_dataset=cls_indices_dataset, sample_weights=sample_weights, shuffle=shuffle, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, feature_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_features=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) eos = tgt_dict.index('[{}]'.format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) # print("feature_dict", feature_dict.symbols, feature_dict.count) #feature_dict ['<s>', '<pad>', '</s>', '<unk>', '<ori>', '<rep>', 'madeupword0000', 'madeupword0001'] [1, 1, 1, 1, 18558611, 5354704, 0, 0] feature_dataset = None if load_features: feature_path = os.path.join( data_path, '{}.feature.{}-{}.{}'.format(split, src, tgt, src)) if indexed_dataset.dataset_exists(feature_path, impl=dataset_impl): feature_dataset = data_utils.load_indexed_dataset( feature_path, feature_dict, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, feature_dataset=feature_dataset, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, )
def load_dataset(self, split, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" logger.info("load dataset start") prefix = os.path.join(self.args.data, '{}.input-label'.format(split)) # Read input sentences. sentences, lengths = [], [] with open(prefix + '.input', encoding='utf-8') as file: for line in file: sentence = line.strip() #print('sentence: {} '.format((sentence))) # Tokenize the sentence, splitting on spaces tokens = self.input_vocab.encode_line( sentence, add_if_not_exist=False, ) #print('token: {} '.format((tokens))) #token: tensor([48, 4, 13, 15, 5, 8, 2], dtype=torch.int32) sentences.append(tokens) lengths.append(tokens.numel()) # print(lengths) [7, 8, 8, 5, 12, 6, 6, 5 ... # Read labels. labels = [] with open(prefix + '.label', encoding='utf-8') as file: print(prefix + '.label') for line in file: label = line.strip() # print('label: {} '.format((label))) labels.append( # Convert label to a numeric ID. torch.LongTensor([self.label_vocab.add_symbol(label)])) #print(labels[0]) tensor([5]) # if label == 'Russian': # print(self.label_vocab.index('Russian')) # print(self.label_vocab.count[4]) print("lables are {}".format(np.unique(labels))) print(self.label_vocab.indices.keys()) print(self.label_vocab.indices.values()) for i in range(len(self.label_vocab.count)): print(self.label_vocab.symbols[i]) print(self.label_vocab.count[i]) print('label_vocab: {} '.format(self.label_vocab.values())) assert len(sentences) == len(labels) print('| {} {} {} examples'.format(self.args.data, split, len(sentences))) # We reuse LanguagePairDataset since classification can be modeled as a # sequence-to-sequence task where the target sequence has length 1. self.datasets[split] = LanguagePairDataset( src=sentences, src_sizes=lengths, src_dict=self.input_vocab, tgt=labels, tgt_sizes=torch.ones(len(labels)), # targets have length 1 tgt_dict=self.label_vocab, left_pad_source=False, # Since our target is a single class label, there's no need for # teacher forcing. If we set this to ``True`` then our Model's # ``forward()`` method would receive an additional argument called # *prev_output_tokens* that would contain a shifted version of the # target sequence. input_feeding=False, ) print(self.datasets[split]) print("load dataset complete") assert len(sentences) == len(labels)
def build_dataset_for_inference(self, src_tokens, src_lengths, tgt_tokens=None, tgt_lengths=None, num_source_inputs=1): if num_source_inputs == 1: return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary, tgt=tgt_tokens, tgt_sizes=tgt_lengths) else: return MultiSourceTranslationDataset(src_tokens, src_lengths, self.source_dictionary, tgt=tgt_tokens, tgt_sizes=tgt_lengths)
def build_dataset_for_evaluation(self, src_tokens, src_lengths, tgt_tokens, tgt_lengths): return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary, tgt_tokens, tgt_lengths)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary, tgt_dict=self.target_dictionary, constraints=constraints)
def load_dataset(self, split, epoch=0, **kwargs): if self.retrieve_fn is None: self.build_model(self.args) # raise ValueError( # "retrieve_fn is None !" # ) retrieve_dataset = None if self.retrieve_pool is None: paths = self.args.data.split(os.pathsep) assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset(split_path, self.dictionary, self.args.dataset_impl) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) lang_pair_dataset = LanguagePairDataset( dataset, dataset.sizes, self.src_dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) if split == self.args.retrieve_split: print("split {} is used as the retrieve_pool".format(split)) retrieve_dataset = lang_pair_dataset else: print("loading the retrieve split {}".format( self.args.retrieve_split)) split_path = os.path.join(self.args.data, self.args.retrieve_split) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl) if dataset is None: raise FileNotFoundError( "Dataset not found: {} ({})".format( self.args.retrieve_split, split_path)) if self.args.prune_num > 0: retrieve_dataset = LanguagePairMapDataset( dataset, dataset.sizes, self.src_dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) else: retrieve_dataset = LanguagePairDataset( dataset, dataset.sizes, self.src_dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) self.retrieve_pool = retrieve_dataset elif split == self.args.retrieve_split: print( "skip reading split {} since it is used as the retrieve_pool". format(split)) lang_pair_dataset = self.retrieve_pool else: paths = self.args.data.split(os.pathsep) assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset(split_path, self.dictionary, self.args.dataset_impl) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) lang_pair_dataset = LanguagePairDataset( dataset, dataset.sizes, self.src_dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) # always use unbiased estimator at test time # Avoid selecting self as templates at training time if 'train' not in split and self.args.criterion != 'guu_elbo': sampling = True masks = None else: def read_mask(fpath): with open(fpath) as fin: return [int(x.rstrip()) for x in fin] sampling = options.eval_bool(self.args.reinforce) if os.path.exists(os.path.join(self.args.data, 'mask_id.txt')): masks = read_mask(os.path.join(self.args.data, 'mask_id.txt')) else: masks = None self.datasets[split] = RetrievePrototypeDataset( lang_pair_dataset, self.src_dict, retrieve_dataset=self.retrieve_pool, retrieve_fn=self.retrieve_fn, cuda=not self.args.cpu, num_samples=self.args.infer_ns, temperature=self.args.reinforce_temperature, sampling=sampling, edit_dict=self.edit_dict, split=split, masks=masks, )
def load_dataset(self, split, **kwargs): def split_exists(split, lang): filename = os.path.join(self.args.data, '{}.{}'.format(split, lang)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedDataset.exists(filename): return True return False def split_para_exists(split, key, lang): filename = os.path.join(self.args.data, '{}.{}.{}'.format(split, key, lang)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedDataset.exists(filename): return True return False def indexed_dataset(path, dictionary): if self.args.raw_text: return IndexedRawTextDataset(path, dictionary) elif IndexedDataset.exists(path): if self.args.lazy_load: return IndexedDataset(path, fix_lua_indexing=True) else: return IndexedCachedDataset(path, fix_lua_indexing=True) return None src_mono_datasets = {} for lang_pair in self.args.mono_lang_pairs: lang = lang_pair.split('-')[0] if split_exists(split, lang): prefix = os.path.join(self.args.data, '{}.{}'.format(split, lang)) else: raise FileNotFoundError( 'Not Found available {} dataset for ({}) lang'.format( split, lang)) src_mono_datasets[lang_pair] = indexed_dataset( prefix, self.dicts[lang]) print('| monolingual {}-{}: {} examples'.format( split, lang, len(src_mono_datasets[lang_pair]))) src_para_datasets = {} for lang_pair in self.args.para_lang_pairs: src, tgt = lang_pair.split('-') key = '-'.join(sorted([src, tgt])) if not split_para_exists(split, key, src): raise FileNotFoundError( 'Not Found available {}-{} para dataset for ({}) lang'. format(split, key, src)) if not split_para_exists(split, key, tgt): raise FileNotFoundError( 'Not Found available {}-{} para dataset for ({}) lang'. format(split, key, tgt)) prefix = os.path.join(self.args.data, '{}.{}'.format(split, key)) if '{}.{}'.format(key, src) not in src_para_datasets: src_para_datasets[key + '.' + src] = indexed_dataset( prefix + '.' + src, self.dicts[src]) if '{}.{}'.format(key, tgt) not in src_para_datasets: src_para_datasets[key + '.' + tgt] = indexed_dataset( prefix + '.' + tgt, self.dicts[tgt]) print('| bilingual {} {}-{}.{}: {} examples'.format( split, src, tgt, src, len(src_para_datasets[key + '.' + src]))) print('| bilingual {} {}-{}.{}: {} examples'.format( split, src, tgt, tgt, len(src_para_datasets[key + '.' + tgt]))) mt_para_dataset = {} for lang_pair in self.args.mt_steps: src, tgt = lang_pair.split('-') key = '-'.join(sorted([src, tgt])) src_key = key + '.' + src tgt_key = key + '.' + tgt src_dataset = src_para_datasets[src_key] tgt_dataset = src_para_datasets[tgt_key] mt_para_dataset[lang_pair] = LanguagePairDataset( src_dataset, src_dataset.sizes, self.dicts[src], tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, ) eval_para_dataset = {} if split != 'train': for lang_pair in self.args.valid_lang_pairs: src, tgt = lang_pair.split('-') if src == tgt: src_key = src + '-' + tgt tgt_key = src + '-' + tgt src_dataset = src_mono_datasets[src_key] tgt_dataset = src_mono_datasets[tgt_key] else: key = '-'.join(sorted([src, tgt])) src_key = key + '.' + src tgt_key = key + '.' + tgt src_dataset = src_para_datasets[src_key] tgt_dataset = src_para_datasets[tgt_key] eval_para_dataset[lang_pair] = LanguagePairDataset( src_dataset, src_dataset.sizes, self.dicts[src], tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, ) memt_para_dataset = {} if split == 'train': for lang_pair in self.args.memt_steps: src, tgt = lang_pair.split('-') key = '-'.join(sorted([src, tgt])) src_key = key + '.' + src tgt_key = key + '.' + tgt src_id, tgt_id = self.args.langs_id[src], self.args.langs_id[ tgt] src_dataset = src_para_datasets[src_key] tgt_dataset = src_para_datasets[tgt_key] memt_para_dataset[lang_pair] = NoisyLanguagePairDataset( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset.sizes, self.dicts[src], self.dicts[tgt], src_id, tgt_id, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, ratio=self.args.word_mask, pred_probs=self.args.pred_probs, ) mass_mono_datasets = {} if split == 'train': for lang_pair in self.args.mass_steps: src_dataset = src_mono_datasets[lang_pair] lang = lang_pair.split('-')[0] mass_mono_dataset = MaskedLanguagePairDataset( src_dataset, src_dataset.sizes, self.dicts[lang], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, shuffle=True, lang_id=self.args.langs_id[lang], ratio=self.args.word_mask, pred_probs=self.args.pred_probs, ) mass_mono_datasets[lang_pair] = mass_mono_dataset self.datasets[split] = RoundRobinZipDatasets(OrderedDict( [(_get_mt_dataset_key(lang_pair), mt_para_dataset[lang_pair]) for lang_pair in mt_para_dataset.keys()] + [(_get_memt_dataset_key(lang_pair), memt_para_dataset[lang_pair]) for lang_pair in memt_para_dataset.keys()] + [(_get_mass_dataset_key(lang_pair), mass_mono_datasets[lang_pair]) for lang_pair in mass_mono_datasets.keys()] + [(_get_mt_dataset_key(lang_pair), eval_para_dataset[lang_pair]) for lang_pair in eval_para_dataset.keys()]), eval_key=None if self.training else self.args.eval_lang_pair)
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, shuffle, is_infer, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_datasets.append( data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl)) tgt_datasets.append( data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl)) print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) # for infer step using dataset not truncate if is_infer: return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, shuffle=shuffle, ) # for train and valid step using truncate truncate dataset else: return TruncateLanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, DEFAULT_MAX_SRC_LEN, tgt_dataset, tgt_dataset.sizes, tgt_dict, DEFAULT_MAX_TGT_LEN, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, shuffle=shuffle, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, src_tag=None, tgt_tag=None, src_tau=-1, tgt_tau=-1, epoch=0, id_to_sample_probabilities=None, lm=None, idx_to_src_gradnorm=None ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) src_datasets.append( data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) ) tgt_datasets.append( data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) ) print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, src_tag=src_tag, tgt_tag=tgt_tag, src_tau=src_tau, tgt_tau=tgt_tau, id_to_sample_probabilities=id_to_sample_probabilities, lm=lm, idx_to_src_gradnorm=idx_to_src_gradnorm, )
def load_dataset(self, split, combine=False, only_train=False): """Load a dataset split.""" def split_exists(split, src, tgt, lang): filename = os.path.join( self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedInMemoryDataset.exists( filename): return True return False def indexed_dataset(path, dictionary): if self.args.raw_text: return IndexedRawTextDataset(path, dictionary) elif IndexedInMemoryDataset.exists(path): return IndexedInMemoryDataset(path, fix_lua_indexing=True) return None src_datasets = [] tgt_datasets = [] pivot_datasets = [] mt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode src, tgt, pivot, mt = \ self.args.source_lang, self.args.target_lang, self.args.p, self.args.mt if split_exists(split_k, src, tgt, src): prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src): prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format( split, self.args.data)) src_datasets.append(indexed_dataset(prefix + src, self.src_dict)) tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict)) if only_train: pivot_datasets.append( indexed_dataset(prefix + pivot, self.tgt_dict)) mt_datasets.append(indexed_dataset(prefix + mt, self.tgt_dict)) print('| {} {} {} examples'.format(self.args.data, split_k, len(src_datasets[-1]))) if not combine: break if only_train: assert len(src_datasets) == len(tgt_datasets) == len( pivot_datasets) == len(mt_datasets) else: assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] src_sizes = src_dataset.sizes tgt_sizes = tgt_dataset.sizes if only_train: pivot_dataset = pivot_datasets[0] pivot_sizes = pivot_dataset.sizes mt_dataset = mt_datasets[0] mt_sizes = mt_dataset.sizes else: pivot_dataset = None pivot_sizes = None mt_dataset = None mt_sizes = None else: src_dataset = ConcatDataset(src_datasets) tgt_dataset = ConcatDataset(tgt_datasets) src_sizes = np.concatenate([ds.sizes for ds in src_datasets]) tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets]) if only_train: pivot_dataset = ConcatDataset(pivot_datasets) pivot_sizes = np.concatenate( [ds.sizes for ds in pivot_datasets]) mt_dataset = ConcatDataset(mt_datasets) mt_sizes = np.concatenate([ds.sizes for ds in mt_datasets]) else: pivot_dataset = None pivot_sizes = None mt_dataset = None mt_sizes = None self.datasets[split] = LanguagePairDataset( src_dataset, src_sizes, self.src_dict, pivot_dataset, pivot_sizes, mt_dataset, mt_sizes, tgt_dataset, tgt_sizes, self.tgt_dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def split_exists(split, src, tgt, lang, data_path): filename = os.path.join( data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedDataset.exists(filename): return True return False def indexed_dataset(path, dictionary): if self.args.raw_text: return IndexedRawTextDataset(path, dictionary) elif IndexedDataset.exists(path): return IndexedCachedDataset(path, fix_lua_indexing=True) return None src_datasets = [] tgt_datasets = [] data_paths = self.args.data for dk, data_path in enumerate(data_paths): for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode src, tgt = self.args.source_lang, self.args.target_lang if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join( data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join( data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0 or dk > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format( split, data_path)) src_datasets.append( indexed_dataset(prefix + src, self.src_dict)) tgt_datasets.append( indexed_dataset(prefix + tgt, self.tgt_dict)) print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = self.args.upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) self.datasets[split] = LanguagePairDataset( src_dataset, src_dataset.sizes, self.src_dict, tgt_dataset, tgt_dataset.sizes, self.tgt_dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, extra_feature_dicts, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] extra_feature_datasets = defaultdict(list) for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_datasets.append( data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl)) tgt_datasets.append( data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl)) if extra_feature_dicts: for i, feature_type in enumerate(extra_feature_dicts): extra_feature_datasets[feature_type].append( data_utils.load_indexed_dataset( prefix + feature_type, extra_feature_dicts[feature_type], dataset_impl)) print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] extra_feature_datasets = { feature_type: datasets[0] for feature_type, datasets in extra_feature_datasets.items() } else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) extra_feature_datasets = { feature_type: ConcatDataset(datasets) for feature_type, datasets in extra_feature_datasets.items() } if len(extra_feature_datasets.keys()) > 0: return LanguagePairDatasetWithExtraFeatures( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, extra_feature_dicts=extra_feature_dicts, extra_features=extra_feature_datasets) else: return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_datasets.append( data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl)) logger.info('{} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=align_dataset, )
def build_dataset_for_inference(self, src_tokens, src_lengths): # TODO: add extra-features if exists return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary)
def load_langpair_dataset( self, data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, src_dataset_transform_func=lambda dataset: dataset, tgt_dataset_transform_func=lambda dataset: dataset, src_lang_id=None, tgt_lang_id=None, langpairs_sharing_datasets=None, ): norm_direction = "-".join(sorted([src, tgt])) if langpairs_sharing_datasets is not None: src_dataset = langpairs_sharing_datasets.get( (data_path, split, norm_direction, src), "NotInCache" ) tgt_dataset = langpairs_sharing_datasets.get( (data_path, split, norm_direction, tgt), "NotInCache" ) align_dataset = langpairs_sharing_datasets.get( (data_path, split, norm_direction, src, tgt), "NotInCache" ) # a hack: any one is not in cache, we need to reload them if ( langpairs_sharing_datasets is None or src_dataset == "NotInCache" or tgt_dataset == "NotInCache" or align_dataset == "NotInCache" or split != getattr(self.args, "train_subset", None) ): # source and target datasets can be reused in reversed directions to save memory # reversed directions of valid and test data will not share source and target datasets src_dataset, tgt_dataset, align_dataset = self.load_lang_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, max_source_positions=max_source_positions, prepend_bos=prepend_bos, load_alignments=load_alignments, truncate_source=truncate_source, ) src_dataset = src_dataset_transform_func(src_dataset) tgt_dataset = tgt_dataset_transform_func(tgt_dataset) if langpairs_sharing_datasets is not None: langpairs_sharing_datasets[ (data_path, split, norm_direction, src) ] = src_dataset langpairs_sharing_datasets[ (data_path, split, norm_direction, tgt) ] = tgt_dataset langpairs_sharing_datasets[ (data_path, split, norm_direction, src, tgt) ] = align_dataset if align_dataset is None: # no align data so flag the reverse direction as well in sharing langpairs_sharing_datasets[ (data_path, split, norm_direction, tgt, src) ] = align_dataset else: logger.info( f"Reusing source and target datasets of [{split}] {tgt}-{src} for reversed direction: " f"[{split}] {src}-{tgt}: src length={len(src_dataset)}; tgt length={len(tgt_dataset)}" ) return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes if tgt_dataset is not None else None, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, src_lang_id=src_lang_id, tgt_lang_id=tgt_lang_id, )
def load_dataset(self, split, combine=False): """Load a dataset split.""" def split_exists(split, src, tgt, lang): filename = os.path.join( self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang)) print('filename:', filename) print('raw_text:', self.args.raw_text) if self.args.raw_text and IndexedRawTokenIDDataset.exists( filename): return True elif not self.args.raw_text and IndexedInMemoryDataset.exists( filename): return True return False def indexed_dataset(path, dictionary): if self.args.raw_text and not self.args.uniform_n_seq_per_batch and not self.args.uniform_seq_len_per_batch: return IndexedRawTokenIDDataset(path, dictionary) elif IndexedInMemoryDataset.exists( path ) and not self.args.uniform_n_seq_per_batch and not self.args.uniform_seq_len_per_batch: return IndexedInMemoryDataset(path) elif self.args.uniform_n_seq_per_batch and self.args.uniform_seq_len_per_batch: if self.args.uniform_n_seq_in_dataset: return MockedInMemoryDataset( path, self.args.uniform_n_seq_in_dataset, self.args.uniform_n_seq_per_batch, self.args.uniform_seq_len_per_batch) return None src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode src, tgt = self.args.source_lang, self.args.target_lang if split_exists(split_k, src, tgt, src): prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src): prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format( split, self.args.data)) src_datasets.append(indexed_dataset(prefix + src, self.src_dict)) tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict)) print('| {} {} {} examples'.format(self.args.data, split_k, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] src_sizes = src_dataset.sizes tgt_sizes = tgt_dataset.sizes else: src_dataset = ConcatDataset(src_datasets) tgt_dataset = ConcatDataset(tgt_datasets) src_sizes = np.concatenate([ds.sizes for ds in src_datasets]) tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets]) print('srcline:', src_dataset[0]) self.datasets[split] = LanguagePairDataset( src_dataset, src_sizes, self.src_dict, tgt_dataset, tgt_sizes, self.tgt_dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, seq_len_multiple=self.args.seq_len_multiple, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, pad_to_multiple=1, prepend_bos_src=None, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError("Dataset not found: {} ({})".format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info("{} {} {}-{} {} examples".format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) elif prepend_bos_src is not None: logger.info(f"prepending src bos: {prepend_bos_src}") src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index("[{}]".format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index("[{}]".format(tgt))) eos = tgt_dict.index("[{}]".format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 is_train_subset = split == getattr(self.args, "train_subset", None) if not is_train_subset: # if not training data set, use the first shard for valid and test paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] # infer langcode src, tgt = self.args.source_lang, self.args.target_lang """ this is mask_word_initial WordNoising uses mask_word_end or mask_bpe_cont probably easiest to write FlippedDataset that reverses sequences and use the standard pipeline load_langpair_dataset: find files by pattern load_indexed source maybe truncate load target check shard counts sample ratios bos, source_id load_alignments LangpairDataset constructor """ src_dataset, tgt_dataset = load_unpaired_langpair( data_path, split, src, self.src_dict, tgt, self.tgt_dict, combine=combine, dataset_impl=self.args.dataset_impl, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, truncate_source=self.args.truncate_source, prepend_bos=self.args.prepend_bos, ) if self.args.bpe_dropout > 0: src_dataset = DynamicGPT2BPEDropoutResampling( self.args, src_dataset, self.source_dictionary, dropout=self.args.bpe_dropout, ) # load backtranslation if is_train_subset and not self.args.skip_backtranslation_data: """ noised vs unnoised valdation set? they might converge at different times """ bt_src_dataset, bt_tgt_dataset = load_unpaired_langpair( # data_path, "{}.bt".format(split), src, self.src_dict, tgt, self.tgt_dict, data_path, "{}.bt".format(split), src, self.src_dict, tgt, self.tgt_dict, combine=combine, dataset_impl=self.args.dataset_impl, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, truncate_source=self.args.truncate_source, prepend_bos=self.args.prepend_bos, ) if self.args.bpe == "gpt2": mask_is_beginning_of_word = get_whole_word_mask( self.args, self.source_dictionary) mask_is_beginning_of_word = mask_is_beginning_of_word.numpy( ).astype(np.bool) # noiser = GPT2WordNoising( # self.src_dict, # mask_is_beginning_of_word, # self.args.max_word_shuffle_distance, # self.args.word_dropout_prob, # self.args.word_blanking_prob, # ) if self.args.bpe_dropout > 0: bt_src_dataset = DynamicGPT2BPEDropoutResampling( self.args, bt_src_dataset, self.source_dictionary, dropout=self.args.bpe_dropout, ) noiser = GPT2WordNoisingV2( self.src_dict, mask_is_beginning_of_word, self.args.max_word_shuffle_distance, self.args.word_dropout_prob, self.args.word_blanking_prob, ) bt_src_dataset = DynamicNoisingDataset( bt_src_dataset, self.src_dict, seed=1, noiser=noiser, ) # try: # from icecream import ic # ic.configureOutput(includeContext=True) # except ImportError: # Graceful fallback if IceCream isn't installed. # ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa # ic("gpt2 bbpe") # bpe = encoders.build_bpe(self.args) # def decode(foo): # return bpe.decode(self.src_dict.string(foo)) # def disp(foo): # return " ".join([bpe.decode(i) for i in self.src_dict.string(foo).split(" ")]) # # foo = [bpe.decode(str(i)) for i in range(0,1000)] # # doo = [bpe.decode((i)) for i in self.src_dict.symbols[4:1000]] # for i in range(5): # ic(_bt_src_dataset[i]) # ic(decode(_bt_src_dataset[i])) # ic(disp(_bt_src_dataset[i])) # ic(disp(bt_src_dataset[i])) # ic(bt_src_dataset[i]) # import pdb; pdb.set_trace() else: assert self.args.bpe_dropout <= 0, "BPE dropout not supported for this BPE scheme" # standard bpe with @@ as continuation marker bt_src_dataset = DynamicNoisingDataset( bt_src_dataset, self.src_dict, seed=1, max_word_shuffle_distance=self.args. max_word_shuffle_distance, word_dropout_prob=self.args.word_dropout_prob, word_blanking_prob=self.args.word_blanking_prob, ) # if self.append_backtranslation_tag: if self.args.tagged_backtranslation: bt_src_dataset = AppendTokenDataset( AppendTokenDataset( StripTokenDataset(bt_src_dataset, self.src_dict.eos()), self.bt_idx), self.src_dict.eos(), ) sample_ratios = [self.args.upsample_primary, 1] src_dataset = ConcatDataset([src_dataset, bt_src_dataset], sample_ratios) tgt_dataset = ConcatDataset([tgt_dataset, bt_tgt_dataset], sample_ratios) self.datasets[split] = LanguagePairDataset( src_dataset, src_dataset.sizes, self.src_dict, tgt_dataset, tgt_dataset.sizes, self.tgt_dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, align_dataset=None, eos=self.tgt_dict.eos(), num_buckets=self.args.num_batch_buckets, shuffle=(split not in ("test", "valid")), )