def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): if constraints is not None: raise NotImplementedError( "Constrained decoding with the multilingual_translation task is not supported" ) src_data = ListDataset(src_tokens, src_lengths) dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary) src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"] if self.args.lang_tok_replacing_bos_eos: dataset = self.data_manager.alter_dataset_langtok( dataset, src_eos=self.source_dictionary.eos(), src_lang=self.args.source_lang, tgt_eos=self.target_dictionary.eos(), tgt_lang=self.args.target_lang, src_langtok_spec=src_langtok_spec, tgt_langtok_spec=tgt_langtok_spec, ) else: dataset.src = self.data_manager.src_dataset_tranform_func( self.args.source_lang, self.args.target_lang, dataset=dataset.src, spec=src_langtok_spec, ) return dataset
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): if constraints is not None: # Though see Susanto et al. (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.325/ raise NotImplementedError( "Constrained decoding with the translation_lev task is not supported" ) return LanguagePairDataset( src_tokens, src_lengths, self.source_dictionary, append_bos=True )
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): return LanguagePairDataset( src_tokens, src_lengths, self.source_dictionary, tgt_dict=self.target_dictionary, constraints=constraints, )
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): src_lang_id = self.source_dictionary.index("[{}]".format( self.args.source_lang)) source_tokens = [] for s_t in src_tokens: s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)]) source_tokens.append(s_t) dataset = LanguagePairDataset( source_tokens, src_lengths, self.source_dictionary, tgt_dict=self.target_dictionary, constraints=constraints, ) return dataset
def language_pair_dataset(lang_pair): src, tgt = lang_pair.split("-") src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[ lang_pair] return self.alter_dataset_langtok( LanguagePairDataset( src_dataset, src_dataset.sizes, self.dicts[src], tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ), self.dicts[src].eos(), src, self.dicts[tgt].eos(), tgt, )
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): if constraints is not None: raise NotImplementedError( "Constrained decoding with the multilingual_translation task is not supported" ) lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang) return RoundRobinZipDatasets( OrderedDict([( lang_pair, self.alter_dataset_langtok( LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary), src_eos=self.source_dictionary.eos(), src_lang=self.args.source_lang, tgt_eos=self.target_dictionary.eos(), tgt_lang=self.args.target_lang, ), )]), eval_key=lang_pair, )
def load_langpair_dataset( self, data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, src_dataset_transform_func=lambda dataset: dataset, tgt_dataset_transform_func=lambda dataset: dataset, src_lang_id=None, tgt_lang_id=None, langpairs_sharing_datasets=None, ): norm_direction = "-".join(sorted([src, tgt])) if langpairs_sharing_datasets is not None: src_dataset = langpairs_sharing_datasets.get( (data_path, split, norm_direction, src), "NotInCache") tgt_dataset = langpairs_sharing_datasets.get( (data_path, split, norm_direction, tgt), "NotInCache") align_dataset = langpairs_sharing_datasets.get( (data_path, split, norm_direction, src, tgt), "NotInCache") # a hack: any one is not in cache, we need to reload them if (langpairs_sharing_datasets is None or src_dataset == "NotInCache" or tgt_dataset == "NotInCache" or align_dataset == "NotInCache" or split != getattr(self.args, "train_subset", None)): # source and target datasets can be reused in reversed directions to save memory # reversed directions of valid and test data will not share source and target datasets src_dataset, tgt_dataset, align_dataset = self.load_lang_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, max_source_positions=max_source_positions, prepend_bos=prepend_bos, load_alignments=load_alignments, truncate_source=truncate_source, ) src_dataset = src_dataset_transform_func(src_dataset) tgt_dataset = tgt_dataset_transform_func(tgt_dataset) if langpairs_sharing_datasets is not None: langpairs_sharing_datasets[(data_path, split, norm_direction, src)] = src_dataset langpairs_sharing_datasets[(data_path, split, norm_direction, tgt)] = tgt_dataset langpairs_sharing_datasets[(data_path, split, norm_direction, src, tgt)] = align_dataset if align_dataset is None: # no align data so flag the reverse direction as well in sharing langpairs_sharing_datasets[(data_path, split, norm_direction, tgt, src)] = align_dataset else: logger.info( f"Reusing source and target datasets of [{split}] {tgt}-{src} for reversed direction: " f"[{split}] {src}-{tgt}: src length={len(src_dataset)}; tgt length={len(tgt_dataset)}" ) return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes if tgt_dataset is not None else None, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, src_lang_id=src_lang_id, tgt_lang_id=tgt_lang_id, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, pad_to_multiple=1, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError("Dataset not found: {} ({})".format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info("{} {} {}-{} {} examples".format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index("[{}]".format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index("[{}]".format(tgt))) eos = tgt_dict.index("[{}]".format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, )
def load_dataset(self, split, epoch=1, **kwargs): """Load a dataset split.""" paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] def split_exists(split, src, tgt, lang): if src is not None: filename = os.path.join( data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) else: filename = os.path.join( data_path, "{}.{}-None.{}".format(split, src, tgt)) return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl) def load_indexed_dataset(path, dictionary): return data_utils.load_indexed_dataset(path, dictionary, self.args.dataset_impl) # load parallel datasets src_datasets, tgt_datasets = {}, {} if (self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train")): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split("-") if split_exists(split, src, tgt, src): prefix = os.path.join(data_path, "{}.{}-{}.".format(split, src, tgt)) elif split_exists(split, tgt, src, src): prefix = os.path.join(data_path, "{}.{}-{}.".format(split, tgt, src)) else: continue src_datasets[lang_pair] = load_indexed_dataset( prefix + src, self.dicts[src]) tgt_datasets[lang_pair] = load_indexed_dataset( prefix + tgt, self.dicts[tgt]) logger.info("parallel-{} {} {} examples".format( data_path, split, len(src_datasets[lang_pair]))) if len(src_datasets) == 0: raise FileNotFoundError("Dataset not found: {} ({})".format( split, data_path)) # back translation datasets backtranslate_datasets = {} if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and split.startswith("train"): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split("-") if not split_exists(split, tgt, None, tgt): raise FileNotFoundError( "Dataset not found: backtranslation {} ({})".format( split, data_path)) filename = os.path.join( data_path, "{}.{}-None.{}".format(split, tgt, tgt)) dataset = load_indexed_dataset(filename, self.dicts[tgt]) lang_pair_dataset_tgt = LanguagePairDataset( dataset, dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) lang_pair_dataset = LanguagePairDataset( dataset, dataset.sizes, src_dict=self.dicts[src], tgt=dataset, tgt_sizes=dataset.sizes, tgt_dict=self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) backtranslate_datasets[lang_pair] = BacktranslationDataset( tgt_dataset=self.alter_dataset_langtok( lang_pair_dataset_tgt, src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_lang=src, ), backtranslation_fn=self.backtranslators[lang_pair], src_dict=self.dicts[src], tgt_dict=self.dicts[tgt], output_collater=self.alter_dataset_langtok( lang_pair_dataset=lang_pair_dataset, src_eos=self.dicts[src].eos(), src_lang=src, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ).collater, ) logger.info("backtranslate-{}: {} {} {} examples".format( tgt, data_path, split, len(backtranslate_datasets[lang_pair]), )) self.backtranslate_datasets[ lang_pair] = backtranslate_datasets[lang_pair] # denoising autoencoder noising_datasets = {} if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None) and split.startswith("train"): for lang_pair in self.lang_pairs: _, tgt = lang_pair.split("-") if not split_exists(split, tgt, None, tgt): continue filename = os.path.join( data_path, "{}.{}-None.{}".format(split, tgt, tgt)) tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt]) tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt]) noising_dataset = NoisingDataset( tgt_dataset1, self.dicts[tgt], seed=1, max_word_shuffle_distance=self.args. max_word_shuffle_distance, word_dropout_prob=self.args.word_dropout_prob, word_blanking_prob=self.args.word_blanking_prob, ) noising_datasets[lang_pair] = self.alter_dataset_langtok( LanguagePairDataset( noising_dataset, tgt_dataset1.sizes, self.dicts[tgt], tgt_dataset2, tgt_dataset2.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ), src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ) logger.info("denoising-{}: {} {} {} examples".format( tgt, data_path, split, len(noising_datasets[lang_pair]), )) def language_pair_dataset(lang_pair): src, tgt = lang_pair.split("-") src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[ lang_pair] return self.alter_dataset_langtok( LanguagePairDataset( src_dataset, src_dataset.sizes, self.dicts[src], tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ), self.dicts[src].eos(), src, self.dicts[tgt].eos(), tgt, ) self.datasets[split] = RoundRobinZipDatasets( OrderedDict( [(lang_pair, language_pair_dataset(lang_pair)) for lang_pair in src_datasets.keys()] + [(_get_bt_dataset_key(lang_pair), dataset) for lang_pair, dataset in backtranslate_datasets.items()] + [(_get_denoising_dataset_key(lang_pair), dataset) for lang_pair, dataset in noising_datasets.items()]), eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), )