def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) return dataset
def load_decode_data(path, mydict): dataset = data_utils.load_indexed_dataset( path, mydict, 'mmap', combine=False, ) dataset = PrependTokenDataset(dataset, mydict.bos()) return dataset
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) dataset = StripTokenDataset(dataset, self.dictionary.eos()) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 2, # one less for <s> and one for </s> pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, document_sep_len=0) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ if self.args.mask_length != 'subword' else None self.datasets[split] = DenoisingDataset( dataset, dataset.sizes, self.dictionary, self.mask_idx, mask_whole_words, shuffle=self.args.shuffle_instance, seed=self.seed, args=self.args) print("| Split: {0}, Loaded {1} samples of denoising_dataset".format( split, len(self.datasets[split]), ))
def load_text_annotations(path, prefix): text_data = load_indexed_dataset( os.path.join(path, prefix + '.text'), None, dataset_impl='mmap', ) assert text_data is not None annotation_data = np.load(os.path.join(path, prefix + '.annotations.npy')) assert annotation_data is not None return text_data, annotation_data
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine ) if dataset is None: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, split_path) ) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, ) add_eos_for_other_targets = ( self.args.sample_break_mode is not None and self.args.sample_break_mode != "none" ) self.datasets[split] = self._initialize_dataset( dataset=dataset, sizes=dataset.sizes, src_vocab=self.dictionary, tgt_vocab=self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, add_bos_token=self.args.add_bos_token, )
def load_dataset(self, split, shuffle=True, **kwargs): """Load a dataset split.""" def split_exists(split, src, tgt, lang, data_path): filename = os.path.join( data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl) # infer langcode src, tgt = self.args.source_lang, self.args.target_lang if split_exists(split, src, tgt, src, self.args.data): prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt)) elif split_exists(split, tgt, src, src, self.args.data): prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src)) else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, self.args.data)) src_dataset = data_utils.load_indexed_dataset(prefix + src, self.src_dict, self.args.dataset_impl) # tgt_dataset = indexed_dataset(prefix + tgt, self.tgt_dict) rng = np.random.RandomState(self.args.seed) # need to be updated with extractive summarization dataset self.datasets[split] = SentsPermAndPredictMaskDataset( src_dataset, src_dataset.sizes, self.src_dict, None, None, self.tgt_dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, shuffle=shuffle, is_poniter_net=(self.args.predict_arch == 'pointer_net'), max_sent_len=self.args.max_sent_length, max_doc_len=self.args.max_doc_length, masked_sent_prob=self.args.masked_sent_prob, max_predictions_per_doc=self.args.max_predictions_per_doc, rng=rng, shuffle_prob=self.args.shuffle_prob, doc_sizes=None, mask_other_sents=eval(self.args.mask_other_sents), max_tokens_len=self.args.max_roberta_position, fix_ratio=self.args.fix_ratio, bert_model=self.args.roberta_model, )
def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) assert dataset is not None, "could not find dataset: {}".format( get_path(type, split)) return dataset
def __init__( self, data_dir, split, sample_rate, max_sample_size=None, min_sample_size=0, shuffle=True, pad=False, normalize=False, num_buckets=0, compute_mask_indices=False, **mask_compute_kwargs, ): super().__init__( sample_rate=sample_rate, max_sample_size=max_sample_size, min_sample_size=min_sample_size, shuffle=shuffle, pad=pad, normalize=normalize, compute_mask_indices=compute_mask_indices, **mask_compute_kwargs, ) from fairseq.data import data_utils, Dictionary self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt")) root_path = os.path.join(data_dir, f"{split}.root") if os.path.exists(root_path): with open(root_path, "r") as f: self.root_dir = next(f).strip() else: self.root_dir = None fnames_path = os.path.join(data_dir, split) self.fnames = data_utils.load_indexed_dataset(fnames_path, self.fnames_dict) lengths_path = os.path.join(data_dir, f"{split}.lengths") with open(lengths_path, "r") as f: for line in f: sz = int(line.rstrip()) assert ( sz >= min_sample_size ), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}" self.sizes.append(sz) self.sizes = np.array(self.sizes, dtype=np.int64) self.set_bucket_info(num_buckets) logger.info(f"loaded {len(self.fnames)} samples")
def load_mask_data(path,mydict):#一个大列表,每个item是一个文档矩阵,矩阵里面每个item是一个node的数值 ,for token_id 和 #print('???',path) #from fairseq.data.indexed_dataset import MMapIndexedDataset #print('???', MMapIndexedDataset(path) ) dataset = data_utils.load_indexed_dataset(path,mydict,'mmap',combine=False,) #print(dataset.__getitem__(0),dataset.__getitem__(0).shape,len(dataset)) dataset = TokenBlockDataset(dataset,dataset.sizes,512 - 1,pad=mydict.pad(),eos=mydict.eos(), break_mode='complete_doc',) #print(dataset.__getitem__(0),dataset.__getitem__(0).shape,len(dataset)) dataset = PrependTokenDataset(dataset, mydict.bos()) #print(dataset.__getitem__(0),dataset.__getitem__(0).shape,len(dataset)) return dataset
def load_dataset(self, split, shuffle=True): """Load a dataset split.""" def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl) # infer langcode src, tgt = self.args.source_lang, self.args.target_lang if split_exists(split, src, tgt, src, self.args.data): prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt)) elif split_exists(split, tgt, src, src, self.args.data): prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src)) else: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) # def indexed_dataset(path, dictionary): # if self.args.raw_text: # return IndexedRawTextDataset(path, dictionary) # elif IndexedInMemoryDataset.exists(path): # return IndexedInMemoryDataset(path, fix_lua_indexing=True) # return None src_dataset = data_utils.load_indexed_dataset(prefix + src, self.src_dict, self.args.dataset_impl) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, self.tgt_dict, self.args.dataset_impl) # need to be updated with extractive summarization dataset self.datasets[split] = ExtractSumRobertaLongDataset( src_dataset, src_dataset.sizes, self.src_dict, tgt_dataset, tgt_dataset.sizes if tgt_dataset else None, self.tgt_dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, shuffle=shuffle, max_sent_len=self.args.max_sent_length, max_doc_len=self.args.max_doc_length, mask_other_sents=eval(self.args.mask_other_sents) )
def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) return dataset
def load_text_object_dataset(self, split, **kwargs): objects_dataset = ObjectDataset(self.args.data_dir, split, max_obj=self.args.max_obj) span_idxs = self.item2span_idxs(sent_num=objects_dataset.sent_num, max_src_sent=self.args.max_src_sent) text_file = text_bin_file(self.args.data_dir, split) # os.path.join(self.args.data_dir, split) text_dataset = data_utils.load_indexed_dataset(text_file, self.vocab_dict) self.datasets[split] = TextObjectDataset(text_dataset=text_dataset, image_dataset=objects_dataset, vocab_dict=self.vocab_dict, span_idxs=span_idxs, shuffle=True if split == "train" else False)
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens. block_size=511或者512 dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample, pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path)) s2s_dataset = MaskedLanguagePairDataset.apply_mask( dataset, dataset.sizes, self.source_dictionary, shuffle=True, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, ) self.datasets[split] = s2s_dataset
def load_denoise_dataset(self, data_path: str, lang: str) -> FairseqDataset: """Classic denoising dataset""" dataset = data_utils.load_indexed_dataset(data_path, self.common_dict, self.args.dataset_impl) noisy_dataset = NoisingDataset( dataset, self.dictionary, seed=1, max_word_shuffle_distance=self.args.max_word_shuffle_distance, word_dropout_prob=self.args.word_dropout_prob, word_blanking_prob=self.args.word_blanking_prob, ) noisy_dataset = PrependTokenDataset( noisy_dataset, _lang_token_index(self.dictionary, lang)) clean_dataset = data_utils.load_indexed_dataset( data_path, self.common_dict, self.args.dataset_impl) denoising_dataset = self._langpair_dataset(noisy_dataset, clean_dataset) denoising_dataset = self._prepend_lang_bos_to_target( denoising_dataset, lang) return denoising_dataset
def _load_single_lang_dataset(self, split, epoch): loaded_datasets = [] paths = self.args.data.split(":") assert len(paths) > 0 data_path = paths[epoch % len(paths)] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") path = os.path.join(data_path, split_k) ds = data_utils.load_indexed_dataset( path, self.dictionary, self.args.dataset_impl ) if ds is None: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, data_path) ) # Since we append each block with the classification_token, # we need to effectively create blocks of length # tokens_per_sample-1 loaded_datasets.append( TokenBlockDataset( ds, ds.sizes, self.args.tokens_per_sample - 1, pad=self.dictionary.pad(), eos=self.dictionary.eos(), ) ) print( "| {} {} {} examples".format( data_path, split_k, len(loaded_datasets[-1]) ) ) if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) return dataset, sizes
def load_binarized_bin(dict, input): dictionary = Dictionary.load(dict) if dict is not None else None dataset = data_utils.load_indexed_dataset( input, dictionary, default='lazy', ) for tensor_line in dataset: if dictionary is None: line = ' '.join([str(int(x)) for x in tensor_line]) else: line = dictionary.string(tensor_line) return line
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(os.pathsep) assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine ) if dataset is None: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, split_path) ) if self.args.truncate_sequence: dataset = TruncateDataset(dataset, self.args.tokens_per_sample) dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, ) add_eos_for_other_targets = ( self.args.sample_break_mode is not None and self.args.sample_break_mode != "none" ) self.datasets[split] = MonolingualDataset( dataset, dataset.sizes, self.dictionary, self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=False if hasattr(self.args, 'lm_eval') and self.args.lm_eval else True, targets=self.targets, add_bos_token=self.args.add_bos_token, )
def make_dataset(key, dictionary): split_path = get_path(key, split) try: dataset = data_utils.load_indexed_dataset( split_path, dictionary, combine=combine, ) except Exception as e: if "StorageException: [404] Path not found" in str(e): logger.warning(f"dataset {e} not found") dataset = None else: raise e return dataset
def desc_dataset(type, dictionary, relation_desc=None): now_path=get_path(type) #print(now_path) dataset=data_utils.load_indexed_dataset( now_path, dictionary, self.args.dataset_impl, combine=combine, ) if self.args.init_token is not None: dataset = PrependTokenDataset(dataset, self.args.init_token) if relation_desc is not None: dataset = ConcatSentencesDataset(dataset, relation_desc) dataset = TruncateDataset(dataset, self.args.tokens_per_sample) #??? dataset = RightPadDataset(dataset, pad_idx=self.source_dictionary.pad()) return dataset
def load_feature_dataset(self, split, **kwargs): features_dataset = FeatureDataset(self.args.data_dir, split) span_idxs = self.get_span_info(sent_num=features_dataset.sent_num, split=split) text_file = text_bin_file( self.args.data_dir, split) # os.path.join(self.args.data_dir, split) text_dataset = data_utils.load_indexed_dataset(text_file, self.vocab_dict) self.datasets[split] = MMITextImageDataset( text_dataset=text_dataset, image_dataset=features_dataset, vocab_dict=self.vocab_dict, span_idxs=span_idxs, shuffle=True if split == "train" else False)
def get_datasets_from_indexed_filterbanks(data_path, tgt_lang, tgt_dict, split, dataset_impl, skip_norm, legacy_audio_fix_lua_indexing): """ Creates a dataset reading precomputed filterbanks adn the corresponding target saved as indexed datasets. """ assert tgt_lang is not None prefix = os.path.join(data_path, split) src_dataset = FilterBanksDataset(prefix + ".npz", dataset_impl == "cached", legacy_audio_fix_lua_indexing) tgt_dataset = data_utils.load_indexed_dataset(prefix + "." + tgt_lang, tgt_dict, dataset_impl) return FilterBankToTextDataset(src_dataset, tgt_dataset, tgt_dict, skip_normalization=skip_norm)
def get_xlco_dataset(args, dataset_path, vocab, mask_idx, combine=False): dataset = data_utils.load_indexed_dataset(dataset_path, vocab, args.dataset_impl, combine=combine) dataset, _ = MaskTokensDataset.apply_mask( dataset, vocab=vocab, pad_idx=vocab.pad(), mask_idx=mask_idx, seed=args.seed, mask_prob=args.mask_prob, mask_whole_words=None, ) dataset = XlcoDataset(dataset, vocab) return dataset
def load_langpair_dataset(data_path, split, src, src_feat_roots, tgt, tgt_dict, dataset_impl, left_pad_source, left_pad_target, max_source_positions, max_target_positions, multilv_args, prepend_bos=False, load_alignments=False, truncate_source=False, use_bucketing=True): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt)) src_dataset = load_sign_dataset(prefix + src, src_feat_roots) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) assert len(src_dataset) == len(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format(data_path, split, src, tgt, len(src_dataset))) # if prepend_bos: # assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") # src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) # tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) align_dataset = None return SignLanguagePairDataset(src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset.sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=align_dataset, use_bucketing=use_bucketing, multilv_args=multilv_args)
def _load_dataset_split(self, split, epoch, combine): paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.cfg.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, split_path) ) dataset = StripTokenDataset(dataset, self.dictionary.eos()) dataset = maybe_shorten_dataset( dataset, split, self.cfg.shorten_data_split_list, self.cfg.shorten_method, self.cfg.tokens_per_sample, self.cfg.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.cfg.tokens_per_sample - 2, # one less for <s> and one for </s> pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.cfg.sample_break_mode, document_sep_len=0, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) return dataset
def main(): parser = get_parser() args = parser.parse_args() dictionary = Dictionary.load(args.dict) if args.dict is not None else None dataset = data_utils.load_indexed_dataset( args.input, dictionary, dataset_impl=args.dataset_impl, default="lazy", ) for tensor_line in dataset: if dictionary is None: line = " ".join([str(int(x)) for x in tensor_line]) else: line = dictionary.string(tensor_line) print(line)
def load_dataset(self, split, **kwargs): features_dir = os.path.join(self.args.features_dir, f'{split}-features-{self.args.features}') image_ids_file = os.path.join(self.args.captions_dir, f'{split}-ids.txt') image_ids = data.read_image_ids(image_ids_file, non_redundant=self.scst) if self.scst and split == 'valid': image_ids = image_ids[:self.args.scst_validation_set_size] if self.scst: captions_file = os.path.join(self.args.captions_dir, f'{split}-captions.tok.json') captions_ds = data.CaptionsDataset(captions_file, image_ids) else: captions_file = os.path.join( self.args.captions_dir, f'{split}-captions.{self.args.captions_lang}') captions_ds = data_utils.load_indexed_dataset( captions_file, self.captions_dict) if self.args.features == 'grid': image_ds = data.GridFeaturesDataset(features_dir, image_ids, grid_shape=(8, 8)) elif self.args.features == 'obj': image_metadata_file = os.path.join(features_dir, 'metadata.csv') image_metadata = data.read_image_metadata(image_metadata_file) image_ds = data.ObjectFeaturesDataset(features_dir, image_ids, image_metadata) else: raise ValueError( f'Invalid --features option: {self.args.features}') self.datasets[split] = data.ImageCaptionDataset( img_ds=image_ds, cap_ds=captions_ds, cap_dict=self.captions_dict, scst=self.scst, shuffle=True)
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ if self.args.dataset_from_json: raise NotImplementedError datasets = [] for path in self.paths: try: ds = get_datasets_from_indexed_filterbanks( path, self.args.target_lang, self.tgt_dict, split, self.args.dataset_impl, self.args.skip_normalization, self.args.legacy_audio_fix_lua_indexing) if self.training: if self.args.context_type == 'src': context_ds = FilterBanksDataset( os.path.join(path, split) + ".context.npz", self.args.dataset_impl == "cached", self.args.legacy_audio_fix_lua_indexing) else: context_ds = data_utils.load_indexed_dataset( os.path.join(path, split) + ".context." + self.args.target_lang, self.tgt_dict, self.args.dataset_impl) datasets.append(ContextAwareDataset( ds, context_ds, self.tgt_dict, self.args.context_type == 'src')) else: datasets.append(ds) except Exception: logger.warning("Split {} not found in {}. Skipping...".format(split, path)) assert len(datasets) > 0 if len(datasets) > 1: self.datasets[split] = ConcatDataset(datasets) else: self.datasets[split] = datasets[0]
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) self.datasets[split] = self.build_s2s_dataset(dataset)
def load_dataset(lang, lang_dict, prefix, dataset_length, sample_ratios=None): """ Function to load additional dataset and deal with all parameters. Easier than copying redudant code for each dataset. Requires src_dataset to provide the length and sample_ratios. """ lang_datasets = [] lang_dataset = data_utils.load_indexed_dataset(prefix + lang, lang_dict, dataset_impl) if lang_dataset is not None: lang_datasets.append(lang_dataset) assert dataset_length == len(lang_datasets) or len(lang_datasets) == 0 if dataset_length == 1: lang_dataset = lang_datasets[0] if len(lang_datasets) > 0 else None else: assert sample_ratios is not None if len(lang_datasets) > 0: lang_dataset = ConcatDataset(lang_datasets, sample_ratios) else: lang_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( lang_dict, "bos_index") if lang_dataset is not None: lang_dataset = PrependTokenDataset(lang_dataset, lang_dict.bos()) eos = None if append_source_id: if lang_dataset is not None: lang_dataset = AppendTokenDataset( lang_dataset, lang_dict.index('[{}]'.format(lang))) lang_dataset_sizes = lang_dataset.sizes if lang_dataset is not None else None return lang_dataset, lang_dataset_sizes
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) dataset = CodeCompletionDataset( dataset, dataset.sizes, self.dictionary, split_fn=self.split_fn, shuffle=(split != 'test'), max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, append_eos_to_source=True, append_eos_to_target=True, ) self.datasets[split] = dataset logger.info( "Split: {0}, Loaded {1} samples of CodeCompletionDataset".format( split, len(self.datasets[split]), ))