def _load_single_lang_dataset(self, split, epoch): loaded_datasets = [] paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') path = os.path.join(data_path, split_k) ds = data_utils.load_indexed_dataset(path, self.dictionary, self.args.dataset_impl) if ds is None: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format(split, data_path)) # Since we append each block with the classification_token, # we need to effectively create blocks of length # tokens_per_sample-1 loaded_datasets.append( TokenBlockDataset( ds, ds.sizes, self.args.tokens_per_sample - 1, pad=self.dictionary.pad(), eos=self.dictionary.eos(), )) print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1]))) if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) return dataset, sizes
def from_tsv(cls, root: str, data_cfg: S2TDataConfig, splits: str, tgt_dict, pre_tokenizer, bpe_tokenizer, is_train_split: bool, epoch: int, seed: int, audio_dict) -> AudioDictDataset: samples = [] _splits = splits.split(",") for split in _splits: tsv_path = op.join(root, f"{split}.tsv") if not op.isfile(tsv_path): raise FileNotFoundError(f"Dataset not found: {tsv_path}") with open(tsv_path) as f: reader = csv.DictReader( f, delimiter="\t", quotechar=None, doublequote=False, lineterminator="\n", quoting=csv.QUOTE_NONE, ) samples.append([dict(e) for e in reader]) assert len(samples) > 0 datasets = [ cls._from_list(name, is_train_split, [s], data_cfg, tgt_dict, pre_tokenizer, bpe_tokenizer, audio_dict) for name, s in zip(_splits, samples) ] if is_train_split and len( _splits) > 1 and data_cfg.sampling_alpha != 1.0: # temperature-based sampling size_ratios = cls._get_size_ratios(_splits, [len(s) for s in samples], alpha=data_cfg.sampling_alpha) datasets = [ ResamplingDataset(d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)) for d, r in zip(datasets, size_ratios) ] return ConcatDataset(datasets)
def from_tsv( cls, root: str, cfg: S2TDataConfig, splits: str, tgt_dict, pre_tokenizer, bpe_tokenizer, is_train_split: bool, epoch: int, seed: int, n_frames_per_step: int = 1, speaker_to_id=None, ) -> SpeechToTextDataset: datasets = [ cls._from_tsv( root, cfg, split, tgt_dict, is_train_split, pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, ) for split in splits.split(",") ] if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: # temperature-based sampling size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) datasets = [ ResamplingDataset(d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)) for r, d in zip(size_ratios, datasets) ] return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ if self.args.dataset_from_json: raise NotImplementedError datasets = [] for path in self.paths: try: ds = get_datasets_from_indexed_filterbanks( path, self.args.target_lang, self.tgt_dict, split, self.args.dataset_impl, self.args.skip_normalization, self.args.legacy_audio_fix_lua_indexing) if self.training: if self.args.context_type == 'src': context_ds = FilterBanksDataset( os.path.join(path, split) + ".context.npz", self.args.dataset_impl == "cached", self.args.legacy_audio_fix_lua_indexing) else: context_ds = data_utils.load_indexed_dataset( os.path.join(path, split) + ".context." + self.args.target_lang, self.tgt_dict, self.args.dataset_impl) datasets.append(ContextAwareDataset( ds, context_ds, self.tgt_dict, self.args.context_type == 'src')) else: datasets.append(ds) except Exception: logger.warning("Split {} not found in {}. Skipping...".format(split, path)) assert len(datasets) > 0 if len(datasets) > 1: self.datasets[split] = ConcatDataset(datasets) else: self.datasets[split] = datasets[0]
def load_dataset(lang, lang_dict, prefix, dataset_length, sample_ratios=None): """ Function to load additional dataset and deal with all parameters. Easier than copying redudant code for each dataset. Requires src_dataset to provide the length and sample_ratios. """ lang_datasets = [] lang_dataset = data_utils.load_indexed_dataset(prefix + lang, lang_dict, dataset_impl) if lang_dataset is not None: lang_datasets.append(lang_dataset) assert dataset_length == len(lang_datasets) or len(lang_datasets) == 0 if dataset_length == 1: lang_dataset = lang_datasets[0] if len(lang_datasets) > 0 else None else: assert sample_ratios is not None if len(lang_datasets) > 0: lang_dataset = ConcatDataset(lang_datasets, sample_ratios) else: lang_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( lang_dict, "bos_index") if lang_dataset is not None: lang_dataset = PrependTokenDataset(lang_dataset, lang_dict.bos()) eos = None if append_source_id: if lang_dataset is not None: lang_dataset = AppendTokenDataset( lang_dataset, lang_dict.index('[{}]'.format(lang))) lang_dataset_sizes = lang_dataset.sizes if lang_dataset is not None else None return lang_dataset, lang_dataset_sizes
def __load_dataset(self, split, lang_pair): src, tgt = lang_pair.split('-') datasets = [] for path in self.paths: try: ds = get_datasets_from_indexed_filterbanks( path, tgt, self.dicts[tgt], split, self.args.dataset_impl, self.args.skip_normalization, self.args.legacy_audio_fix_lua_indexing) datasets.append(ds) except Exception: logger.warning("Split {} not found in {}. Skipping...".format( split, path)) assert len(datasets) > 0 if len(datasets) > 1: dataset = ConcatDataset(datasets) else: dataset = datasets[0] return self.alter_dataset_langtok(dataset, src_eos=None, src_lang=src, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt)
def load_dataset(self, split, combine=False): """ Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ loaded_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') path = os.path.join(self.args.data, split_k) if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = IndexedRawTextDataset(path, self.dictionary) elif not self.args.raw_text and IndexedDataset.exists(path): if self.args.lazy_load: ds = IndexedDataset(path, fix_lua_indexing=True) else: ds = IndexedCachedDataset(path, fix_lua_indexing=True) else: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format( split, self.args.data)) with data_utils.numpy_seed(self.seed + k): loaded_datasets.append( BlockPairDataset( ds, self.dictionary, ds.sizes, self.args.tokens_per_sample, break_mode=self.args.break_mode, )) logger.info('{} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) if not combine: break if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) self.datasets[split] = MaskedLMDataset( dataset=dataset, sizes=sizes, vocab=self.dictionary, pad_idx=self.dictionary.pad(), mask_idx=self.dictionary.mask(), classif_token_idx=self.dictionary.cls(), sep_token_idx=self.dictionary.sep(), shuffle=False, seed=self.seed, )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(":") assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) if self.langs is None: languages = sorted( [ name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name)) ] ) else: languages = self.langs.split(",") for name in languages: p = os.path.join(data_path, name) assert os.path.exists(p), "data not found: {}".format(p) logger.info("Training on {0} languages: {1}".format(len(languages), languages)) logger.info( "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)} ) mask_whole_words = get_whole_word_mask(self.args, self.dictionary) language_without_segmentations = self.args.no_whole_word_mask_langs.split(",") lang_datasets = [] for language in languages: split_path = os.path.join(data_path, language, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, split_path) ) end_token = ( self.source_dictionary.index("[{}]".format(language)) if self.args.add_lang_token else self.source_dictionary.eos() ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 2, # one less for <s> pad=self.source_dictionary.pad(), eos=end_token, break_mode=self.args.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, end_token) lang_mask_whole_words = ( mask_whole_words if language not in language_without_segmentations else None ) lang_dataset = DenoisingDataset( dataset, dataset.sizes, self.dictionary, self.mask_idx, lang_mask_whole_words, shuffle=self.args.shuffle_instance, seed=self.seed, args=self.args, eos=None if not self.args.add_lang_token else self.source_dictionary.index("[{}]".format(language)), ) lang_datasets.append(lang_dataset) dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) logger.info( "loaded total {} blocks for all languages".format( int(dataset_lengths.sum()), ) ) if split == self.args.train_subset: # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) logger.info( "Sample probability by language: {}".format( { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) } ) ) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths logger.info( "Up/Down Sampling ratio by language: {}".format( { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) } ) ) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] >= 1.0, ) for i, d in enumerate(lang_datasets) ] dataset = ConcatDataset( resampled_lang_datasets, ) else: dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): split_name = split + "_" + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( split, ",".join(lang_splits) ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )
def get_asr_dataset_from_json( data_path, split, tgt_dict, combine, upsample_primary, max_source_positions, max_target_positions, seed=1, specaugment_config=None, ): """ Parse data json and create dataset. See espresso/tools/asr_prep_json.py which pack json from raw files Json example: { "011c0202": { "feat": "fbank/raw_fbank_pitch_train_si284.1.ark:54819", "token_text": "T H E <space> H O T E L", "utt2num_frames": "693", }, "011c0203": { ... } } """ src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") data_json_path = os.path.join(data_path, "{}.json".format(split_k)) if not os.path.isfile(data_json_path): if k > 0: break else: raise FileNotFoundError( "Dataset not found: {}".format(data_json_path)) with open(data_json_path, "rb") as f: loaded_json = json.load(f, object_pairs_hook=OrderedDict) utt_ids, feats, token_text, utt2num_frames = [], [], [], [] for utt_id, val in loaded_json.items(): utt_ids.append(utt_id) feats.append(val["feat"]) if "token_text" in val: token_text.append(val["token_text"]) if "utt2num_frames" in val: utt2num_frames.append(int(val["utt2num_frames"])) assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames) src_datasets.append( FeatScpCachedDataset( utt_ids, feats, utt2num_frames=utt2num_frames, seed=seed, specaugment_config=specaugment_config if split == "train" else None, ordered_prefetch=True, )) if len(token_text) > 0: assert len(utt_ids) == len(token_text) assert tgt_dict is not None tgt_datasets.append(AsrTextDataset(utt_ids, token_text, tgt_dict)) logger.info("{} {} examples".format(data_json_path, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 feat_dim = src_datasets[0].feat_dim if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: for i in range(1, len(src_datasets)): assert feat_dim == src_datasets[i].feat_dim, \ "feature dimension does not match across multiple json files" sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return AsrDataset( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=False, left_pad_target=False, max_source_positions=max_source_positions, max_target_positions=max_target_positions, )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def split_exists(split, data_type, data_path): filename = os.path.join(data_path, f'{split}.{data_type}') assert not self.args.raw_text exists = [ IndexedDataset.exists( os.path.join(data_path, f'{split}.{data_type}.{k}')) for k in DPTREE_KEYS ] if all(exists): return True else: print(f'Following modality not exists: {exists}') return False # def indexed_dataset(path, dictionary): def indexed_dataset(path): assert IndexedCachedDataset.exists( path), f'IndexedCachedDataset.exists({path})' return IndexedCachedDataset(path, fix_lua_indexing=True) def dptree_indexed_dataset(path): assert DPTreeIndexedCachedDataset.exists( path), f'DPTreeIndexedCachedDataset.exists({path})' return DPTreeIndexedCachedDataset(path, fix_lua_indexing=True) src_datasets = [] tgt_datasets = [] src_datasets_dict = {k: [] for k in DPTREE_KEYS} # data_paths = self.args.data data_path = self.args.data print(f'| split = {split}') print(f'| self.args.data = {self.args.data}') # singular data path lang = self.args.source_lang src, tgt = 'input', 'target' for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') if split_exists(split_k, src, data_path): prefix = os.path.join(data_path, f'{split}.') else: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format(split, data_path)) # src_datasets.append(indexed_dataset(prefix + src)) for modality in src_datasets_dict.keys(): src_datasets_dict[modality].append( dptree_indexed_dataset(f'{prefix}{src}.{modality}')) tgt_datasets.append(indexed_dataset(prefix + tgt)) print('| {} {} {} examples'.format(data_path, split_k, len(tgt_datasets[-1]))) if not combine: break assert len(src_datasets_dict[DPTREE_KEYS[0]]) == len(tgt_datasets) if len(tgt_datasets) == 1: # src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] src_dataset_dict = {k: v[0] for k, v in src_datasets_dict.items()} tgt_dataset = tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = self.args.upsample_primary # src_dataset = ConcatDataset(src_datasets, sample_ratios) src_dataset_dict = { k: ConcatDataset(v, sample_ratios) for k, v in src_datasets_dict.items() } tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) src_sizes = src_dataset_dict['nodes'].sizes.reshape(-1, 2).sum(-1) # print(f'src_sizes::: {src_sizes}') self.datasets[split] = NodeStackFromDPTreeSepMonoClassificationDataset( # srcs, src_sizes, src_dict src_dataset_dict, src_sizes, self.source_dictionary, tgt_dataset, left_pad_source=self.args.left_pad_source, # left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, # max_target_positions=self.args.max_target_positions, )
def load_dataset(self, split, combine=False, epoch=0): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def split_exists(split, src, tgt, lang, data_path): filename = os.path.join( data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedDataset.exists(filename): return True return False def indexed_dataset(path, dictionary): if self.args.raw_text: return IndexedRawTextDataset(path, dictionary) elif IndexedDataset.exists(path): return IndexedCachedDataset(path, fix_lua_indexing=True) return None src_datasets = [] tgt_datasets = [] data_paths = self.args.data for dk, data_path in enumerate(data_paths): for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode src, tgt = self.args.source_lang, self.args.target_lang if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join( data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join( data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0 or dk > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format( split, data_path)) src_datasets.append( indexed_dataset(prefix + src, self.src_dict)) tgt_datasets.append( indexed_dataset(prefix + tgt, self.tgt_dict)) print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = self.args.upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) if split == "train": train = True seed = None elif split == "valid": train = True seed = 1 elif split == "test": train = False seed = 1 else: raise Exception('No such split: ' + str(split)) self.datasets[split] = LanguagePairSelfDatasetMask( src_dataset, src_dataset.sizes, self.src_dict, tgt_dataset, tgt_dataset.sizes, self.tgt_dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, shuffle=False, dynamic_length=self.args.dynamic_length, mask_range=self.args.mask_range, train=train, seed=seed, full_masking=self.args.full_masking, dynamic_masking=self.args.dynamic_masking, skip_eos=self.args.skip_eos, )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def split_exists(split, src, tgt, lang, data_path): filename = os.path.join( data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedDataset.exists(filename): return True return False def indexed_dataset(path, dictionary, copy_ext_dict=False, src_dataset=None): if self.args.raw_text: return IndexedRawTextDataset(path, dictionary, copy_ext_dict=copy_ext_dict, src_dataset=src_dataset) elif IndexedDataset.exists(path): if self.args.lazy_load: return IndexedDataset(path, fix_lua_indexing=True) else: return IndexedCachedDataset(path, fix_lua_indexing=True) return None def indexed_label(path): if IndexedRawLabelDataset.exists(path): return IndexedRawLabelDataset(path) else: print('Label file not found: {}'.format(path)) return None src_datasets = [] tgt_datasets = [] src_labels = [] tgt_labels = [] data_paths = self.args.data # 如果有其它文件,请按照train1, train2等命名 for dk, data_path in enumerate(data_paths): for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode src, tgt = self.args.source_lang, self.args.target_lang if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join( data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join( data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0 or dk > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format( split, data_path)) src_dataset = indexed_dataset(prefix + src, self.src_dict, self.args.copy_ext_dict) tgt_dataset = indexed_dataset(prefix + tgt, self.tgt_dict, self.args.copy_ext_dict, src_dataset) # src_dataset 包括 lines, sizes, tokens_list, words_list src_datasets.append(src_dataset) tgt_datasets.append(tgt_dataset) #label的索引 label_prefix = os.path.join(data_path, '{}.label.'.format(split_k)) src_label = indexed_label(label_prefix + src + '.txt') tgt_label = indexed_label(label_prefix + tgt + '.txt') src_labels.append(src_label) tgt_labels.append(tgt_label) print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) src_label, tgt_label = None, None if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] src_label, tgt_label = src_labels[0], tgt_labels[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = self.args.upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) #源数据和目标数据组成的数据集 self.datasets[split] = LanguagePairDataset( src_dataset, src_dataset.sizes, self.src_dict, src_label, tgt_dataset, tgt_dataset.sizes, self.tgt_dict, tgt_label, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, pad_to_multiple=1, add_lang_token=False, ): def split_exists(split, src, tgt, lang, data_path): logger.info( os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))) filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) def split_exists_self(split, src, data_path): logger.info( os.path.join(data_path, "{}.{}-{}.{}".format(split, src, src, src))) filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, src, src)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) def split_exists_valid(split, lang, data_path): logger.info(os.path.join(data_path, "{}.{}".format(split, lang))) filename = os.path.join(data_path, "{}.{}".format(split, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") # print(split_k, src, tgt, src, data_path) prefix_src = None prefix_tgt = None if not "-" in split_k: # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({}) {} {}".format( split, data_path, src, tgt)) else: # infer langcode if split_exists_valid(split_k, src, data_path): prefix = os.path.join(data_path, split_k + ".") else: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({}) ".format(split, data_path)) if prefix_src != None: prefix = prefix_src src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) if prefix_tgt != None: prefix = prefix_tgt tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info("{} {} {}-{} {} examples".format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary logger.info("::::data sample_ratios:{}".format(sample_ratios)) src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index("[{}]".format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index("[{}]".format(tgt))) eos = tgt_dict.index("[{}]".format(tgt)) eos = None if add_lang_token: src_dataset = PrependTokenDataset(src_dataset, src_dict.index("[{}]".format(src))) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset( tgt_dataset, tgt_dict.index("[{}]".format(tgt))) align_dataset = None if load_alignments: align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, )
def load_generation_pair_dataset( data_path, split, tgt, src_dict, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, common_eos=None ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, "src", "tgt", "src", data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, "src", "tgt")) elif split_exists(split_k, "tgt", "src", "src", data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, "tgt", "src")) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + "src", src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + "tgt", tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info('{} {} {}-{} {} examples'.format( data_path, split_k, "src", "tgt", len(src_datasets[-1]) )) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: if common_eos is not None: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(common_eos))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset(tgt_dataset, tgt_dict.index('[{}]'.format(common_eos))) eos = tgt_dict.index('[{}]'.format(common_eos)) bos = tgt_dict.index('[{}]'.format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, "src", "tgt")) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return GenerationPairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=align_dataset, eos=eos, bos=bos )
def pos_loader(data_path, split, src, src_dict, tgt, tgt_dict, anchor, anchor_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, truncate_source=False, append_source_id=False): # Check the existence of the file def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] anchor_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode (from a->b or from b->a) if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) anchor_prefix = os.path.join(data_path, anchor, '{}.{}-{}.'.format(split_k, anchor, tgt)) anchor_dataset = data_utils.load_indexed_dataset( anchor_prefix + anchor, anchor_dict, dataset_impl) if anchor_dataset is not None: anchor_datasets.append(anchor_dataset) logger.info('{} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 # None is not avaliable for anchors assert len(src_datasets) == len(anchor_datasets) if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None anchor_dataset = anchor_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None anchor_dataset = ConcatDataset(anchor_datasets, sample_ratios) if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr( tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) eos = None if append_source_id: src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src))) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index('[{}]'.format(tgt))) eos = tgt_dict.index('[{}]'.format(tgt)) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return POSGraphLanguagePairDatasetb( src_dataset, src_dataset.sizes, src_dict, anchor_dataset, anchor_dataset.sizes, anchor_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, eos=eos)
def load_dataset(self, split: str, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ languages, data_path = MultilingualLanguageModelingTask._get_langs( self.args, epoch) lang_to_offline_shard_ratio = None if self.args.lang_to_offline_shard_ratio != "": lang_to_offline_shard_ratio = {} assert os.path.exists( self.args.lang_to_offline_shard_ratio ), "provided offline shard ratio file doesn't exist: {0}".format( self.args.lang_to_offline_shard_ratio) with open(self.args.lang_to_offline_shard_ratio) as fin: for line in fin: lang, ratio = line.strip().split("\t") ratio = float(ratio) lang_to_offline_shard_ratio[lang] = ratio logger.info( "Found offline sharded ratio: %s", lang_to_offline_shard_ratio, ) if split == self.args.train_subset: logger.info("Training on {0} languages: {1}".format( len(languages), languages)) else: logger.info("Evaluating on {0} languages: {1}".format( len(languages), languages)) tokens_per_sample = self.args.tokens_per_sample - int( self.args.add_bos_token) fixed_pad_length = None if self.args.pad_to_fixed_length: fixed_pad_length = self.args.tokens_per_sample pad_to_bsz = None if self.args.pad_to_fixed_bsz: pad_to_bsz = (self.args.batch_size_valid if "valid" in split else self.args.batch_size) lang_datasets = [] for lang_id, language in enumerate(languages): split_path = os.path.join(data_path, language, split) dataset = data_utils.load_indexed_dataset(split_path, self.dictionary, self.args.dataset_impl, combine=combine) # print('len(dataset) =', len(dataset)) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, tokens_per_sample, self.args.seed, ) dataset = TokenBlockDataset( dataset, dataset.sizes, tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, ) add_eos_for_other_targets = ( self.args.sample_break_mode is not None and self.args.sample_break_mode != "none") src_lang_idx, tgt_lang_idx = None, None if self.args.add_bos_token: src_lang_idx = self.dictionary.index(lang_token(language)) tgt_lang_idx = self.output_dictionary.index( lang_token(language)) lang_datasets.append( MonolingualDataset( dataset=dataset, sizes=dataset.sizes, src_vocab=self.dictionary, tgt_vocab=self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, fixed_pad_length=fixed_pad_length, pad_to_bsz=pad_to_bsz, add_bos_token=self.args.add_bos_token, src_lang_idx=src_lang_idx, tgt_lang_idx=tgt_lang_idx, )) dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) logger.info("loaded total {} blocks for all languages".format( dataset_lengths.sum(), )) if split == self.args.train_subset: dataset_lengths_ratio_multiplier = np.ones(len(dataset_lengths)) if lang_to_offline_shard_ratio is not None: dataset_lengths_ratio_multiplier = [] for lang in languages: assert ( lang in lang_to_offline_shard_ratio ), "Lang: {0} missing in offline shard ratio file: {1}".format( lang, self.args.lang_to_offline_shard_ratio, ) dataset_lengths_ratio_multiplier.append( lang_to_offline_shard_ratio[lang]) dataset_lengths_ratio_multiplier = np.array( dataset_lengths_ratio_multiplier) true_dataset_lengths = (dataset_lengths * dataset_lengths_ratio_multiplier) else: true_dataset_lengths = dataset_lengths # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(true_dataset_lengths) logger.info( "Sample probability by language: %s", { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) }, ) size_ratio = (sample_probs * true_dataset_lengths.sum()) / dataset_lengths # TODO: add an option for shrinking all size ratios to below 1 # if self.args.multilang_sampling_alpha != 1: # size_ratio /= size_ratio.max() # Fix numeric errors in size ratio computation # 0.999999999999999999 -> 1 # 1.000000000000000002 -> 1 for i in range(len(size_ratio)): size_ratio[i] = round(size_ratio[i], 8) logger.info( "Up/Down Sampling ratio by language: %s", { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) }, ) logger.info( "Actual dataset size by language: %s", { lang: "{0:.2f}".format(len(lang_datasets[id])) for id, lang in enumerate(languages) }, ) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] > 1.0, ) for i, d in enumerate(lang_datasets) ] logger.info( "Resampled dataset size by language: %s", { lang: "{0:.2f}".format(len(resampled_lang_datasets[id])) for id, lang in enumerate(languages) }, ) dataset = ConcatDataset(resampled_lang_datasets) else: dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): split_name = split + "_" + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset # [TODO]: This is hacky for now to print validation ppl for each # language individually. Maybe need task API changes to allow it # in more generic ways. if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( split, ",".join(lang_splits)) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, srcda=False, srcda_choice='uniform', tgtda=False, tgtda_choice='uniform' ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_datasets.append( data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) ) print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) align_dataset = None if load_alignments: align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl) return LanguagePairDatasetDA( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, align_dataset=align_dataset, srcda=srcda, srcda_choice=srcda_choice, tgtda=tgtda, tgtda_choice=tgtda_choice )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 if self.args.multiple_datasets: if len(paths) == 1: paths = [ os.path.join(paths[0], p) for p in next(os.walk(paths[0]))[1] ] datasets = [ ShardedDataset( self.dictionary, self.args.dataset_impl, path, split, epoch - 1, combine=combine, ) for path in paths ] ds_names = [ds.name for ds in datasets] if split in self.subsample_splits: sizes = [sum(d.sizes) for d in datasets] min_sz = min(sizes) ratios = [min_sz / sz for sz in sizes] datasets = [ SubsampleDataset(d, r) if r < 1 else d for d, r in zip(datasets, ratios) ] dataset = ConcatDataset(datasets) else: data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset(split_path, self.dictionary, self.args.dataset_impl, combine=combine) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) ds_names = [None] dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, ) if self.args.prepend_ds_name: dataset = PrependDataset( dataset, prepend_getter=ds_name_getter( offset=0, generic_ds_name_chance=self.args.generic_ds_name_chance, dictionary=self.dictionary, ), ensure_first_token_is=self.dictionary.eos(), ) dataset = ReplaceDataset( dataset, replace_map={ self.dictionary.eos(): self.dictionary.indices["\\n"] }, offsets=[1, -1], ) add_eos_for_other_targets = (self.args.sample_break_mode is not None and self.args.sample_break_mode != "none") dataset = MonolingualDataset( dataset, dataset.sizes, self.dictionary, self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, add_bos_token=self.args.add_bos_token, ) if self.args.colorize_ds_name: ds_names.append("generic") min_ds = min(self.dictionary.indices[n] for n in ds_names) dataset = ColorizeDataset( dataset, color_getter=ds_name_getter( offset=-min_ds, generic_ds_name_chance=self.args.generic_ds_name_chance, dictionary=self.dictionary, ), ) self.datasets[split] = dataset
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" if self.cfg.data.endswith("1"): data_shard = (epoch - 1) % self.cfg.num_data_splits + 1 data_path = self.cfg.data[:-1] + str(data_shard) else: data_path = self.cfg.data def get_path(type, data_split): return os.path.join(data_path, str(type), data_split) def make_dataset(type, dictionary, data_split, combine): split_path = get_path(type, data_split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, combine=combine, ) return dataset def load_split(data_split, metric): input_src = None if self.cfg.include_src: input_src = make_dataset("input_src", self.dictionary, data_split, combine=False) assert input_src is not None, "could not find dataset: {}".format( get_path("input_src", data_split)) input_tgt = make_dataset("input_tgt", self.dictionary, data_split, combine=False) assert input_tgt is not None, "could not find dataset: {}".format( get_path("input_tgt", data_split)) label_path = f"{get_path(metric, data_split)}.{metric}" assert os.path.exists( label_path), f"could not find dataset: {label_path}" np_labels = np.loadtxt(label_path) if self.cfg.target_metric == "ter": np_labels = -np_labels label = RawLabelDataset(np_labels) return input_src, input_tgt, label src_datasets = [] tgt_datasets = [] label_datasets = [] if split == self.cfg.train_subset: for k in itertools.count(): split_k = "train" + (str(k) if k > 0 else "") prefix = os.path.join(data_path, "input_tgt", split_k) if not indexed_dataset.dataset_exists(prefix, impl=None): if k > 0: break else: raise FileNotFoundError(f"Dataset not found: {prefix}") input_src, input_tgt, label = load_split( split_k, self.cfg.target_metric) src_datasets.append(input_src) tgt_datasets.append(input_tgt) label_datasets.append(label) else: input_src, input_tgt, label = load_split(split, self.cfg.target_metric) src_datasets.append(input_src) tgt_datasets.append(input_tgt) label_datasets.append(label) if len(tgt_datasets) == 1: input_tgt, label = tgt_datasets[0], label_datasets[0] if self.cfg.include_src: input_src = src_datasets[0] else: input_tgt = ConcatDataset(tgt_datasets) label = ConcatDataset(label_datasets) if self.cfg.include_src: input_src = ConcatDataset(src_datasets) input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions) if self.cfg.include_src: input_src = PrependTokenDataset(input_src, self.dictionary.bos()) input_src = TruncateDataset(input_src, self.cfg.max_positions) src_lengths = NumelDataset(input_src, reduce=False) src_tokens = ConcatSentencesDataset(input_src, input_tgt) else: src_tokens = PrependTokenDataset(input_tgt, self.dictionary.bos()) src_lengths = NumelDataset(src_tokens, reduce=False) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), "src_lengths": src_lengths, }, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), "target": label, } dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) assert len(dataset) % self.cfg.mt_beam == 0, ( "dataset size (%d) is not a multiple of beam size (%d)" % (len(dataset), self.cfg.mt_beam)) # no need to shuffle valid/test sets if not self.cfg.no_shuffle and split == self.cfg.train_subset: # need to keep all hypothese together start_idx = np.arange(0, len(dataset), self.cfg.mt_beam) with data_utils.numpy_seed(self.cfg.seed + epoch): np.random.shuffle(start_idx) idx = np.arange(0, self.cfg.mt_beam) shuffle = np.tile(idx, (len(start_idx), 1)).reshape(-1) + np.tile( start_idx, (self.cfg.mt_beam, 1)).transpose().reshape(-1) dataset = SortDataset( dataset, sort_order=[shuffle], ) logger.info(f"Loaded {split} with #samples: {len(dataset)}") self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedDataset.exists(filename): return True return False def indexed_dataset(path, dictionary): if self.args.raw_text: raise NotImplementedError elif IndexedDataset.exists(path): return DPTreeIndexedCachedDataset(path, fix_lua_indexing=True) return None src_datasets_dict = {k: [] for k in NSTACK_KEYS} tgt_datasets = [] data_paths = self.args.data print(f'| split = {split}') print(f'| self.args.data = {self.args.data}') for dk, data_path in enumerate(data_paths): for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode src, tgt = self.args.source_lang, self.args.target_lang if split_exists(split_k, src, tgt, tgt, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, tgt, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0 or dk > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) for modality in src_datasets_dict.keys(): src_datasets_dict[modality].append(indexed_dataset(f'{prefix}{src}.{modality}', self.src_dict)) # src_datasets.append(indexed_dataset(prefix + src, self.src_dict)) tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict)) print('| {} {} {} examples'.format(data_path, split_k, len(tgt_datasets[-1]))) if not combine: break assert len(src_datasets_dict[NSTACK_KEYS[0]]) == len(tgt_datasets) if len(tgt_datasets) == 1: # src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] src_dataset_dict = {k: v[0] for k, v in src_datasets_dict.items()} tgt_dataset = tgt_datasets[0] else: sample_ratios = [1] * len(tgt_datasets) sample_ratios[0] = self.args.upsample_primary # src_dataset = ConcatDataset(src_datasets, sample_ratios) src_dataset_dict = {k: ConcatDataset(v, sample_ratios) for k, v in src_datasets_dict.items()} tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) # src_sizes = src_dataset_dict['nodes'].sizes # src_sizes = src_dataset_dict['nodes'].sizes.reshape(-1, 2).sum(-1) leave_shape = src_dataset_dict['leaves'].sizes.reshape(-1, 2) node_shape = src_dataset_dict['nodes'].sizes.reshape(-1, 2) # leaves_sizes = leave_shape.sum(-1) # nodes_sizes = node_shape.sum(-1) leaves_sizes = leave_shape.prod(-1) nodes_sizes = node_shape.prod(-1) # print(f'| FIXED VERSION, size must be prod(-1)') src_sizes = leaves_sizes + nodes_sizes src_nsents = leave_shape[:, 0] # print(f'Some leave_size: {leave_shape[:10]}') # print(f'Some src_nsent: ({src_nsents[:10]})') self.datasets[split] = Nstack2SeqPairDataset( src_dataset_dict, src_sizes, self.src_dict, src_nsents, tgt_dataset, tgt_dataset.sizes, self.tgt_dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, remove_eos_from_source=self.args.remove_eos_from_source, append_eos_to_target=self.args.append_eos_to_target, input_feeding=self.args.input_feeding, is_infer=self.args.infer_mode )
def load_lang_dataset( self, data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, max_source_positions, prepend_bos=False, load_alignments=False, truncate_source=False, ): src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") # infer langcode if self.split_exists(split_k, src, tgt, src, data_path, dataset_impl): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) elif self.split_exists(split_k, tgt, src, src, data_path, dataset_impl): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) else: if k > 0: break else: logger.error( f"Dataset not found: {data_path}, {split_k}, {src}, {tgt}" ) raise FileNotFoundError( "Dataset not found: {} ({})".format(split, data_path) ) src_dataset = self.load_data(prefix + src, src_dict, dataset_impl) if truncate_source: src_dataset = AppendTokenDataset( TruncateDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_datasets.append(self.load_data(prefix + tgt, tgt_dict, dataset_impl)) logger.info( "{} {} {}-{} {} examples".format( data_path, split_k, src, tgt, len(src_datasets[-1]) ) ) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) align_dataset = None if load_alignments: align_path = os.path.join( data_path, "{}.align.{}-{}".format(split, src, tgt) ) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl ) return src_dataset, tgt_dataset, align_dataset
def get_asr_dataset_from_json( data_path, split, tgt_dict, combine, upsample_primary=1, num_buckets=0, shuffle=True, pad_to_multiple=1, seed=1, global_cmvn_stats_path=None, specaugment_config=None, ): """ Parse data json and create dataset. See espresso/tools/asr_prep_json.py which pack json from raw files Json example: { "011c0202": { "feat": "fbank/raw_fbank_pitch_train_si284.1.ark:54819" or "wave": "/export/corpora5/LDC/LDC93S6B/11-1.1/wsj0/si_tr_s/011/011c0202.wv1" or "command": "sph2pipe -f wav /export/corpora5/LDC/LDC93S6B/11-1.1/wsj0/si_tr_s/011/011c0202.wv1 |", "text": "THE HOTEL", "utt2num_frames": "693", }, "011c0203": { ... } } """ src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") data_json_path = os.path.join(data_path, "{}.json".format(split_k)) if not os.path.isfile(data_json_path): if k > 0: break else: raise FileNotFoundError( "Dataset not found: {}".format(data_json_path) ) with open(data_json_path, "rb") as f: loaded_json = json.load(f, object_pairs_hook=OrderedDict) utt_ids, audios, texts, utt2num_frames = [], [], [], [] for utt_id, val in loaded_json.items(): utt_ids.append(utt_id) if "feat" in val: audio = val["feat"] elif "wave" in val: audio = val["wave"] elif "command" in val: audio = val["command"] else: raise KeyError( f"'feat', 'wave' or 'command' should be present as a field for the entry {utt_id} in {data_json_path}" ) audios.append(audio) if "text" in val: texts.append(val["text"]) if "utt2num_frames" in val: utt2num_frames.append(int(val["utt2num_frames"])) assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames) if "feat" in next(iter(loaded_json.items())): extra_kwargs = {} else: extra_kwargs = {"feat_dim": 80, "feature_type": "fbank"} if global_cmvn_stats_path is not None: feature_transforms_config = { "transforms": ["global_cmvn"], "global_cmvn": {"stats_npz_path": global_cmvn_stats_path} } extra_kwargs["feature_transforms_config"] = feature_transforms_config src_datasets.append(AudioFeatDataset( utt_ids, audios, utt2num_frames=utt2num_frames, seed=seed, specaugment_config=specaugment_config if split == "train" else None, **extra_kwargs )) if len(texts) > 0: assert len(utt_ids) == len(texts) assert tgt_dict is not None tgt_datasets.append(AsrTextDataset(utt_ids, texts, tgt_dict)) logger.info("{} {} examples".format(data_json_path, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 feat_dim = src_datasets[0].feat_dim if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: for i in range(1, len(src_datasets)): assert ( feat_dim == src_datasets[i].feat_dim ), "feature dimension does not match across multiple json files" sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return AsrDataset( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=False, left_pad_target=False, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ loaded_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') path = os.path.join(self.args.data, split_k) if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = IndexedRawTextDataset(path, self.dictionary) elif not self.args.raw_text and IndexedDataset.exists(path): if self.args.lazy_load: ds = IndexedDataset(path, fix_lua_indexing=True) else: ds = IndexedCachedDataset(path, fix_lua_indexing=True) else: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format( split, self.args.data)) loaded_datasets.append( TokenBlockDataset( ds, ds.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, )) print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) if not combine: break if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none' self.datasets[split] = MonolingualDataset( dataset, sizes, self.dictionary, self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, ter, xml_dico, xml_params, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, shuffle=True, task='translation_qe' ): def split_exists(split, src, tgt, lang, data_path): # filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) # return indexed_dataset.dataset_exists(filename, impl=dataset_impl) filename = os.path.join(data_path, split + '.') print(filename + src) return os.path.exists(filename + src) src_datasets = [] tgt_datasets = [] ter_datasets = [] xml_datasets = [] word_tag_datasets = [] gap_tag_datasets = [] bpe_tag_datasets = [] xml_bpe_tag_datasets = [] src_word_tag_datasets = [] src_bpe_tag_datasets = [] xml_src_bpe_tag_datasets = [] # tgt_datasets_xml = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') filename = os.path.join(data_path, split + '.') print(filename + src) # infer langcode if split_exists(split_k, src, tgt, src, data_path): # prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) prefix = os.path.join(data_path, split_k + '.') prefix_xml = os.path.join(data_path, 'xml_data/' + split_k + '.') elif split_exists(split_k, tgt, src, src, data_path): # prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) prefix = os.path.join(data_path, split_k + '.') prefix_xml = os.path.join(data_path, 'xml_data/' + split_k + '.') else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) src_datasets.append( data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) ) tgt_datasets.append( data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl) ) ter_datasets.append( torch.from_numpy(np.loadtxt(prefix + ter)) ) xml_datasets.append( data_utils.load_indexed_dataset(prefix_xml + src, xml_dico, dataset_impl, path_xml=prefix_xml + tgt) ) if task == 'translation_qe_word': word_tag_datasets.append( data_utils.load_word_qe_tags(prefix + 'word_tags') ) gap_tag_datasets.append( data_utils.load_word_qe_tags(prefix + 'gap_tags') ) src_word_tag_datasets.append( data_utils.load_word_qe_tags(prefix + 'src_word_tags') ) bpe_tag_datasets.append( data_utils.load_bpe_tags(prefix + 'bpe') ) src_bpe_tag_datasets.append( data_utils.load_bpe_tags(prefix + 'src_bpe') ) xml_bpe_tag_datasets.append( data_utils.load_bpe_tags(prefix_xml + 'bpe') ) xml_src_bpe_tag_datasets.append( data_utils.load_bpe_tags(prefix_xml + 'src_bpe') ) # tgt_datasets_xml.append( # data_utils.load_indexed_dataset(prefix_xml + tgt, xml_dico, dataset_impl) # ) print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) assert len(src_datasets) == len(ter_datasets) if len(src_datasets) == 1: if task == 'translation_qe_word': src_dataset, tgt_dataset, ter_dataset, xml_dataset, \ word_tag_dataset, gap_tag_dataset, bpe_tag_dataset, xml_bpe_tag_dataset, \ src_word_tag_dataset, src_bpe_tag_dataset, xml_src_bpe_tag_dataset \ = src_datasets[0], tgt_datasets[0], ter_datasets[0], xml_datasets[0], \ word_tag_datasets[0], gap_tag_datasets[0], bpe_tag_datasets[0], xml_bpe_tag_datasets[0], \ src_word_tag_datasets[0], src_bpe_tag_datasets[0], xml_src_bpe_tag_datasets[0] else: src_dataset, tgt_dataset, ter_dataset, xml_dataset = src_datasets[0], tgt_datasets[0], ter_datasets[0], xml_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) ter_dataset = ConcatDataset(ter_datasets, sample_ratios) xml_dataset = ConcatDataset(xml_datasets, sample_ratios) if task == 'translation_qe_word': return LanguagePairWordDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, ter_dataset, xml_dataset, xml_dico, xml_params, xml_pad_indx=xml_params.pad_index, word_tag=word_tag_dataset, gap_tag=gap_tag_dataset, bpe_tag=bpe_tag_dataset, xml_bpe_tag=xml_bpe_tag_dataset, src_word_tag=src_word_tag_dataset, src_bpe_tag=src_bpe_tag_dataset, xml_src_bpe_tag=xml_src_bpe_tag_dataset, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, shuffle=shuffle ) else: return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, ter_dataset, xml_dataset, xml_dico, xml_params, xml_pad_indx=xml_params.pad_index, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, shuffle=shuffle )
def load_dataset(self, split, combine=False): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def split_exists(split, src, tgt, lang, data_path): filename = os.path.join( data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedDataset.exists(filename): return True return False def indexed_dataset(path, dictionary): if self.args.raw_text and IndexedRawTextDataset.exists(path): return IndexedRawTextDataset(path, dictionary) elif not self.args.raw_text and IndexedInMemoryDataset.exists( path): return IndexedDataset(path, fix_lua_indexing=False) return None src_datasets = [] tgt_datasets = [] data_paths = self.args.data for dk, data_path in enumerate(data_paths): for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode src, tgt = self.args.source_lang, self.args.target_lang if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join( data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join( data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else: if k > 0 or dk > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format( split, data_path)) src_datasets.append( indexed_dataset(prefix + src, self.src_dict)) tgt_datasets.append( indexed_dataset(prefix + tgt, self.tgt_dict)) print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = self.args.upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) self.datasets[split] = SummerizationLanguagePairDataset( src_dataset, src_dataset.sizes, self.src_dict, tgt_dataset, tgt_dataset.sizes, self.tgt_dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, with_target=(split != 'test'))
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ loaded_datasets = [] paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') path = os.path.join(data_path, split_k) ds = indexed_dataset.make_dataset(path, impl=self.args.dataset_impl, fix_lua_indexing=True, dictionary=self.dictionary) if ds is None: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format(split, data_path)) loaded_datasets.append( TokenBlockDataset( ds, ds.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, )) print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1]))) if not combine: break if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none' self.datasets[split] = MonolingualDataset( dataset, sizes, self.dictionary, self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, add_bos_token=self.args.add_bos_token, )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def indexed_dataset(path, dictionary): return IndexedCachedDataset(path, fix_lua_indexing=True) src_datasets = [] tgt_datasets = [] src_lngs = [] tgt_lngs = [] dataset_ids = [] dataset_names = [] lng_borders = [0] data_path = self.args.data[0] fns = glob.glob(os.path.join(data_path, f'{split}.*')) lng_pairs = list(set([f.split('.')[1] for f in fns])) lng_pairs = sorted(lng_pairs) ds_idx = 0 sources = [s for s in self.args.sources.split(",") if s != ''] targets = [t for t in self.args.targets.split(",") if t != ''] is_distill = self.args.criterion == 'distill_label_smoothed_cross_entropy' and split == 'train' topk_idxs = [] topk_probs = [] expert_scores = [] for idx, lng_pair in enumerate(lng_pairs): src, tgt = lng_pair.split('-') prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt)) def add_dataset(src, tgt): if (src not in sources and len(sources) > 0) or (tgt not in targets and len(targets) > 0): return 0 if not os.path.exists(prefix + src + ".bin") or \ not os.path.exists(prefix + tgt + ".bin"): return 0 if is_distill and not os.path.exists( os.path.join(self.args.data[0], '{}_{}_topk_idx.idx'.format(src, tgt))): return 0 src_ds = indexed_dataset(prefix + src, self.src_dict) src_datasets.append(src_ds) tgt_ds = indexed_dataset(prefix + tgt, self.tgt_dict) tgt_datasets.append(tgt_ds) l = len(src_ds) if self.args.data_limit != '' \ and src + "-" + tgt == self.args.data_limit.split(':')[0] \ and l > int(self.args.data_limit.split(':')[1]): l = int(self.args.data_limit.split(':')[1]) src_datasets[-1].size = l tgt_datasets[-1].size = l l = len(src_ds) print("| Add dataset {} -> {}. size:{}".format(src, tgt, l)) lng_borders.append(lng_borders[-1] + l) dataset_names.append(f"{src}_{tgt}") for i in range(l): src_lngs.append(self.lng2id[src]) tgt_lngs.append(self.lng2id[tgt]) dataset_ids.append(ds_idx) if is_distill: assert self.args.data_limit == '' path = os.path.join(self.args.data[0], '{}_{}_topk_idx'.format(src, tgt)) topk_idxs.append(TeacherOutputDataset(path)) path = os.path.join(self.args.data[0], '{}_{}_topk_prob'.format(src, tgt)) topk_probs.append(TeacherOutputDataset(path)) expert_bleu = os.path.join( self.args.data[0], 'expert_bleu_{}_{}.json'.format(src, tgt)) expert_bleu = json.load(open(expert_bleu)) expert_scores.append(expert_bleu[f"bleu_{src}_{tgt}"]) return 1 ds_idx += add_dataset(src, tgt) ds_idx += add_dataset(tgt, src) src_dataset = ConcatDataset(src_datasets) tgt_dataset = ConcatDataset(tgt_datasets) src_sizes = np.concatenate([ds.sizes for ds in src_datasets]) tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets]) topk_idx_dataset = None topk_probs_dataset = None if is_distill: topk_idx_dataset = ConcatDataset(topk_idxs) topk_probs_dataset = ConcatDataset(topk_probs) assert len(topk_probs_dataset) == len(tgt_dataset), ( len(topk_probs_dataset), len(tgt_dataset)) assert len(topk_idx_dataset) == len(tgt_dataset) self.datasets[split] = UniversalDataset( self.args, src_dataset, src_sizes, self.src_dict, src_lngs, tgt_lngs, tgt_dataset, tgt_sizes, self.tgt_dict, dataset_ids, lng_borders, dataset_names, topk_idxs=topk_idx_dataset, topk_probs=topk_probs_dataset, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, expert_scores=expert_scores, is_train=split == 'train')
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] languages = sorted(name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name))) logger.info("Training on {0} languages: {1}".format( len(languages), languages)) logger.info("Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}) mask_whole_words = self._get_whole_word_mask() lang_datasets = [] for lang_id, language in enumerate(languages): split_path = os.path.join(data_path, language, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info('loaded {} blocks from: {}'.format( len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) lang_dataset = NestedDictionaryDataset( { 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), 'lang_id': RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]), }, sizes=[src_dataset.sizes], ) lang_datasets.append(lang_dataset) dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) logger.info('loaded total {} blocks for all languages'.format( dataset_lengths.sum(), )) if split == self.args.train_subset: # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) logger.info( "Sample probability by language: ", { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) }) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths logger.info( "Up/Down Sampling ratio by language: ", { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) }) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] >= 1.0, ) for i, d in enumerate(lang_datasets) ] dataset = ConcatDataset(resampled_lang_datasets) else: dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): split_name = split + '_' + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset # [TODO]: This is hacky for now to print validation ppl for each # language individually. Maybe need task API changes to allow it # in more generic ways. if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( split, ','.join(lang_splits)) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )
def load_dataset(self, split, epoch=1, combine=False): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ loaded_datasets = [] paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] logger.info("data_path", data_path) for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') path = os.path.join(data_path, split_k) ds = indexed_dataset.make_dataset( path, impl=self.args.dataset_impl, fix_lua_indexing=True, dictionary=self.dictionary, ) if ds is None: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format(split, data_path)) with data_utils.numpy_seed(self.seed + k): loaded_datasets.append( BlockPairDataset( ds, self.dictionary, ds.sizes, self.args.tokens_per_sample, break_mode=self.args.break_mode, doc_break_size=1, )) logger.info('{} {} {} examples都是非常重要的例子'.format( data_path, split_k, len(loaded_datasets[-1]))) if not combine: break if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) self.datasets[split] = MaskedLMDataset( dataset=dataset, sizes=sizes, vocab=self.dictionary, pad_idx=self.dictionary.pad(), mask_idx=self.dictionary.mask(), classif_token_idx=self.dictionary.cls(), sep_token_idx=self.dictionary.sep(), shuffle=self.args.shuffle_dataset, seed=self.seed, )
def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') # infer langcode if split_exists(split_k, src, tgt, src, data_path): src_prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): src_prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) elif split_exists(split_k, src, tgt.split("_")[0], src, data_path): src_prefix = os.path.join( data_path, '{}.{}-{}.'.format(split_k, src, tgt.split("_")[0])) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) src_datasets.append( data_utils.load_indexed_dataset(src_prefix + src, src_dict, dataset_impl)) if split_exists(split_k, src, tgt, tgt, data_path): tgt_prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, src, tgt.split("_")[0], tgt, data_path): tgt_prefix = os.path.join( data_path, '{}.{}-{}.'.format(split_k, src, tgt.split("_")[0])) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, data_path)) tgt_datasets.append( data_utils.load_indexed_dataset(tgt_prefix + tgt, tgt_dict, dataset_impl)) print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1]))) if not combine: break assert len(src_datasets) == len(tgt_datasets) if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, )