def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): """ Generate batches for inference. We assume that the input begins with a bos symbol (`<s>`) and ends with an eos symbol (`</s>`). """ pad = self.source_dictionary.pad() eos = self.source_dictionary.eos() src_dataset = TokenBlockDataset( src_tokens, src_lengths, block_size=self.args.tokens_per_sample - 2, # for <s> and </s> pad=pad, eos=eos, break_mode=self.args.sample_break_mode, document_sep_len=0, ) prev_output_tokens = PrependTokenDataset( StripTokenDataset(src_dataset, eos), eos) src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False) return NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": src_dataset, "src_lengths": NumelDataset(src_dataset, reduce=False), "prev_output_tokens": PadDataset(prev_output_tokens, pad_idx=pad, left_pad=False), }, "target": src_dataset, }, sizes=[np.array(src_lengths)], )
def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): src_dataset = RightPadDataset( TokenBlockDataset( src_tokens, src_lengths, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode="eos", ), pad_idx=self.source_dictionary.pad(), ) src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos()) src_dataset = NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": src_dataset, "src_lengths": NumelDataset(src_dataset, reduce=False), }, }, sizes=src_lengths, ) if sort: src_dataset = SortDataset(src_dataset, sort_order=[src_lengths]) return src_dataset
def load_dataset( self, split: str, epoch=1, combine=False, **kwargs ) -> MonolingualDataset: """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) # each process has its own copy of the raw data (likely to be an np.memmap) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine ) if dataset is None: raise FileNotFoundError(f"Dataset not found: {split} ({split_path})") dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, use_plasma_view=self.args.use_plasma_view, split_path=split_path, plasma_path=self.args.plasma_path, ) add_eos_for_other_targets = ( self.args.sample_break_mode is not None and self.args.sample_break_mode != "none" ) self.datasets[split] = MonolingualDataset( dataset=dataset, sizes=dataset.sizes, src_vocab=self.dictionary, tgt_vocab=self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, add_bos_token=self.args.add_bos_token, )
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): """ Generate batches for inference. We prepend an eos token to src_tokens (or bos if `--add-bos-token` is set) and we append a <pad> to target. This is convenient both for generation with a prefix and LM scoring. """ dataset = StripTokenDataset( TokenBlockDataset( src_tokens, src_lengths, block_size=None, # ignored for "eos" break mode pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode="eos", ), # remove eos from (end of) target sequence self.source_dictionary.eos(), ) src_dataset = PrependTokenDataset( dataset, token=( self.source_dictionary.bos() if getattr(self.args, "add_bos_token", False) else self.source_dictionary.eos() ), ) tgt_dataset = AppendTokenDataset(dataset, token=self.source_dictionary.pad()) return NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False ), }, sizes=[np.array(src_lengths)], )
def _load_single_lang_dataset(self, split, epoch): loaded_datasets = [] paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") path = os.path.join(data_path, split_k) ds = data_utils.load_indexed_dataset(path, self.dictionary, self.args.dataset_impl) if ds is None: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, data_path)) # Since we append each block with the classification_token, # we need to effectively create blocks of length # tokens_per_sample-1 loaded_datasets.append( TokenBlockDataset( ds, ds.sizes, self.args.tokens_per_sample - 1, pad=self.dictionary.pad(), eos=self.dictionary.eos(), )) logger.info("{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1]))) if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) return dataset, sizes
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) dataset = StripTokenDataset(dataset, self.dictionary.eos()) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 2, # one less for <s> and one for </s> pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, document_sep_len=0, ) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) mask_whole_words = (get_whole_word_mask(self.args, self.source_dictionary) if self.args.mask_length != "subword" else None) self.datasets[split] = DenoisingDataset( dataset, dataset.sizes, self.dictionary, self.mask_idx, mask_whole_words, shuffle=self.args.shuffle_instance, seed=self.seed, args=self.args, ) logger.info( "Split: {0}, Loaded {1} samples of denoising_dataset".format( split, len(self.datasets[split]), ))
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets mask_whole_words = (get_whole_word_mask(self.args, self.source_dictionary) if self.args.mask_whole_words else None) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, mask_multiple_length=self.args.mask_multiple_length, mask_stdev=self.args.mask_stdev, ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), ), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": RightPadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), ), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] languages = sorted(name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name))) logger.info("Training on {0} languages: {1}".format( len(languages), languages)) logger.info("Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}) mask_whole_words = self._get_whole_word_mask() lang_datasets = [] for lang_id, language in enumerate(languages): split_path = os.path.join(data_path, language, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format( len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) lang_dataset = NestedDictionaryDataset( { "net_input": { "src_tokens": PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_dataset, reduce=True), "lang_id": RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]), }, sizes=[src_dataset.sizes], ) lang_datasets.append(lang_dataset) dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) logger.info("loaded total {} blocks for all languages".format( dataset_lengths.sum(), )) if split == self.args.train_subset: # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) logger.info( "Sample probability by language: ", { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) }, ) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths logger.info( "Up/Down Sampling ratio by language: ", { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) }, ) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] >= 1.0, ) for i, d in enumerate(lang_datasets) ] dataset = ConcatDataset(resampled_lang_datasets) else: dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): split_name = split + "_" + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset # [TODO]: This is hacky for now to print validation ppl for each # language individually. Maybe need task API changes to allow it # in more generic ways. if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( split, ",".join(lang_splits)) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )