def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): """ Generate batches for inference. We prepend an eos token to src_tokens (or bos if `--add-bos-token` is set) and we append an eos to target. This is convenient both for generation with a prefix and LM scoring. """ tgt_dataset = TokenBlockDataset( src_tokens, src_lengths, block_size=None, # ignored for "eos" break mode pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode="eos", ) src_dataset = PrependTokenDataset( StripTokenDataset( tgt_dataset, # remove eos from (end of) target sequence self.source_dictionary.eos(), ), token=(self.source_dictionary.bos() if getattr( self.args, "add_bos_token", False) else self.source_dictionary.eos()), ) return NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": PadDataset(src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": PadDataset(tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False), }, sizes=[np.array(src_lengths)], )
def tgt_dataset_tranform_func(self, source_lang, target_lang, dataset, spec=None): if dataset is None: # note that target dataset can be None during inference time return None if self.args.lang_tok_replacing_bos_eos: # TODO: Unifiy with alter_dataset_langtok # It is handled by self.alter_dataset_langtok. # The complication in self.alter_dataset_langtok # makes a unified framework difficult. return dataset # if not self.args.decoder_langtok: if not spec: return dataset tok = self.get_decoder_langtok(target_lang, spec) if tok: return PrependTokenDataset(dataset, tok) return dataset
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): dataset = StripTokenDataset( TokenBlockDataset( src_tokens, src_lengths, block_size=None, # ignored for "eos" break mode pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode="eos", ), # remove eos from (end of) target sequence self.source_dictionary.eos(), ) src_dataset = PrependTokenDataset( dataset, token=(self.source_dictionary.bos() if getattr( self.args, "add_bos_token", False) else self.source_dictionary.eos()), ) tgt_dataset = AppendTokenDataset(dataset, token=self.source_dictionary.pad()) return NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": PadDataset(src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": PadDataset(tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False), }, sizes=[np.array(src_lengths)], )
def load_dataset(lang, lang_dict, prefix, dataset_length, sample_ratios=None): """ Function to load additional dataset and deal with all parameters. Easier than copying redudant code for each dataset. Requires src_dataset to provide the length and sample_ratios. """ lang_datasets = [] lang_dataset = data_utils.load_indexed_dataset(prefix + lang, lang_dict, dataset_impl) if lang_dataset is not None: lang_datasets.append(lang_dataset) assert dataset_length == len(lang_datasets) or len(lang_datasets) == 0 if dataset_length == 1: lang_dataset = lang_datasets[0] if len(lang_datasets) > 0 else None else: assert sample_ratios is not None if len(lang_datasets) > 0: lang_dataset = ConcatDataset(lang_datasets, sample_ratios) else: lang_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( lang_dict, "bos_index") if lang_dataset is not None: lang_dataset = PrependTokenDataset(lang_dataset, lang_dict.bos()) eos = None if append_source_id: if lang_dataset is not None: lang_dataset = AppendTokenDataset( lang_dataset, lang_dict.index('[{}]'.format(lang))) lang_dataset_sizes = lang_dataset.sizes if lang_dataset is not None else None return lang_dataset, lang_dataset_sizes
def _load_dataset_split(self, split, epoch, combine): paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, combine=combine, ) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) dataset = maybe_shorten_dataset( dataset, split, self.cfg.shorten_data_split_list, self.cfg.shorten_method, self.cfg.tokens_per_sample, self.cfg.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.cfg.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.cfg.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) return PrependTokenDataset(dataset, self.source_dictionary.bos())
def load_denoise_dataset(self, data_path: str, lang: str) -> FairseqDataset: """Classic denoising dataset""" dataset = data_utils.load_indexed_dataset(data_path, self.common_dict, self.args.dataset_impl) noisy_dataset = NoisingDataset( dataset, self.dictionary, seed=1, max_word_shuffle_distance=self.args.max_word_shuffle_distance, word_dropout_prob=self.args.word_dropout_prob, word_blanking_prob=self.args.word_blanking_prob, ) noisy_dataset = PrependTokenDataset( noisy_dataset, _lang_token_index(self.dictionary, lang)) clean_dataset = data_utils.load_indexed_dataset( data_path, self.common_dict, self.args.dataset_impl) denoising_dataset = self._langpair_dataset(noisy_dataset, clean_dataset) denoising_dataset = self._prepend_lang_bos_to_target( denoising_dataset, lang) return denoising_dataset
def load_bt_dataset(self, data_path: str, lang: str) -> FairseqDataset: """The BT dataset is generated with (tgt, tgt) pairs. The actual translation to a (generated_src, tgt) pair is done on the fly during training. """ mono_dataset = data_utils.load_indexed_dataset(data_path, self.common_dict, self.args.dataset_impl) assert mono_dataset is not None, f"No dataset found for {lang}" mono_dataset_src = PrependTokenDataset( mono_dataset, _lang_token_index(self.dictionary, lang)) mono_dataset_bt = self._langpair_dataset(mono_dataset_src, mono_dataset) logger.info( f"mono_lang = {lang} " f"lang token index = {_lang_token_index(self.dictionary, lang)} " f"lang token = {_lang_token(lang)}") mono_dataset_bt = self._prepend_lang_bos_to_target( mono_dataset_bt, lang) return mono_dataset_bt
def load_dataset(self, split: str, combine: bool = False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" inputs_path = Path(self.args.data) / "{split}".format(split=split) src_tokens = data_utils.load_indexed_dataset(str(inputs_path), self.source_dictionary, self.args.dataset_impl, combine=combine) assert src_tokens is not None, "could not find dataset: {}".format( inputs_path) src_tokens = PrependTokenDataset(src_tokens, self.source_dictionary.bos()) targets_path = Path(self.args.data) / "{}.term".format(split) labels = data_utils.load_indexed_dataset(str(targets_path), self._label_dictionary, self.args.dataset_impl, combine=combine) assert labels is not None, "could not find labels: {}".format( targets_path) clean_labels = NoBosEosDataset(labels, self.label_dictionary) word_mask = WordEndMaskDataset(src_tokens, self.dictionary, self.is_word_initial, bos_value=0, eos_value=0) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset(src_tokens, pad_idx=self.source_dictionary.pad()), "nsrc_tokens": NumelDataset(src_tokens), "word_mask": RightPadDataset(word_mask, pad_idx=0), # pad is zero since mask }, "target_attrs": RightPadDataset(clean_labels, pad_idx=self.label_dictionary.pad()), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), "nwords": NumWordsDataset(src_tokens, self.dictionary, self.is_word_initial), } nested_dataset = NestedDictionaryDatasetFix(dataset, sizes=[src_tokens.sizes]) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) dataset = SortDataset(nested_dataset, sort_order=[shuffle]) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) return dataset src_tokens = make_dataset('data', self.source_dictionary) if self.args.init_token is not None: src_tokens = PrependTokenDataset(src_tokens, self.args.init_token) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) if self.args.truncate_sequence: src_tokens = TruncateDataset(src_tokens, self.args.max_positions) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } label_dataset = make_dataset('label', self.label_dictionary) if label_dataset is not None: dataset.update(target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.label_dictionary.eos(), ), offset=-self.label_dictionary.nspecial, )) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format(get_path(type, split)) input1 = make_dataset('input1', self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if input1 is None: src_tokens = input0 else: if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) src_tokens = maybe_shorten_dataset( src_tokens, split, self.args.shorten_data_split_whitelist, self.args.shorten_method, self.args.max_positions, self.args.seed, ) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } if self.args.add_prev_output_tokens: prev_tokens_dataset = RightPadDataset( RollDataset(src_tokens, 1), pad_idx=self.dictionary.pad(), ) dataset['net_input'].update( prev_output_tokens=prev_tokens_dataset, ) if not self.args.regression_target: label_dataset = make_dataset('label', self.label_dictionary) if label_dataset is not None: dataset.update( target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.label_dictionary.eos(), ), offset=-self.label_dictionary.nspecial, ) ) else: label_path = "{0}.label".format(get_path('label', split)) if os.path.exists(label_path): def parse_regression_target(i, line): values = line.split() assert len(values) == self.args.num_classes, \ f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"' return [float(x) for x in values] dataset.update( target=RawLabelDataset([ parse_regression_target(i, line.strip()) for i, line in enumerate(open(label_path).readlines()) ]) ) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def pos_loader(data_path, split, src, src_dict, tgt, tgt_dict, anchor, anchor_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, truncate_source=False, append_source_id=False): # Check the existence of the file def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] anchor_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode (from a->b or from b->a) if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) anchor_prefix = os.path.join(data_path, anchor, '{}.{}-{}.'.format(split_k, anchor, tgt)) anchor_dataset = data_utils.load_indexed_dataset( anchor_prefix + anchor, anchor_dict, dataset_impl) if anchor_dataset is not None: anchor_datasets.append(anchor_dataset) logger.info('{} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 # None is not avaliable for anchors assert len(src_datasets) == len(anchor_datasets) if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None anchor_dataset = anchor_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None anchor_dataset = ConcatDataset(anchor_datasets, sample_ratios) if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) eos = tgt_dict.index('[{}]'.format(tgt)) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return POSGraphLanguagePairDatasetb( src_dataset, src_dataset.sizes, src_dict, anchor_dataset, anchor_dataset.sizes, anchor_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, eos=eos)
def load_dataset(self, split, epoch=0, combine=False): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format( get_path(type, split)) input1 = make_dataset('input1', self.source_dictionary) input0 = PrependTokenDataset(input0, self.source_dictionary.bos()) if input1 is None: src_tokens = input0 else: input1 = PrependTokenDataset(input1, self.source_dictionary.eos()) src_tokens = ConcatSentencesDataset(input0, input1) src_tokens = TruncateDataset(src_tokens, self.args.max_positions) assert not self.args.mask_whole_words src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( src_tokens, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=None, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadToLenDataset(src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, pad_len=self.args.max_positions), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadToLenDataset(tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, pad_len=self.args.max_positions), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, srcda=False, srcda_choice='uniform', tgtda=False, tgtda_choice='uniform' ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_datasets.append( data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) ) print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl) return LanguagePairDatasetDA( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=align_dataset, srcda=srcda, srcda_choice=srcda_choice, tgtda=tgtda, tgtda_choice=tgtda_choice )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets mask_whole_words = (get_whole_word_mask(self.args, self.source_dictionary) if self.args.mask_whole_words else None) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, mask_multiple_length=self.args.mask_multiple_length, mask_stdev=self.args.mask_stdev, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), ), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": RightPadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), ), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, epoch=0, combine=False, data_selector=None): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) print('| loaded {} batches from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets if self.args.mask_whole_words: bpe = encoders.build_bpe(self.args) if bpe is not None: def is_beginning_of_word(i): if i < self.source_dictionary.nspecial: # special elements are always considered beginnings return True tok = self.source_dictionary[i] if tok.startswith('madeupword'): return True try: return bpe.is_beginning_of_word(tok) except ValueError: return True mask_whole_words = torch.ByteTensor(list( map(is_beginning_of_word, range(len(self.source_dictionary))) )) else: mask_whole_words = None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] languages = sorted(name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name))) logger.info("Training on {0} languages: {1}".format( len(languages), languages)) logger.info("Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}) mask_whole_words = self._get_whole_word_mask() lang_datasets = [] for lang_id, language in enumerate(languages): split_path = os.path.join(data_path, language, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info('loaded {} blocks from: {}'.format( len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) lang_dataset = NestedDictionaryDataset( { 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), 'lang_id': RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]), }, sizes=[src_dataset.sizes], ) lang_datasets.append(lang_dataset) dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) logger.info('loaded total {} blocks for all languages'.format( dataset_lengths.sum(), )) if split == self.args.train_subset: # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) logger.info( "Sample probability by language: ", { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) }) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths logger.info( "Up/Down Sampling ratio by language: ", { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) }) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] >= 1.0, ) for i, d in enumerate(lang_datasets) ] dataset = ConcatDataset(resampled_lang_datasets) else: dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): split_name = split + '_' + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset # [TODO]: This is hacky for now to print validation ppl for each # language individually. Maybe need task API changes to allow it # in more generic ways. if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( split, ','.join(lang_splits)) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )
# create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, tokens_per_sample - 2, # one less for <s> and one for </s> pad=dictionary.pad(), eos=dictionary.eos(), break_mode=args.sample_break_mode, document_sep_len=0, ) assert len(dataset) == prev_size # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, source_dictionary.bos()) dataset = AppendTokenDataset(dataset, source_dictionary.eos()) mask_whole_words = (get_whole_word_mask(args, source_dictionary) if mask_length != 'subword' else None) bpe = encoders.build_bpe(args) eoh = dictionary.indices[bpe.encode('</h>')] denoising_dataset = DenoisingDataset(dataset, dataset.sizes, dictionary, mask_idx, mask_whole_words, shuffle=False, seed=seed, args=args,
def load_dataset(self, split: str, combine: bool = False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" inputs_path = Path(self.args.data) / "{split}".format(split=split) src_tokens = data_utils.load_indexed_dataset( str(inputs_path), self.source_dictionary, self.args.dataset_impl, combine=combine, ) assert src_tokens is not None, "could not find dataset: {}".format( inputs_path) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) src_tokens = PrependTokenDataset(src_tokens, self.source_dictionary.bos()) targets_path = Path(self.args.data) / "{}.term".format(split) term_labels = data_utils.load_indexed_dataset( str(targets_path), self.label_dictionary, self.args.dataset_impl, combine=combine, ) assert term_labels is not None, "could not find labels: {}".format( targets_path) term_cats, term_attrs = POSDataset.make_both(term_labels, self.dictionary, self.label_dictionary) def print_terms(term_cats, term_attrs): # Debug function cat_labels = [self.label_dictionary[t] for t in term_cats] attr_data = [t.nonzero().T for t in term_attrs if t.numel()] attr_labels = [] for word_attr in attr_data: if not word_attr.numel(): attr_labels.append([]) continue attr_labels.append([ self.label_dictionary[t + self.label_dictionary.nspecial] for t in word_attr[0] ]) return cat_labels, attr_labels word_mask = WordEndMaskDataset(src_tokens, self.dictionary, self.is_word_initial, bos_value=0, eos_value=0) exclude_cats_mask = IgnoreLabelsDataset(term_cats, self.ignore_cats) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset(src_tokens, pad_idx=self.source_dictionary.pad()), "nsrc_tokens": NumelDataset(src_tokens), "word_mask": RightPadDataset(word_mask, pad_idx=0), }, "exclude_cats_mask": RightPadDataset(exclude_cats_mask, pad_idx=1), "target_cats": RightPadDataset(term_cats, pad_idx=self.label_dictionary.pad()), "target_attrs": RightPad2dDataset(term_attrs, pad_idx=0), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), "nwords": NumWordsDataset(src_tokens, self.dictionary, self.is_word_initial), } nested_dataset = NestedDictionaryDatasetFix(dataset, sizes=[src_tokens.sizes]) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset(nested_dataset, sort_order=[shuffle]) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(":") assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) if self.langs is None: languages = sorted( [ name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name)) ] ) else: languages = self.langs.split(",") for name in languages: p = os.path.join(data_path, name) assert os.path.exists(p), "data not found: {}".format(p) logger.info("Training on {0} languages: {1}".format(len(languages), languages)) logger.info( "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)} ) mask_whole_words = get_whole_word_mask(self.args, self.dictionary) language_without_segmentations = self.args.no_whole_word_mask_langs.split(",") lang_datasets = [] for language in languages: split_path = os.path.join(data_path, language, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, split_path) ) end_token = ( self.source_dictionary.index("[{}]".format(language)) if self.args.add_lang_token else self.source_dictionary.eos() ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 2, # one less for <s> pad=self.source_dictionary.pad(), eos=end_token, break_mode=self.args.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, end_token) lang_mask_whole_words = ( mask_whole_words if language not in language_without_segmentations else None ) lang_dataset = DenoisingDataset( dataset, dataset.sizes, self.dictionary, self.mask_idx, lang_mask_whole_words, shuffle=self.args.shuffle_instance, seed=self.seed, args=self.args, eos=None if not self.args.add_lang_token else self.source_dictionary.index("[{}]".format(language)), ) lang_datasets.append(lang_dataset) dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) logger.info( "loaded total {} blocks for all languages".format( int(dataset_lengths.sum()), ) ) if split == self.args.train_subset: # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) logger.info( "Sample probability by language: {}".format( { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) } ) ) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths logger.info( "Up/Down Sampling ratio by language: {}".format( { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) } ) ) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] >= 1.0, ) for i, d in enumerate(lang_datasets) ] dataset = ConcatDataset( resampled_lang_datasets, ) else: dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): split_name = split + "_" + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( split, ",".join(lang_splits) ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )
def load_KE_dataset(self, split, kedata_path, epoch=0, combine=False): paths = kedata_path.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] def get_path(type): return os.path.join(data_path,type,split) def desc_dataset(type, dictionary, relation_desc=None): now_path=get_path(type) #print(now_path) dataset=data_utils.load_indexed_dataset( now_path, dictionary, self.args.dataset_impl, combine=combine, ) if self.args.init_token is not None: dataset = PrependTokenDataset(dataset, self.args.init_token) if relation_desc is not None: dataset = ConcatSentencesDataset(dataset, relation_desc) dataset = TruncateDataset(dataset, self.args.tokens_per_sample) #??? dataset = RightPadDataset(dataset, pad_idx=self.source_dictionary.pad()) return dataset assert(not (self.args.relation_desc and self.args.relemb_from_desc)) if self.args.relation_desc or self.args.relemb_from_desc: now_path=get_path('relation_desc') relation_desc=data_utils.load_indexed_dataset( now_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if self.args.relation_desc: if self.args.separator_token is not None: relation_desc = PrependTokenDataset(relation_desc, self.args.separator_token) else: raise Exception("separator_token is None") elif self.args.relemb_from_desc: relation_desc = PrependTokenDataset(relation_desc, self.args.init_token) relation_desc = TruncateDataset(relation_desc, self.args.tokens_per_sample // 8) # 64 relation_desc = RightPadDataset(relation_desc, pad_idx=self.source_dictionary.pad()) else: relation_desc = None head=desc_dataset("head",self.source_dictionary) tail=desc_dataset("tail",self.source_dictionary) nHead=desc_dataset("negHead",self.source_dictionary) nTail=desc_dataset("negTail",self.source_dictionary) head_r=desc_dataset("head",self.source_dictionary, relation_desc if self.args.relation_desc else None) tail_r=desc_dataset("tail",self.source_dictionary, relation_desc if self.args.relation_desc else None) assert len(nHead)%len(head)==0, "check the KE positive and negative instances' number" self.negative_sample_size=len(nHead)/len(head) relation=np.load(get_path("relation")+".npy") sizes=np.load(get_path("sizes")+".npy") with data_utils.numpy_seed(self.args.seed + epoch): shuffle=np.random.permutation(len(head)) net_input = { 'heads': head, 'tails': tail, 'nHeads': KeNegDataset(nHead,self.args), 'nTails': KeNegDataset(nTail,self.args), 'heads_r': head_r, 'tails_r': tail_r, 'src_lengths': FakeNumelDataset(sizes, reduce=False), } if self.args.relemb_from_desc: net_input['relation_desc'] = relation_desc dataset=SortDataset( NestedDictionaryDataset( { 'id':IdDataset(), 'net_input': net_input, 'target': RawLabelDataset(relation), 'nsentences':NumSamplesDataset(), 'ntokens': FakeNumelDataset(sizes, reduce=True), }, sizes=[sizes], ), sort_order=[shuffle], ) return dataset
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None # these features are not yet implemented for the cluster code if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset( src_dataset, src_dict.index('[{}]'.format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) eos = tgt_dict.index('[{}]'.format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, ) else: # sample_ratios = [1] * len(src_datasets) # sample_ratios[0] = upsample_primary # src_dataset = ConcatDataset(src_datasets, sample_ratios) # if len(tgt_datasets) > 0: # tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) # else: # tgt_dataset = None datasets = [] eos = None align_dataset = None for i in range(0, len(src_datasets)): src_dataset = src_datasets[i] tgt_dataset = tgt_datasets[i] tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None datasets.append( LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, )) return datasets
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) print('| loaded {} blocks from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ if self.args.mask_whole_words else None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_lang_dataset( self, data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, max_source_positions, prepend_bos=False, load_alignments=False, truncate_source=False, ): src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") # infer langcode if self.split_exists(split_k, src, tgt, src, data_path, dataset_impl): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) elif self.split_exists(split_k, tgt, src, src, data_path, dataset_impl): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) else: if k > 0: break else: logger.error( f"Dataset not found: {data_path}, {split_k}, {src}, {tgt}" ) raise FileNotFoundError( "Dataset not found: {} ({})".format(split, data_path) ) src_dataset = self.load_data(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_datasets.append(self.load_data(prefix + tgt, tgt_dict, dataset_impl)) logger.info( "{} {} {}-{} {} examples".format( data_path, split_k, src, tgt, len(src_datasets[-1]) ) ) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) align_dataset = None if load_alignments: align_path = os.path.join( data_path, "{}.align.{}-{}".format(split, src, tgt) ) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl ) return src_dataset, tgt_dataset, align_dataset
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format( get_path(type, split)) input1 = make_dataset('input1', self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if input1 is None: src_tokens = input0 else: if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) if self.args.truncate_sequence: src_tokens = TruncateDataset(src_tokens, self.args.max_positions) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } if not self.args.regression_target: label_dataset = make_dataset('label', self.target_dictionary) if label_dataset is not None: dataset.update(target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.target_dictionary.eos(), ), offset=-self.target_dictionary.nspecial, )) else: label_path = f"{get_path('label', split)}.label" if os.path.exists(label_path): dataset.update(target=RawLabelDataset( [float(x.strip()) for x in open(label_path).readlines()])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) print(f"| Loaded {split} with #samples: {len(dataset)}") self.datasets[split] = dataset return self.datasets[split]
def load_generation_pair_dataset( data_path, split, tgt, src_dict, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, common_eos=None ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, "src", "tgt", "src", data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, "src", "tgt")) elif split_exists(split_k, "tgt", "src", "src", data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, "tgt", "src")) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + "src", src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + "tgt", tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format( data_path, split_k, "src", "tgt", len(src_datasets[-1]) )) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: if common_eos is not None: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(common_eos))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset(tgt_dataset, tgt_dict.index('[{}]'.format(common_eos))) eos = tgt_dict.index('[{}]'.format(common_eos)) bos = tgt_dict.index('[{}]'.format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, "src", "tgt")) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return GenerationPairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=align_dataset, eos=eos, bos=bos )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) input_options = [ make_dataset('input{idx}'.format(idx=idx + 1), self.source_dictionary) for idx in range(self.args.num_classes) ] if self.args.separator_token is not None: input0 = PrependTokenDataset(input0, self.args.separator_token) src_tokens = [] for input_option in input_options: if self.args.init_token is not None: input_option = PrependTokenDataset(input_option, self.args.init_token) if self.args.max_option_length is not None: input_option = TruncateDataset(input_option, self.args.max_option_length) src_token = ConcatSentencesDataset(input_option, input0) if self.args.truncate_sequence: src_token = TruncateDataset(src_token, self.args.max_positions) src_tokens.append(src_token) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens[0])) dataset = { 'id': IdDataset(), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens[0], reduce=True), } for src_token_idx in range(len(src_tokens)): dataset.update({ 'net_input{idx}'.format(idx=src_token_idx + 1): { 'src_tokens': RightPadDataset( src_tokens[src_token_idx], pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens[src_token_idx], reduce=False), } }) label_path = '{}.label'.format(get_path('label', split)) if os.path.exists(label_path): with open(label_path) as h: dataset.update(target=RawLabelDataset( [int(x.strip()) for x in h.readlines()])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[ np.maximum.reduce( [src_token.sizes for src_token in src_tokens]) ], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(key, split): return os.path.join(self.cfg.data, key, split) def make_dataset(key, dictionary): split_path = get_path(key, split) try: dataset = data_utils.load_indexed_dataset( split_path, dictionary, combine=combine, ) except Exception as e: if "StorageException: [404] Path not found" in str(e): logger.warning(f"dataset {e} not found") dataset = None else: raise e return dataset input0 = make_dataset("input0", self.source_dictionary) assert input0 is not None, "could not find dataset: {}".format( get_path("input0", split)) input1 = make_dataset("input1", self.source_dictionary) if self.cfg.init_token is not None: input0 = PrependTokenDataset(input0, self.cfg.init_token) if input1 is None: src_tokens = input0 else: if self.cfg.separator_token is not None: input1 = PrependTokenDataset(input1, self.cfg.separator_token) src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.cfg.seed): shuffle = np.random.permutation(len(src_tokens)) src_tokens = maybe_shorten_dataset( src_tokens, split, self.cfg.shorten_data_split_list, self.cfg.shorten_method, self.max_positions(), self.cfg.seed, ) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), "src_lengths": NumelDataset(src_tokens, reduce=False), }, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), } if self.cfg.add_prev_output_tokens: prev_tokens_dataset = RightPadDataset( RollDataset(src_tokens, 1), pad_idx=self.dictionary.pad(), ) dataset["net_input"].update( prev_output_tokens=prev_tokens_dataset, ) if not self.cfg.regression_target: label_dataset = make_dataset("label", self.label_dictionary) if label_dataset is not None: dataset.update(target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.label_dictionary.eos(), ), offset=-self.label_dictionary.nspecial, )) else: label_path = "{0}.label".format(get_path("label", split)) if os.path.exists(label_path): def parse_regression_target(i, line): values = line.split() assert ( len(values) == self.cfg.num_classes ), f'expected num_classes={self.cfg.num_classes} regression target values on line {i}, found: "{line}"' return [float(x) for x in values] with open(label_path) as h: dataset.update(target=RawLabelDataset([ parse_regression_target(i, line.strip()) for i, line in enumerate(h.readlines()) ])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.cfg.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, pad_to_multiple=1, add_lang_token=False, ): def split_exists(split, src, tgt, lang, data_path): logger.info( os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))) filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) def split_exists_self(split, src, data_path): logger.info( os.path.join(data_path, "{}.{}-{}.{}".format(split, src, src, src))) filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, src, src)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) def split_exists_valid(split, lang, data_path): logger.info(os.path.join(data_path, "{}.{}".format(split, lang))) filename = os.path.join(data_path, "{}.{}".format(split, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") # print(split_k, src, tgt, src, data_path) prefix_src = None prefix_tgt = None if not "-" in split_k: # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({}) {} {}".format( split, data_path, src, tgt)) else: # infer langcode if split_exists_valid(split_k, src, data_path): prefix = os.path.join(data_path, split_k + ".") else: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({}) ".format(split, data_path)) if prefix_src != None: prefix = prefix_src src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) if prefix_tgt != None: prefix = prefix_tgt tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info("{} {} {}-{} {} examples".format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary logger.info("::::data sample_ratios:{}".format(sample_ratios)) src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index("[{}]".format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index("[{}]".format(tgt))) eos = tgt_dict.index("[{}]".format(tgt)) eos = None if add_lang_token: src_dataset = PrependTokenDataset(src_dataset, src_dict.index("[{}]".format(src))) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset( tgt_dataset, tgt_dict.index("[{}]".format(tgt))) align_dataset = None if load_alignments: align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, )
def load_dataset(self, split: str, combine: bool = False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" inputs_path = Path(self.args.data) / f"{split}.text" src_tokens = data_utils.load_indexed_dataset( str(inputs_path), self.source_dictionary, self.args.dataset_impl, combine=combine, ) assert src_tokens is not None, "could not find dataset: {}".format( inputs_path) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) src_tokens = PrependTokenDataset(src_tokens, self.source_dictionary.bos()) word_masks_w_bos = WordEndMaskDataset(src_tokens, self.dictionary, self.is_word_initial, bos_value=1, eos_value=0) nterm_targets_path = Path(self.args.data) / "{}.nonterm".format(split) labelled_spans = data_utils.load_indexed_dataset( str(nterm_targets_path), self.nterm_dictionary, self.args.dataset_impl, combine=combine, ) assert labelled_spans is not None, "could not find nonterminal labels: {}".format( nterm_targets_path) target_spans, nterm_cats = DynamicLabelledSpanDataset.make_both( labelled_spans, self.nterm_dictionary, seed=self.args.seed, ) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset(src_tokens, pad_idx=self.source_dictionary.pad()), "nsrc_tokens": NumelDataset(src_tokens), "word_mask_w_bos": RightPadDataset(word_masks_w_bos, pad_idx=0), }, "target_span_labels": RightPadDataset(nterm_cats, pad_idx=self.nterm_dictionary.pad()), "target_spans": RightPadDataset(target_spans, pad_idx=0), "ntarget_span_labels": NumelDataset(nterm_cats), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), "nwords": NumWordsDataset(src_tokens, self.dictionary, self.is_word_initial), } nested_dataset = NestedDictionaryDatasetFix(dataset, sizes=[src_tokens.sizes]) dataset = SortDataset(nested_dataset, sort_order=[shuffle]) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, field, split): return os.path.join(self.args.data, type, field, split) def make_dataset(type, field, dictionary): split_path = get_path(type, field, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = {} input1 = {} for field in configs.fields: input0[field] = make_dataset('input0', field, self.source_dictionary[field]) assert input0[ field] is not None, 'could not find dataset: {}'.format( get_path('input0', field, split)) input1[field] = make_dataset('input1', field, self.source_dictionary[field]) assert input1[ field] is not None, 'could not find dataset: {}'.format( get_path('input1', field, split)) assert len(input0[field]) == len( input1[field]), 'input pair different length' if self.args.init_token is not None: input0[field] = PrependTokenDataset(input0[field], self.args.init_token) input1[field] = PrependTokenDataset(input1[field], self.args.init_token) if self.args.truncate_sequence: input0[field] = TruncateDataset(input0[field], self.args.max_positions) input1[field] = TruncateDataset(input1[field], self.args.max_positions) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(input0[field])) dataset = { 'id': IdDataset(), 'net_input0': { 'src_tokens': { field: RightPadDataset( input0[field], pad_idx=self.source_dictionary[field].pad()) for field in configs.fields }, 'src_lengths': NumelDataset(input0[field], reduce=False), }, 'net_input1': { 'src_tokens': { field: RightPadDataset( input1[field], pad_idx=self.source_dictionary[field].pad()) for field in configs.fields }, 'src_lengths': NumelDataset(input1[field], reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens0': NumelDataset(input0[field], reduce=True), 'ntokens1': NumelDataset(input1[field], reduce=True), } label_path = "{0}.label".format(get_path('label', '', split)) if os.path.exists(label_path): dataset.update(target=RawLabelDataset( [float(x.strip()) for x in open(label_path).readlines()])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[np.maximum(input0[field].sizes, input1[field].sizes)], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]