def build_dataset_for_inference(self, src_tokens, src_lengths, language="en_XX", **kwargs): """ Generate batches for inference. We prepend an eos token to src_tokens (or bos if `--add-bos-token` is set) and we append a <pad> to target. This is convenient both for generation with a prefix and LM scoring. """ dataset = StripTokenDataset( TokenBlockDataset( src_tokens, src_lengths, block_size=None, # ignored for "eos" break mode pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode="eos", ), # remove eos from (end of) target sequence self.source_dictionary.eos(), ) src_lang_idx = self.dictionary.index(lang_token(language)) src_dataset = PrependTokenDataset( dataset, token=((src_lang_idx or self.source_dictionary.bos()) if getattr( self.args, "add_bos_token", False) else self.source_dictionary.eos()), ) max_seq_len = max(src_lengths) + 1 tgt_dataset = AppendTokenDataset(dataset, token=self.source_dictionary.pad()) return NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, pad_length=max_seq_len, ), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, pad_length=max_seq_len, ), }, sizes=[np.array(src_lengths)], )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) dataset = StripTokenDataset(dataset, self.dictionary.eos()) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 2, # one less for <s> and one for </s> pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, document_sep_len=0, ) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) mask_whole_words = (get_whole_word_mask(self.args, self.source_dictionary) if self.args.mask_length != "subword" else None) self.datasets[split] = DenoisingDataset( dataset, dataset.sizes, self.dictionary, self.mask_idx, mask_whole_words, shuffle=self.args.shuffle_instance, seed=self.seed, args=self.args, ) logger.info( "Split: {0}, Loaded {1} samples of denoising_dataset".format( split, len(self.datasets[split]), ))
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """ TODO: - break_mode=",。" """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[epoch % len(paths)] def get_path(type, split): return os.path.join(data_path, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) return dataset dataset = make_dataset('input', self.dictionary) dataset = TruncateDataset( RStripTokenDataset(dataset, self.dictionary.eos()), self.args.tokens_per_sample - 2) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)。 # https://github.com/pytorch/fairseq/blob/master/fairseq/tasks/translation.py#L71 # https://github.com/pytorch/fairseq/blob/77983ee1a52c4e011e54cc6bfa5352b7811ec96d/fairseq/tasks/denoising.py#L127 dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) meta_dataset = make_dataset('meta', self.meta_dictionary) meta_dataset = StripTokenDataset( meta_dataset, id_to_strip=self.meta_dictionary.eos()) s2s_dataset = KnowledgeLanguagePairDataset.apply_mask( dataset, dataset.sizes, self.source_dictionary, meta=meta_dataset, meta_sizes=meta_dataset.sizes, meta_dict=self.meta_dictionary, shuffle=True, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, sub_task=self.args.sub_task, ) self.datasets[split] = s2s_dataset
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """ TODO: - break_mode=",。" """ paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 data_path = paths[epoch % len(paths)] def get_path(type, split): return os.path.join(data_path, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.cfg.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) return dataset dataset = make_dataset('input', self.dictionary) dataset = TruncateDataset( RStripTokenDataset(dataset, self.dictionary.eos()), self.cfg.tokens_per_sample - 2) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)。 dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) meta_dataset = make_dataset('meta', self.meta_dictionary) meta_dataset = StripTokenDataset( meta_dataset, id_to_strip=self.meta_dictionary.eos()) s2s_dataset = KnowledgeLanguagePairDataset.apply_mask( dataset, dataset.sizes, self.source_dictionary, meta=meta_dataset, meta_sizes=meta_dataset.sizes, meta_dict=self.meta_dictionary, shuffle=True, mask_idx=self.mask_idx, mask_prob=self.cfg.mask_prob, leave_unmasked_prob=self.cfg.leave_unmasked_prob, random_token_prob=self.cfg.random_token_prob, sub_task=self.cfg.sub_task, ) self.datasets[split] = s2s_dataset
def main(args): tokenizer = build_tokenizer(args) src_indices = get_indices(args.input_src, tokenizer) trg_indices = get_indices(args.input_trg, tokenizer) src_dataset = IndexDataset(src_indices) trg_dataset = IndexDataset(trg_indices) eos = tokenizer.sep_token_id bos = tokenizer.cls_token_id max_pos = args.max_pos datasets = [] src_dataset = TruncateDataset( StripTokenDataset(src_dataset, eos), max_pos - 2, ) trg_dataset = TruncateDataset( StripTokenDataset(trg_dataset, eos), max_pos - 2, ) src_dataset = PrependTokenDataset(src_dataset, bos) trg_dataset = PrependTokenDataset(trg_dataset, bos) src_dataset = AppendTokenDataset(src_dataset, eos) trg_dataset = AppendTokenDataset(trg_dataset, eos) print("| get all items ...") # items = [i for i in tqdm(dataset)] items = [] for t1, t2 in tqdm(zip(src_dataset, trg_dataset)): items.append(t1) items.append(t2) print("| writing binary file ...") prefix = os.path.join(args.output, "train.0") save_items(items, prefix, len(tokenizer))
def _load_dataset_split(self, split, epoch, combine): paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.cfg.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, split_path) ) dataset = StripTokenDataset(dataset, self.dictionary.eos()) dataset = maybe_shorten_dataset( dataset, split, self.cfg.shorten_data_split_list, self.cfg.shorten_method, self.cfg.tokens_per_sample, self.cfg.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.cfg.tokens_per_sample - 2, # one less for <s> and one for </s> pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.cfg.sample_break_mode, document_sep_len=0, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) return dataset
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): dataset = StripTokenDataset( TokenBlockDataset( src_tokens, src_lengths, block_size=None, # ignored for "eos" break mode pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode="eos", ), # remove eos from (end of) target sequence self.source_dictionary.eos(), ) src_dataset = PrependTokenDataset( dataset, token=(self.source_dictionary.bos() if getattr( self.args, "add_bos_token", False) else self.source_dictionary.eos()), ) tgt_dataset = AppendTokenDataset(dataset, token=self.source_dictionary.pad()) return NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": PadDataset(src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": PadDataset(tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False), }, sizes=[np.array(src_lengths)], )
def load_dataset(lang, lang_dict, prefix, dataset_length, sample_ratios=None): """ Function to load additional dataset and deal with all parameters. Easier than copying redudant code for each dataset. Requires src_dataset to provide the length and sample_ratios. """ lang_datasets = [] lang_dataset = data_utils.load_indexed_dataset(prefix + lang, lang_dict, dataset_impl) if lang_dataset is not None: lang_datasets.append(lang_dataset) assert dataset_length == len(lang_datasets) or len(lang_datasets) == 0 if dataset_length == 1: lang_dataset = lang_datasets[0] if len(lang_datasets) > 0 else None else: assert sample_ratios is not None if len(lang_datasets) > 0: lang_dataset = ConcatDataset(lang_datasets, sample_ratios) else: lang_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( lang_dict, "bos_index") if lang_dataset is not None: lang_dataset = PrependTokenDataset(lang_dataset, lang_dict.bos()) eos = None if append_source_id: if lang_dataset is not None: lang_dataset = AppendTokenDataset( lang_dataset, lang_dict.index('[{}]'.format(lang))) lang_dataset_sizes = lang_dataset.sizes if lang_dataset is not None else None return lang_dataset, lang_dataset_sizes
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, pad_to_multiple=1, # Masked LM parameters. mask_idx: int = 0, seed: int = 1, mask_prob: float = 0.01, leave_unmasked_prob: float = 0.0, random_token_prob: float = 0.0, freq_weighted_replacement: bool = False, mask_whole_words: torch.Tensor = None, mask_multiple_length: int = 1, mask_stdev: float = 0.0, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join( data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join( data_path, "{}.{}-{}.".format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join( data_path, "{}.{}-{}.".format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, data_path) ) src_dataset = data_utils.load_indexed_dataset( prefix + src, src_dict, dataset_impl ) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset( prefix + tgt, tgt_dict, dataset_impl ) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info( "{} {} {}-{} {} examples".format( data_path, split_k, src, tgt, len(src_datasets[-1]) ) ) if not combine: break # logger.info('Length of Source DataSets: {}'.format(len(src_datasets))) assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset( src_dataset, src_dict.index("[{}]".format(src)) ) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index("[{}]".format(tgt)) ) eos = tgt_dict.index("[{}]".format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join( data_path, "{}.align.{}-{}".format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl ) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None # mask source dataset. src_dataset, masked_src_dataset = MaskTokensDataset.apply_mask( src_dataset, src_dict, pad_idx=src_dict.pad(), mask_idx=mask_idx, seed=seed, mask_prob=mask_prob, leave_unmasked_prob=leave_unmasked_prob, random_token_prob=random_token_prob, freq_weighted_replacement=freq_weighted_replacement, mask_whole_words=mask_whole_words, mask_multiple_length=mask_multiple_length, mask_stdev=mask_stdev, ) # Print samples. # if split == 'valid': # print(src_dataset[1]) # print(masked_src_dataset[1]) return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, # for Mask LM loss calculation. masked_src_dataset, masked_src_dataset.sizes, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, user_context_path, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format( data_path, split_k, src, tgt, len(src_datasets[-1]) )) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset(tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) eos = tgt_dict.index('[{}]'.format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None logger.info('Loading user big issues from {}'.format(split)) user_context = pickle.load(open(os.path.join(user_context_path, '{}_with_claim.users_big_issues.pkl'.format(split)), 'rb')) return ExtendedLanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, user_context, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=align_dataset, eos=eos )
def load_langpair_with_additional_data_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, add_dir=None, add_lang=None, add_dict=None, userdirname=None, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_datasets.append( data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl)) print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) add_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') additional_data_path = f'{data_path}/{add_dir}' # infer langcode if split_exists(split_k, add_lang, 'None', add_lang, additional_data_path): prefix = os.path.join( additional_data_path, '{}.{}-{}.'.format(split_k, add_lang, 'None')) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, additional_data_path)) add_dataset = data_utils.load_indexed_dataset(prefix + add_lang, add_dict, dataset_impl) if truncate_source: add_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(add_dataset, add_dict.eos()), max_source_positions - 1, ), add_dict.eos(), ) add_datasets.append(add_dataset) print('| {} {} {}-{} {} examples'.format(data_path, split_k, add_lang, 'None', len(add_datasets[-1]))) if not combine: break if len(add_datasets) == 1: add_dataset = add_datasets[0] else: raise Exception # sample_ratios = [1] * len(src_datasets) # sample_ratios[0] = upsample_primary # src_dataset = ConcatDataset(src_datasets, sample_ratios) # tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) if add_dataset: import sys module_parent, module_name = os.path.split( os.path.abspath(userdirname)) add_user_module(userdirname) return sys.modules[ module_name].data.LanguagePairWithAdditionalDataDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, add_dataset, add_dataset.sizes, add_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=align_dataset, ) else: return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=align_dataset, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None # these features are not yet implemented for the cluster code if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset( src_dataset, src_dict.index('[{}]'.format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) eos = tgt_dict.index('[{}]'.format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, ) else: # sample_ratios = [1] * len(src_datasets) # sample_ratios[0] = upsample_primary # src_dataset = ConcatDataset(src_datasets, sample_ratios) # if len(tgt_datasets) > 0: # tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) # else: # tgt_dataset = None datasets = [] eos = None align_dataset = None for i in range(0, len(src_datasets)): src_dataset = src_datasets[i] tgt_dataset = tgt_datasets[i] tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None datasets.append( LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, )) return datasets
def load_lang_dataset( self, data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, max_source_positions, prepend_bos=False, load_alignments=False, truncate_source=False, ): src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") # infer langcode if self.split_exists(split_k, src, tgt, src, data_path, dataset_impl): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) elif self.split_exists(split_k, tgt, src, src, data_path, dataset_impl): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) else: if k > 0: break else: logger.error( f"Dataset not found: {data_path}, {split_k}, {src}, {tgt}" ) raise FileNotFoundError( "Dataset not found: {} ({})".format(split, data_path) ) src_dataset = self.load_data(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_datasets.append(self.load_data(prefix + tgt, tgt_dict, dataset_impl)) logger.info( "{} {} {}-{} {} examples".format( data_path, split_k, src, tgt, len(src_datasets[-1]) ) ) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) align_dataset = None if load_alignments: align_path = os.path.join( data_path, "{}.align.{}-{}".format(split, src, tgt) ) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl ) return src_dataset, tgt_dataset, align_dataset
def load_langpair_dataset(data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, load_cls_labels=False, load_cls_indices=False, load_sample_weights=False, truncate_source=False, append_source_id=False, shuffle=True): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None src_prepended_bos = False if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) src_prepended_bos = True eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) eos = tgt_dict.index('[{}]'.format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None cls_dataset = None if load_cls_labels: cls_labels_path = os.path.join(data_path, '{}.cls'.format(split)) if indexed_dataset.dataset_exists(cls_labels_path, impl=dataset_impl): cls_dataset = data_utils.load_indexed_dataset( cls_labels_path, None, dataset_impl) if truncate_source: cls_dataset = AppendTokenDataset( TruncateDataset( TruncateLastElementDataset(cls_dataset), max_source_positions - 1, ), -1, # will ignore -1 label in training ) if src_prepended_bos: cls_dataset = PrependTokenDataset(cls_dataset, -1) else: print("cls_labels dataset NOT FOUND!", cls_labels_path) cls_indices_dataset = None if load_cls_indices: cls_indices_path = os.path.join(data_path, '{}.cls_ind'.format(split)) if indexed_dataset.dataset_exists(cls_indices_path, impl=dataset_impl): cls_indices_dataset = data_utils.load_indexed_dataset( cls_indices_path, None, dataset_impl) sample_weights = None if load_sample_weights: weights_file = os.path.join( data_path, '{}.{}-{}.weights.npy'.format(split, src, tgt)) assert os.path.exists(weights_file) with open(weights_file, 'rb') as f: sample_weights = np.load(f) logger.info('Loaded {} weights from {}'.format(len(sample_weights), weights_file)) return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=align_dataset, eos=eos, cls_dataset=cls_dataset, cls_indices_dataset=cls_indices_dataset, sample_weights=sample_weights, shuffle=shuffle, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, pad_to_multiple=1, add_lang_token=False, ): def split_exists(split, src, tgt, lang, data_path): logger.info( os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))) filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) def split_exists_self(split, src, data_path): logger.info( os.path.join(data_path, "{}.{}-{}.{}".format(split, src, src, src))) filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, src, src)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) def split_exists_valid(split, lang, data_path): logger.info(os.path.join(data_path, "{}.{}".format(split, lang))) filename = os.path.join(data_path, "{}.{}".format(split, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") # print(split_k, src, tgt, src, data_path) prefix_src = None prefix_tgt = None if not "-" in split_k: # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({}) {} {}".format( split, data_path, src, tgt)) else: # infer langcode if split_exists_valid(split_k, src, data_path): prefix = os.path.join(data_path, split_k + ".") else: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({}) ".format(split, data_path)) if prefix_src != None: prefix = prefix_src src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) if prefix_tgt != None: prefix = prefix_tgt tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info("{} {} {}-{} {} examples".format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary logger.info("::::data sample_ratios:{}".format(sample_ratios)) src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index("[{}]".format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index("[{}]".format(tgt))) eos = tgt_dict.index("[{}]".format(tgt)) eos = None if add_lang_token: src_dataset = PrependTokenDataset(src_dataset, src_dict.index("[{}]".format(src))) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset( tgt_dataset, tgt_dict.index("[{}]".format(tgt))) align_dataset = None if load_alignments: align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, )
def load_ape_dataset( data_path, split, src_dict, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, input_type='src_only', src_type="src", ): """ ignoring src and tgt name. Assume $split.src, $split.mt, and $split.pe exist """ src = src_type mt = "mt" tgt = "pe" term = "term" src_factor = src_type + "_embed" mt_factor = "mt_embed" def split_exists(split, lang, data_path): filename = os.path.join(data_path, '{}.{}'.format(split, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) def load_dataset(lang, lang_dict, prefix, dataset_length, sample_ratios=None): """ Function to load additional dataset and deal with all parameters. Easier than copying redudant code for each dataset. Requires src_dataset to provide the length and sample_ratios. """ lang_datasets = [] lang_dataset = data_utils.load_indexed_dataset(prefix + lang, lang_dict, dataset_impl) if lang_dataset is not None: lang_datasets.append(lang_dataset) assert dataset_length == len(lang_datasets) or len(lang_datasets) == 0 if dataset_length == 1: lang_dataset = lang_datasets[0] if len(lang_datasets) > 0 else None else: assert sample_ratios is not None if len(lang_datasets) > 0: lang_dataset = ConcatDataset(lang_datasets, sample_ratios) else: lang_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( lang_dict, "bos_index") if lang_dataset is not None: lang_dataset = PrependTokenDataset(lang_dataset, lang_dict.bos()) eos = None if append_source_id: if lang_dataset is not None: lang_dataset = AppendTokenDataset( lang_dataset, lang_dict.index('[{}]'.format(lang))) lang_dataset_sizes = lang_dataset.sizes if lang_dataset is not None else None return lang_dataset, lang_dataset_sizes src_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, data_path): prefix = os.path.join(data_path, '{}.'.format(split_k)) elif split_exists(split_k, mt, data_path): prefix = os.path.join(data_path, '{}.'.format(split_k)) elif split_exists(split_k, tgt, data_path): prefix = os.path.join(data_path, '{}.'.format(split_k)) elif split_exists(split_k, term, data_path): prefix = os.path.join(data_path, '{}.'.format(split_k)) elif split_exists(split_k, src_factor, data_path): prefix = os.path.join(data_path, '{}.'.format(split_k)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) if not combine: break dataset_length = len(src_datasets) sample_ratios = None if len(src_datasets) == 1: src_dataset = src_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src))) eos = tgt_dict.index('[{}]'.format(tgt)) align_dataset = None mt_dataset, mt_dataset_sizes = load_dataset(mt, tgt_dict, prefix, dataset_length, sample_ratios=sample_ratios) tgt_dataset, tgt_dataset_sizes = load_dataset(tgt, tgt_dict, prefix, dataset_length, sample_ratios=sample_ratios) term_dataset, term_dataset_sizes = load_dataset( term, tgt_dict, prefix, dataset_length, sample_ratios=sample_ratios) src_factor_dataset, src_factor_dataset_sizes = load_dataset( src_factor, tgt_dict, prefix, dataset_length, sample_ratios=sample_ratios) mt_factor_dataset, mt_factor_dataset_sizes = load_dataset( mt_factor, tgt_dict, prefix, dataset_length, sample_ratios=sample_ratios) logger.info('{} {} {} examples'.format(data_path, split_k, len(src_datasets[-1]))) return APEDataset(src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, mt_dataset, mt_dataset_sizes, term_dataset, term_dataset_sizes, src_factor_dataset, src_factor_dataset_sizes, mt_factor_dataset, mt_factor_dataset_sizes, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, input_type=input_type)
def load_seq_sql_dataset(data_path, split, src, src_dict, prev_src_dict, sql, sql_dict, prev_sql_dict, encoder_embed_path, encoder_embed_dim, decoder_embed_path, decoder_embed_dim, encoder_random_embedding_path, decoder_random_embedding_path, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, truncate_source, prepend_bos): src_datasets = [] sql_datasets = [] prefix = os.path.join(data_path, split) src_dataset = data_utils.load_indexed_dataset(prefix + '.' + src, src_dict, dataset_impl) #col_sizes = get_col_sizes(prefix + '.col') if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) sql_datasets.append( data_utils.load_indexed_dataset(prefix + '.' + sql, sql_dict, dataset_impl)) assert len(src_datasets) == len(sql_datasets) if len(src_datasets) == 1: src_dataset, sql_dataset = src_datasets[0], sql_datasets[0] else: #not implemented sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) sql_dataset = ConcatDataset(sql_datasets, sample_ratios) if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( sql_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) sql_dataset = PrependTokenDataset(sql_dataset, sql_dict.bos()) return Seq2SqlPairDataSet( src_dataset, src_dataset.sizes, src_dict, prev_src_dict, sql_dataset, sql_dataset.sizes, sql_dict, prev_sql_dict, encoder_embed_path, encoder_embed_dim, decoder_embed_path, decoder_embed_dim, encoder_random_embedding_path, decoder_random_embedding_path, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, pad_to_multiple=1, plus_encoder_loss=False, add_langs=None, shuffle_lang_pair=None, args=None, word_trans_dict=None, word_align_dict=None, policy_ratio_dicts=None, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) def split_exists_valid(split, lang, data_path): filename = os.path.join(data_path, "{}.{}".format(split, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") if not "-" in split_k: # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({}) {} {}".format(split, data_path, src, tgt) ) else: # for multi-valid if split_exists_valid( split_k, src, data_path): prefix = os.path.join(data_path, split_k+".") else: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({}) ".format(split, data_path) ) src_dataset = data_utils.load_indexed_dataset( prefix + src, src_dict, dataset_impl ) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) # for monolingual instances. if src == tgt: tgt_dataset = copy.deepcopy(src_dataset) else: tgt_dataset = data_utils.load_indexed_dataset( prefix + tgt, tgt_dict, dataset_impl ) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None # add src and tag lang id on the biganing of sens. if add_langs: src_dataset = PrependTokenDataset( src_dataset, src_dict.index("[{}]".format(src)) ) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset( tgt_dataset, tgt_dict.index("[{}]".format(tgt)) ) align_dataset = None if load_alignments: align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl ) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return DDenoisingPairDatasetDynaReplace( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, plus_encoder_loss=plus_encoder_loss, add_langs=add_langs, shuffle_lang_pair=shuffle_lang_pair, args=args , word_trans_dict=word_trans_dict , word_align_dict=word_align_dict, policy_ratio_dicts= policy_ratio_dicts, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, pad_to_multiple=1, prepend_bos_src=None, bert_model_name=None, bart_model_name=None, electra_model_name=None, electra_pretrain=False, denoising=False, masking=False, extra_data=False, input_mapping=False, mask_ratio=None, random_ratio=None, insert_ratio=None, rotate_ratio=None, permute_sentence_ratio=None, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name, do_lower_case=False) if denoising: bart_tokenizer = AutoTokenizer.from_pretrained(bart_model_name, do_lower_case=False) #bart_tokenizer = BartTokenizer.from_pretrained(bart_model_name, do_lower_case=False) if electra_pretrain: electra_tokenizer = ElectraTokenizer.from_pretrained( electra_model_name) srcbert_datasets = [] extra_datasets = [] extra_bert_datasets = [] extra_bert_mapping_datasets = [] extra_bart_datasets = [] extra_bart_mapping_datasets = [] if denoising: srcbart_datasets = [] if electra_pretrain: srcelectra_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) bertprefix = os.path.join( data_path, '{}.bert.{}-{}.'.format(split_k, src, tgt)) bert_mapping_prefix = os.path.join( data_path, '{}.bert.map.{}-{}.'.format(split_k, src, tgt)) if denoising: bartprefix = os.path.join( data_path, '{}.bart.{}-{}.'.format(split_k, src, tgt)) bart_mapping_prefix = os.path.join( data_path, '{}.bart.map.{}-{}.'.format(split_k, src, tgt)) if electra_pretrain: electraprefix = os.path.join( data_path, '{}.electra.{}-{}.'.format(split_k, src, tgt)) electra_mapping_prefix = os.path.join( data_path, '{}.electra.map.{}-{}.'.format(split_k, src, tgt)) if extra_data: extraprefix = os.path.join( data_path, '{}.extra.{}-{}.'.format(split_k, src, tgt)) extra_bert_prefix = os.path.join( data_path, '{}.extra.bert.{}-{}.'.format(split_k, src, tgt)) extra_bert_mapping_prefix = os.path.join( data_path, '{}.extra.bert.map.{}-{}.'.format(split_k, src, tgt)) extra_bart_prefix = os.path.join( data_path, '{}.extra.bart.{}-{}.'.format(split_k, src, tgt)) extra_bart_mapping_prefix = os.path.join( data_path, '{}.extra.bart.map.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) bertprefix = os.path.join( data_path, '{}.bert.{}-{}.'.format(split_k, tgt, src)) bert_mapping_prefix = os.path.join( data_path, '{}.bert.map.{}-{}.'.format(split_k, src, tgt)) if denoising: bartprefix = os.path.join( data_path, '{}.bart.{}-{}.'.format(split_k, tgt, src)) bart_mapping_prefix = os.path.join( data_path, '{}.bart.map.{}-{}.'.format(split_k, src, tgt)) if electra_pretrain: electraprefix = os.path.join( data_path, '{}.electra.{}-{}.'.format(split_k, src, tgt)) electra_mapping_prefix = os.path.join( data_path, '{}.electra.map.{}-{}.'.format(split_k, src, tgt)) if extra_data: extraprefix = os.path.join( data_path, '{}.extra.{}-{}.'.format(split_k, src, tgt)) extra_bert_prefix = os.path.join( data_path, '{}.extra.bert.{}-{}.'.format(split_k, src, tgt)) extra_bert_mapping_prefix = os.path.join( data_path, '{}.extra.bert.map.{}-{}.'.format(split_k, src, tgt)) extra_bart_prefix = os.path.join( data_path, '{}.extra.bart.{}-{}.'.format(split_k, src, tgt)) extra_bart_mapping_prefix = os.path.join( data_path, '{}.extra.bart.map.{}-{}.'.format(split_k, src, tgt)) else: if k > 0: break else: raise FileNotFoundError("Dataset not found: {} ({})".format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) # srcbert_datasets.append(indexed_dataset.make_dataset(bertprefix + src, impl=dataset_impl, # fix_lua_indexing=True, )) # if denoising: # srcbart_datasets.append(indexed_dataset.make_dataset(bartprefix + src, impl=dataset_impl, # fix_lua_indexing=True, )) # if extra_data: # extra_datasets.append(indexed_dataset.make_dataset(extraprefix + src, impl=dataset_impl, # fix_lua_indexing=True, )) srcbert_datasets.append( data_utils.load_indexed_dataset( bertprefix + src, dataset_impl=dataset_impl, )) if denoising: srcbart_datasets.append( data_utils.load_indexed_dataset( bartprefix + src, dataset_impl=dataset_impl, )) if electra_pretrain: srcelectra_datasets.append( data_utils.load_indexed_dataset( electraprefix + src, dataset_impl=dataset_impl, )) if extra_data and split == 'train': extra_datasets.append( data_utils.load_indexed_dataset( extraprefix + src, dataset_impl=dataset_impl, )) extra_bert_datasets.append( data_utils.load_indexed_dataset( extra_bert_prefix + src, dataset_impl=dataset_impl, )) extra_bert_mapping_datasets.append( data_utils.load_indexed_dataset( extra_bert_mapping_prefix + src, dataset_impl=dataset_impl, )) extra_bart_datasets.append( data_utils.load_indexed_dataset( extra_bart_prefix + src, dataset_impl=dataset_impl, )) extra_bart_mapping_datasets.append( data_utils.load_indexed_dataset( extra_bart_mapping_prefix + src, dataset_impl=dataset_impl, )) #import pdb; pdb.set_trace() assert extra_datasets != [] or extra_bert_datasets != [] or extra_bert_mapping_datasets != [] or extra_bart_datasets != [] or extra_bart_mapping_datasets != [] #extra_datasets = extra_datasets[0] #import pdb; pdb.set_trace() src_datasets[-1] = PrependTokenDataset(src_datasets[-1], token=src_dict.bos_index) if extra_data and split == 'train': extra_datasets[-1] = PrependTokenDataset(extra_datasets[-1], token=src_dict.bos_index) if denoising is True: if input_mapping is True and split == 'train': bart_mapping_dataset = data_utils.load_indexed_dataset( bart_mapping_prefix + src, dataset_impl=dataset_impl) else: bart_mapping_dataset = None src_datasets[-1] = DenoisingBartDataset( src_datasets[-1], src_datasets[-1].sizes, src_dict, srcbart_datasets[-1], srcbart_datasets[-1].sizes, bart_tokenizer, map_dataset=bart_mapping_dataset, mask_ratio=mask_ratio, random_ratio=random_ratio, insert_ratio=insert_ratio, rotate_ratio=rotate_ratio, permute_sentence_ratio=permute_sentence_ratio, ) if electra_pretrain is True: if input_mapping is True and split == 'train': electra_mapping_dataset = data_utils.load_indexed_dataset( electra_mapping_prefix + src, dataset_impl=dataset_impl) else: electra_mapping_dataset = None src_datasets[-1] = ElectrapretrainDataset( src_datasets[-1], src_datasets[-1].sizes, src_dict, srcelectra_datasets[-1], srcelectra_datasets[-1].sizes, electra_tokenizer, map_dataset=electra_mapping_dataset, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, ) if masking is True: if input_mapping is True and split == 'train': #bert_mapping_dataset = indexed_dataset.make_dataset(bert_mapping_prefix + src, impl=dataset_impl, fix_lua_indexing=True) bert_mapping_dataset = data_utils.load_indexed_dataset( bert_mapping_prefix + src, dataset_impl=dataset_impl) else: bert_mapping_dataset = None src_datasets[-1] = MaskingDataset( src_datasets[-1], src_datasets[-1].sizes, src_dict, srcbert_datasets[-1], srcbert_datasets[-1].sizes, bert_tokenizer, map_dataset=bert_mapping_dataset, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, ) if extra_data is True and split == 'train': assert input_mapping is True src_datasets[-1] = MaskingExtraDataset( src_datasets[-1], src_datasets[-1].sizes, src_dict, extra_datasets[-1], extra_datasets[-1].sizes, extra_bert_datasets[-1], extra_bert_datasets[-1].sizes, bert_tokenizer, map_dataset=extra_bert_mapping_datasets[-1], left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, ) src_datasets[-1] = DenoisingBartExtraDataset( src_datasets[-1], src_datasets[-1].sizes, src_dict, extra_datasets[-1], extra_datasets[-1].sizes, extra_bart_datasets[-1], extra_bart_datasets[-1].sizes, bart_tokenizer, map_dataset=extra_bart_mapping_datasets[-1], ) logger.info("{} {} {}-{} {} examples".format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None # srcbert_datasets = srcbert_datasets[0] # if denoising: # srcbart_datasets = srcbart_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) elif prepend_bos_src is not None: logger.info(f"prepending src bos: {prepend_bos_src}") src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index("[{}]".format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index("[{}]".format(tgt))) eos = tgt_dict.index("[{}]".format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None src_bart_dataset = None src_bert_dataset = None src_electra_dataset = None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, masking, src_bert_dataset, denoising, src_bart_dataset, src_electra_dataset, #extra_datasets, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, load_dependency=False, gold_dependency=False, dependency_with_input=False, truncate_source=False, remove_eos_from_source=True, append_source_id=False, num_buckets=0, shuffle=True, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) eos = tgt_dict.index('[{}]'.format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) src_dep, tgt_dep = None, None if load_dependency: src_dep_path = os.path.join(data_path, '{}.dep.{}'.format(split, src)) tgt_dep_path = os.path.join(data_path, '{}.dep.{}'.format(split, tgt)) if os.path.exists(src_dep_path): src_deps = [] with open(src_dep_path, 'r') as src_dep_data: for h in src_dep_data: src_deps.append( torch.LongTensor( [[i, int(x) - 1] for i, x in enumerate(h.strip().split())])) src_dep = RawLabelDataset(src_deps) if os.path.exists(tgt_dep_path): tgt_deps = [] with open(tgt_dep_path, 'r') as tgt_dep_data: for h in tgt_dep_data: tgt_deps.append( torch.LongTensor( [[i, int(x) - 1] for i, x in enumerate(h.strip().split())])) tgt_dep = RawLabelDataset(tgt_deps) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDatasetWithDependency( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, remove_eos_from_source=remove_eos_from_source, align_dataset=align_dataset, eos=eos, src_dep=src_dep, tgt_dep=tgt_dep, dependency_with_input=dependency_with_input, gold_dependency=gold_dependency, num_buckets=num_buckets, shuffle=shuffle, )
# create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, tokens_per_sample - 2, # one less for <s> and one for </s> pad=dictionary.pad(), eos=dictionary.eos(), break_mode=args.sample_break_mode, document_sep_len=0, ) assert len(dataset) == prev_size # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, source_dictionary.bos()) dataset = AppendTokenDataset(dataset, source_dictionary.eos()) mask_whole_words = (get_whole_word_mask(args, source_dictionary) if mask_length != 'subword' else None) bpe = encoders.build_bpe(args) eoh = dictionary.indices[bpe.encode('</h>')] denoising_dataset = DenoisingDataset(dataset, dataset.sizes, dictionary, mask_idx, mask_whole_words, shuffle=False, seed=seed, args=args, eoh=eoh)
def load_pos_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False ): # Check the existence of the file def split_exists(split, src, tgt, lang, data_path): filename = os.path.join( data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode (from a->b or from b->a) if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join( data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join( data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format(split, data_path)) src_dataset = data_utils.load_indexed_dataset( prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset( prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format( data_path, split_k, src, tgt, len(src_datasets[-1]) )) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset( src_dataset, src_dict.index('[{}]'.format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) eos = tgt_dict.index('[{}]'.format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join( data_path, '{}.align.{}-{}'.format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None # Load POS Graph def graph_exist(data_path, split, src, tgt, lang): existence = True row_path = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, src)) + '.row' col_path = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, src)) + '.col' anchor_path = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, src)) + '.anchor' if(not os.path.exists(row_path)): existence = False elif(not os.path.exists(col_path)): existence = False elif(not os.path.exists(anchor_path)): existence = False return existence pos_graphs_l = [] pos_anchors_l = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') existence = graph_exist(data_path, split_k, src, tgt, src) if(not existence): if(k == 0): raise FileNotFoundError('POS Graph Dataset not found') if(k > 0): break pos_rows = codecs.open(os.path.join( data_path, '{}.{}-{}.{}'.format(split_k, src, tgt, src)) + '.row', 'r', 'utf-8').readlines() pos_cols = codecs.open(os.path.join( data_path, '{}.{}-{}.{}'.format(split_k, src, tgt, src)) + '.col', 'r', 'utf-8').readlines() pos_graphs = [] print('Loading graphs' + '.' * 50) assert len(pos_cols) == len(pos_rows) pbar = tqdm(total=len(pos_cols)) for n, (row, col) in enumerate(zip(pos_rows, pos_cols)): pos_row = [eval(i) for i in row.strip().split()] pos_col = [eval(i) for i in col.strip().split()] pos_graphs.append((pos_row, pos_col)) pbar.update() pbar.close() pos_anchors = codecs.open(os.path.join( data_path, '{}.{}-{}.{}'.format(split_k, src, tgt, src)) + '.anchor', 'r', 'utf-8').readlines() anchors = [] for line in pos_anchors: anchors.append([eval(i) for i in line.strip().split()]) pos_graphs_l.extend(pos_graphs) pos_anchors_l.extend(anchors) assert (len(pos_anchors_l) == len(pos_graphs_l)) and (len(src_dataset.sizes) == len(pos_anchors_l)) return POSGraphLanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, pos_anchors_l, pos_graphs_l, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=align_dataset, eos=eos )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(":") assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) if self.langs is None: languages = sorted( [ name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name)) ] ) else: languages = self.langs.split(",") for name in languages: p = os.path.join(data_path, name) assert os.path.exists(p), "data not found: {}".format(p) logger.info("Training on {0} languages: {1}".format(len(languages), languages)) logger.info( "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)} ) mask_whole_words = get_whole_word_mask(self.args, self.dictionary) language_without_segmentations = self.args.no_whole_word_mask_langs.split(",") lang_datasets = [] for language in languages: split_path = os.path.join(data_path, language, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, split_path) ) end_token = ( self.source_dictionary.index("[{}]".format(language)) if self.args.add_lang_token else self.source_dictionary.eos() ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 2, # one less for <s> pad=self.source_dictionary.pad(), eos=end_token, break_mode=self.args.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, end_token) lang_mask_whole_words = ( mask_whole_words if language not in language_without_segmentations else None ) lang_dataset = DenoisingDataset( dataset, dataset.sizes, self.dictionary, self.mask_idx, lang_mask_whole_words, shuffle=self.args.shuffle_instance, seed=self.seed, args=self.args, eos=None if not self.args.add_lang_token else self.source_dictionary.index("[{}]".format(language)), ) lang_datasets.append(lang_dataset) dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) logger.info( "loaded total {} blocks for all languages".format( int(dataset_lengths.sum()), ) ) if split == self.args.train_subset: # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) logger.info( "Sample probability by language: {}".format( { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) } ) ) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths logger.info( "Up/Down Sampling ratio by language: {}".format( { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) } ) ) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] >= 1.0, ) for i, d in enumerate(lang_datasets) ] dataset = ConcatDataset( resampled_lang_datasets, ) else: dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): split_name = split + "_" + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( split, ",".join(lang_splits) ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )
def load_langpair_dataset(data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, explicit_str_att=False): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] sent_id_datasets = [] chains_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) pre_src_dataset = data_utils.load_indexed_dataset( prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset(TruncateDataset( StripTokenDataset(pre_src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), split=split) src_datasets.append(src_dataset) else: src_datasets.append(pre_src_dataset) if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) sent_id_dataset = SentIdsRawDataset(prefix + 'source.sentids') if truncate_source: sent_id_dataset = AppendLastTokenDataset(TruncateNDimDataset( StripTokenFromMaskDataset(sent_id_dataset, pre_src_dataset, src_dict.eos()), max_source_positions - 1, dim=1), split=split) sent_id_datasets.append(sent_id_dataset) if explicit_str_att: chains_dataset = ChainsDataset(prefix + 'source.chains') chains_datasets.append(chains_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert (len(src_datasets) == len(tgt_datasets) and len(src_datasets) == len(sent_id_datasets)) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] sent_id_dataset = sent_id_datasets[0] chains_dataset = chains_datasets if explicit_str_att: chains_dataset = chains_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary # src_dataset = ConcatDataset(src_datasets, sample_ratios) # if len(tgt_datasets) > 0: # tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) # else: # tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) sent_id_dataset = PrependFirstTokenDataset(sent_id_dataset) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src)), split=split) sent_id_dataset = AppendLastTokenDataset(sent_id_dataset, split=split) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) eos = tgt_dict.index('[{}]'.format(tgt)) align_dataset = None # if load_alignments: # align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) # chains = torch.load(load_alignments) # if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): # align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return StructSumDataset(src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=align_dataset, eos=eos, src_sent_ids=sent_id_dataset, split=split, chains_dataset=chains_dataset, explicit_str_att=explicit_str_att)
def load_generation_pair_dataset( data_path, split, tgt, src_dict, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, common_eos=None ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, "src", "tgt", "src", data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, "src", "tgt")) elif split_exists(split_k, "tgt", "src", "src", data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, "tgt", "src")) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + "src", src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + "tgt", tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format( data_path, split_k, "src", "tgt", len(src_datasets[-1]) )) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: if common_eos is not None: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(common_eos))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset(tgt_dataset, tgt_dict.index('[{}]'.format(common_eos))) eos = tgt_dict.index('[{}]'.format(common_eos)) bos = tgt_dict.index('[{}]'.format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, "src", "tgt")) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return GenerationPairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=align_dataset, eos=eos, bos=bos )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 is_train_subset = split == getattr(self.args, "train_subset", None) if not is_train_subset: # if not training data set, use the first shard for valid and test paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] # infer langcode src, tgt = self.args.source_lang, self.args.target_lang """ this is mask_word_initial WordNoising uses mask_word_end or mask_bpe_cont probably easiest to write FlippedDataset that reverses sequences and use the standard pipeline load_langpair_dataset: find files by pattern load_indexed source maybe truncate load target check shard counts sample ratios bos, source_id load_alignments LangpairDataset constructor """ src_dataset, tgt_dataset = load_unpaired_langpair( data_path, split, src, self.src_dict, tgt, self.tgt_dict, combine=combine, dataset_impl=self.args.dataset_impl, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, truncate_source=self.args.truncate_source, prepend_bos=self.args.prepend_bos, ) if self.args.bpe_dropout > 0: src_dataset = DynamicGPT2BPEDropoutResampling( self.args, src_dataset, self.source_dictionary, dropout=self.args.bpe_dropout, ) # load backtranslation if is_train_subset and not self.args.skip_backtranslation_data: """ noised vs unnoised valdation set? they might converge at different times """ bt_src_dataset, bt_tgt_dataset = load_unpaired_langpair( # data_path, "{}.bt".format(split), src, self.src_dict, tgt, self.tgt_dict, data_path, "{}.bt".format(split), src, self.src_dict, tgt, self.tgt_dict, combine=combine, dataset_impl=self.args.dataset_impl, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, truncate_source=self.args.truncate_source, prepend_bos=self.args.prepend_bos, ) if self.args.bpe == "gpt2": mask_is_beginning_of_word = get_whole_word_mask( self.args, self.source_dictionary) mask_is_beginning_of_word = mask_is_beginning_of_word.numpy( ).astype(np.bool) # noiser = GPT2WordNoising( # self.src_dict, # mask_is_beginning_of_word, # self.args.max_word_shuffle_distance, # self.args.word_dropout_prob, # self.args.word_blanking_prob, # ) if self.args.bpe_dropout > 0: bt_src_dataset = DynamicGPT2BPEDropoutResampling( self.args, bt_src_dataset, self.source_dictionary, dropout=self.args.bpe_dropout, ) noiser = GPT2WordNoisingV2( self.src_dict, mask_is_beginning_of_word, self.args.max_word_shuffle_distance, self.args.word_dropout_prob, self.args.word_blanking_prob, ) bt_src_dataset = DynamicNoisingDataset( bt_src_dataset, self.src_dict, seed=1, noiser=noiser, ) # try: # from icecream import ic # ic.configureOutput(includeContext=True) # except ImportError: # Graceful fallback if IceCream isn't installed. # ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa # ic("gpt2 bbpe") # bpe = encoders.build_bpe(self.args) # def decode(foo): # return bpe.decode(self.src_dict.string(foo)) # def disp(foo): # return " ".join([bpe.decode(i) for i in self.src_dict.string(foo).split(" ")]) # # foo = [bpe.decode(str(i)) for i in range(0,1000)] # # doo = [bpe.decode((i)) for i in self.src_dict.symbols[4:1000]] # for i in range(5): # ic(_bt_src_dataset[i]) # ic(decode(_bt_src_dataset[i])) # ic(disp(_bt_src_dataset[i])) # ic(disp(bt_src_dataset[i])) # ic(bt_src_dataset[i]) # import pdb; pdb.set_trace() else: assert self.args.bpe_dropout <= 0, "BPE dropout not supported for this BPE scheme" # standard bpe with @@ as continuation marker bt_src_dataset = DynamicNoisingDataset( bt_src_dataset, self.src_dict, seed=1, max_word_shuffle_distance=self.args. max_word_shuffle_distance, word_dropout_prob=self.args.word_dropout_prob, word_blanking_prob=self.args.word_blanking_prob, ) # if self.append_backtranslation_tag: if self.args.tagged_backtranslation: bt_src_dataset = AppendTokenDataset( AppendTokenDataset( StripTokenDataset(bt_src_dataset, self.src_dict.eos()), self.bt_idx), self.src_dict.eos(), ) sample_ratios = [self.args.upsample_primary, 1] src_dataset = ConcatDataset([src_dataset, bt_src_dataset], sample_ratios) tgt_dataset = ConcatDataset([tgt_dataset, bt_tgt_dataset], sample_ratios) self.datasets[split] = LanguagePairDataset( src_dataset, src_dataset.sizes, self.src_dict, tgt_dataset, tgt_dataset.sizes, self.tgt_dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, align_dataset=None, eos=self.tgt_dict.eos(), num_buckets=self.args.num_batch_buckets, shuffle=(split not in ("test", "valid")), )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, srcda=False, srcda_choice='uniform', tgtda=False, tgtda_choice='uniform' ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_datasets.append( data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) ) print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl) return LanguagePairDatasetDA( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=align_dataset, srcda=srcda, srcda_choice=srcda_choice, tgtda=tgtda, tgtda_choice=tgtda_choice )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, feature_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_features=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) eos = tgt_dict.index('[{}]'.format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) # print("feature_dict", feature_dict.symbols, feature_dict.count) #feature_dict ['<s>', '<pad>', '</s>', '<unk>', '<ori>', '<rep>', 'madeupword0000', 'madeupword0001'] [1, 1, 1, 1, 18558611, 5354704, 0, 0] feature_dataset = None if load_features: feature_path = os.path.join( data_path, '{}.feature.{}-{}.{}'.format(split, src, tgt, src)) if indexed_dataset.dataset_exists(feature_path, impl=dataset_impl): feature_dataset = data_utils.load_indexed_dataset( feature_path, feature_dict, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, feature_dataset=feature_dataset, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, )
def pos_loader(data_path, split, src, src_dict, tgt, tgt_dict, anchor, anchor_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, truncate_source=False, append_source_id=False): # Check the existence of the file def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] anchor_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode (from a->b or from b->a) if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) anchor_prefix = os.path.join(data_path, anchor, '{}.{}-{}.'.format(split_k, anchor, tgt)) anchor_dataset = data_utils.load_indexed_dataset( anchor_prefix + anchor, anchor_dict, dataset_impl) if anchor_dataset is not None: anchor_datasets.append(anchor_dataset) logger.info('{} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 # None is not avaliable for anchors assert len(src_datasets) == len(anchor_datasets) if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None anchor_dataset = anchor_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None anchor_dataset = ConcatDataset(anchor_datasets, sample_ratios) if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) eos = tgt_dict.index('[{}]'.format(tgt)) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return POSGraphLanguagePairDatasetb( src_dataset, src_dataset.sizes, src_dict, anchor_dataset, anchor_dataset.sizes, anchor_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, eos=eos)
def load_unpaired_langpair( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, max_source_positions, max_target_positions, prepend_bos=False, truncate_source=False, append_source_id=False, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError("Dataset not found: {} ({})".format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info("{} {} {}-{} {} examples".format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index("[{}]".format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index("[{}]".format(tgt))) return src_dataset, tgt_dataset