def from_tsv(cls, root: str, cfg: S2TDataConfig, splits: str, tgt_dict, pre_tokenizer, bpe_tokenizer, is_train_split: bool, epoch: int, seed: int, n_frames_per_step: int = 1, speaker_to_id=None) -> SpeechToTextDataset: datasets = [ cls._from_tsv(root, cfg, split, tgt_dict, is_train_split, pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id) for split in splits.split(",") ] if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: # temperature-based sampling size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) datasets = [ ResamplingDataset(d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)) for r, d in zip(size_ratios, datasets) ] return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
def resample_datasets(self, lang_datasets, lang_pairs_all, epoch): # For train subset, additionally up or down sample languages. if self.args.multilang_sampling_alpha == 1.0: return lang_datasets dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) sample_probs = self._get_sample_prob(dataset_lengths) logger.info("Sample probability by language pair: {}".format({ lp: "{0:.4f}".format(sample_probs[id]) for id, lp in enumerate(lang_pairs_all) })) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths logger.info("Up/Down Sampling ratio by language: {}".format({ lp: "{0:.2f}".format(size_ratio[id]) for id, lp in enumerate(lang_pairs_all) })) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] >= 1.0, ) for i, d in enumerate(lang_datasets) ] return resampled_lang_datasets
def from_tsv( cls, root: str, data_cfg: S2TDataConfigSrc, splits: str, tgt_dict, src_dict, pre_tokenizer, bpe_tokenizer, is_train_split: bool, epoch: int, seed: int, ) -> SpeechToTextDatasetWithSrc: samples = [] _splits = splits.split(",") for split in _splits: tsv_path = op.join(root, f"{split}.tsv") if not op.isfile(tsv_path): raise FileNotFoundError(f"Dataset not found: {tsv_path}") with open(tsv_path) as f: reader = csv.DictReader( f, delimiter="\t", quotechar=None, doublequote=False, lineterminator="\n", quoting=csv.QUOTE_NONE, ) samples.append([dict(e) for e in reader]) assert len(samples) > 0 datasets = [ cls._from_list( name, is_train_split, [s], data_cfg, tgt_dict, src_dict, pre_tokenizer, bpe_tokenizer, ) for name, s in zip(_splits, samples) ] if is_train_split and len( _splits) > 1 and data_cfg.sampling_alpha != 1.0: # temperature-based sampling size_ratios = cls._get_size_ratios(_splits, [len(s) for s in samples], alpha=data_cfg.sampling_alpha) datasets = [ ResamplingDataset(d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)) for d, r in zip(datasets, size_ratios) ] return ConcatDataset(datasets)
def load_langpair_dataset(self, prepend_tgt_lang_tag=False, sampling_alpha=1.0, epoch=0): lang_pairs = [] text_dataset = None split = "train" for lp in self.args.langpairs.split(","): src, tgt = lp.split("-") text_dataset = load_langpair_dataset( self.args.parallel_text_data, split, src, self.src_dict, tgt, self.tgt_dict, combine=True, dataset_impl=None, upsample_primary=1, left_pad_source=False, left_pad_target=False, max_source_positions=self.args.max_positions_text, max_target_positions=self.args.max_target_positions, load_alignments=False, truncate_source=False, ) if prepend_tgt_lang_tag: # TODO text_dataset = TransformEosLangPairDataset( text_dataset, src_eos=self.src_dict.eos(), tgt_bos=self.tgt_dict.eos( ), # 'prev_output_tokens' starts with eos new_tgt_bos=self.tgt_dict.index( LANG_TAG_TEMPLATE.format(tgt)), ) lang_pairs.append(text_dataset) if len(lang_pairs) > 1: if sampling_alpha != 1.0: size_ratios = SpeechToTextDatasetCreator.get_size_ratios( self.args.langpairs.split(","), [len(s) for s in lang_pairs], alpha=sampling_alpha, ) lang_pairs = [ ResamplingDataset(d, size_ratio=r, epoch=epoch, replace=(r >= 1.0)) for d, r in zip(lang_pairs, size_ratios) ] return ConcatDataset(lang_pairs) return text_dataset
def from_tsv( cls, root: str, cfg: S2TJointDataConfig, splits: str, tgt_dict, src_dict, pre_tokenizer, bpe_tokenizer, src_pre_tokenizer, src_bpe_tokenizer, is_train_split: bool, epoch: int, seed: int, append_eos: Optional[bool] = True, use_src_lang_id: Optional[int] = 0, ) -> SpeechToTextJointDataset: datasets = [ cls._from_tsv( root, cfg, split, tgt_dict, src_dict, is_train_split, pre_tokenizer, bpe_tokenizer, src_pre_tokenizer, src_bpe_tokenizer, append_eos=append_eos, use_src_lang_id=use_src_lang_id, ) for split in splits.split(",") ] if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: # temperature-based sampling size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) datasets = [ ResamplingDataset(d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)) for r, d in zip(size_ratios, datasets) ] return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
def test_resampling_dataset_batch_by_size_true(self): resampling_dataset = ResamplingDataset( self.dataset, self.weights, size_ratio=self.size_ratio, batch_by_size=True, seed=0, ) results = self._test_common(resampling_dataset, iters=1000) # For batch_by_size = True, the batches should be returned in # increasing order of size. assert results["ordered_by_size"] # Allow tolerance in distribution error of 2%. assert results["max_distribution_diff"] < 0.02
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(":") assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) if self.langs is None: languages = sorted( [ name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name)) ] ) else: languages = self.langs.split(",") for name in languages: p = os.path.join(data_path, name) assert os.path.exists(p), "data not found: {}".format(p) logger.info("Training on {0} languages: {1}".format(len(languages), languages)) logger.info( "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)} ) mask_whole_words = get_whole_word_mask(self.args, self.dictionary) language_without_segmentations = self.args.no_whole_word_mask_langs.split(",") lang_datasets = [] for language in languages: split_path = os.path.join(data_path, language, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, split_path) ) end_token = ( self.source_dictionary.index("[{}]".format(language)) if self.args.add_lang_token else self.source_dictionary.eos() ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 2, # one less for <s> pad=self.source_dictionary.pad(), eos=end_token, break_mode=self.args.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, end_token) lang_mask_whole_words = ( mask_whole_words if language not in language_without_segmentations else None ) lang_dataset = DenoisingDataset( dataset, dataset.sizes, self.dictionary, self.mask_idx, lang_mask_whole_words, shuffle=self.args.shuffle_instance, seed=self.seed, args=self.args, eos=None if not self.args.add_lang_token else self.source_dictionary.index("[{}]".format(language)), ) lang_datasets.append(lang_dataset) dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) logger.info( "loaded total {} blocks for all languages".format( int(dataset_lengths.sum()), ) ) if split == self.args.train_subset: # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) logger.info( "Sample probability by language: {}".format( { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) } ) ) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths logger.info( "Up/Down Sampling ratio by language: {}".format( { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) } ) ) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] >= 1.0, ) for i, d in enumerate(lang_datasets) ] dataset = ConcatDataset( resampled_lang_datasets, ) else: dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): split_name = split + "_" + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( split, ",".join(lang_splits) ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] languages = sorted(name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name))) logger.info("Training on {0} languages: {1}".format( len(languages), languages)) logger.info("Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}) mask_whole_words = self._get_whole_word_mask() lang_datasets = [] for lang_id, language in enumerate(languages): split_path = os.path.join(data_path, language, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info('loaded {} blocks from: {}'.format( len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) lang_dataset = NestedDictionaryDataset( { 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), 'lang_id': RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]), }, sizes=[src_dataset.sizes], ) lang_datasets.append(lang_dataset) dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) logger.info('loaded total {} blocks for all languages'.format( dataset_lengths.sum(), )) if split == self.args.train_subset: # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) logger.info( "Sample probability by language: ", { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) }) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths logger.info( "Up/Down Sampling ratio by language: ", { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) }) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] >= 1.0, ) for i, d in enumerate(lang_datasets) ] dataset = ConcatDataset(resampled_lang_datasets) else: dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): split_name = split + '_' + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset # [TODO]: This is hacky for now to print validation ppl for each # language individually. Maybe need task API changes to allow it # in more generic ways. if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( split, ','.join(lang_splits)) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )
def load_dataset(self, split: str, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ languages, data_path = MultilingualLanguageModelingTask._get_langs( self.args, epoch) lang_to_offline_shard_ratio = None if self.args.lang_to_offline_shard_ratio != "": lang_to_offline_shard_ratio = {} assert os.path.exists( self.args.lang_to_offline_shard_ratio ), "provided offline shard ratio file doesn't exist: {0}".format( self.args.lang_to_offline_shard_ratio) with open(self.args.lang_to_offline_shard_ratio) as fin: for line in fin: lang, ratio = line.strip().split("\t") ratio = float(ratio) lang_to_offline_shard_ratio[lang] = ratio logger.info( "Found offline sharded ratio: %s", lang_to_offline_shard_ratio, ) if split == self.args.train_subset: logger.info("Training on {0} languages: {1}".format( len(languages), languages)) else: logger.info("Evaluating on {0} languages: {1}".format( len(languages), languages)) tokens_per_sample = self.args.tokens_per_sample - int( self.args.add_bos_token) fixed_pad_length = None if self.args.pad_to_fixed_length: fixed_pad_length = self.args.tokens_per_sample pad_to_bsz = None if self.args.pad_to_fixed_bsz: pad_to_bsz = (self.args.batch_size_valid if "valid" in split else self.args.batch_size) lang_datasets = [] for lang_id, language in enumerate(languages): split_path = os.path.join(data_path, language, split) dataset = data_utils.load_indexed_dataset(split_path, self.dictionary, self.args.dataset_impl, combine=combine) # print('len(dataset) =', len(dataset)) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, tokens_per_sample, self.args.seed, ) dataset = TokenBlockDataset( dataset, dataset.sizes, tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, ) add_eos_for_other_targets = ( self.args.sample_break_mode is not None and self.args.sample_break_mode != "none") src_lang_idx, tgt_lang_idx = None, None if self.args.add_bos_token: src_lang_idx = self.dictionary.index(lang_token(language)) tgt_lang_idx = self.output_dictionary.index( lang_token(language)) lang_datasets.append( MonolingualDataset( dataset=dataset, sizes=dataset.sizes, src_vocab=self.dictionary, tgt_vocab=self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, fixed_pad_length=fixed_pad_length, pad_to_bsz=pad_to_bsz, add_bos_token=self.args.add_bos_token, src_lang_idx=src_lang_idx, tgt_lang_idx=tgt_lang_idx, )) dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) logger.info("loaded total {} blocks for all languages".format( dataset_lengths.sum(), )) if split == self.args.train_subset: dataset_lengths_ratio_multiplier = np.ones(len(dataset_lengths)) if lang_to_offline_shard_ratio is not None: dataset_lengths_ratio_multiplier = [] for lang in languages: assert ( lang in lang_to_offline_shard_ratio ), "Lang: {0} missing in offline shard ratio file: {1}".format( lang, self.args.lang_to_offline_shard_ratio, ) dataset_lengths_ratio_multiplier.append( lang_to_offline_shard_ratio[lang]) dataset_lengths_ratio_multiplier = np.array( dataset_lengths_ratio_multiplier) true_dataset_lengths = (dataset_lengths * dataset_lengths_ratio_multiplier) else: true_dataset_lengths = dataset_lengths # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(true_dataset_lengths) logger.info( "Sample probability by language: %s", { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) }, ) size_ratio = (sample_probs * true_dataset_lengths.sum()) / dataset_lengths # TODO: add an option for shrinking all size ratios to below 1 # if self.args.multilang_sampling_alpha != 1: # size_ratio /= size_ratio.max() # Fix numeric errors in size ratio computation # 0.999999999999999999 -> 1 # 1.000000000000000002 -> 1 for i in range(len(size_ratio)): size_ratio[i] = round(size_ratio[i], 8) logger.info( "Up/Down Sampling ratio by language: %s", { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) }, ) logger.info( "Actual dataset size by language: %s", { lang: "{0:.2f}".format(len(lang_datasets[id])) for id, lang in enumerate(languages) }, ) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] > 1.0, ) for i, d in enumerate(lang_datasets) ] logger.info( "Resampled dataset size by language: %s", { lang: "{0:.2f}".format(len(resampled_lang_datasets[id])) for id, lang in enumerate(languages) }, ) dataset = ConcatDataset(resampled_lang_datasets) else: dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): split_name = split + "_" + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset # [TODO]: This is hacky for now to print validation ppl for each # language individually. Maybe need task API changes to allow it # in more generic ways. if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( split, ",".join(lang_splits)) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] # infer langcode lg_datasets = [] for lg in self.gt_langs: src, tgt = lg, lg bos_id = self.tgt_dict.index('[{}]'.format(lg)) data_path_lg = os.path.join(data_path, lg) dataset = load_generation_pair_dataset( data_path_lg, split, tgt, self.src_dict, self.tgt_dict, combine=combine, dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=getattr(self.args, 'max_source_positions', 1024), max_target_positions=getattr(self.args, 'max_target_positions', 1024), load_alignments=self.args.load_alignments, prepend_bos=getattr(self.args, 'preprend_bos', False), append_source_id=True, common_eos=self.args.common_eos, lg_id=bos_id) lg_datasets.append(dataset) dataset_lengths = np.array([len(d) for d in lg_datasets], dtype=float) sample_probs = self._get_sample_prob(dataset_lengths) logger.info( "| Sample probability by language: ", { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(self.gt_langs) }) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths logger.info( "| Up/Down Sampling ratio by language: ", { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(self.gt_langs) }) if split == getattr(self.args, "train_subset", "train"): resampled_lang_datasets = [ ResamplingDataset( lg_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] >= 1.0, ) for i, d in enumerate(lg_datasets) ] dataset = ConcatDataset(resampled_lang_datasets, ) else: dataset = ConcatDataset(lg_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lg_datasets): split_name = split + '_' + self.gt_langs[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset if hasattr(self.args, "valid_subset"): if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( split, ','.join(lang_splits)) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 # if not training data set, use the first shard for valid and test if split != getattr(self.args, "train_subset", None): paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] lang_datasets = [] for lang_pair in self.args.langs.split(","): [src, tgt] = lang_pair.split("-") # if multi-valid : valid.zh-en if "valid" in split and split != "valid": split_lang = split.split(".")[-1] if ( (lang_pair != split_lang) and (tgt + "-" + src != split_lang) ): continue # special for (fil and fi) langs. if (lang_pair != split_lang and (tgt + "-" + src == split_lang and "fil" in split_lang)): continue data_path_temp = data_path if self.args.add_lang_token: add_langs = (src, tgt) else: add_langs = None dataset = load_langpair_dataset( data_path_temp, split, src, self.src_dict, tgt, self.tgt_dict, combine=combine, dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, load_alignments=self.args.load_alignments, truncate_source=self.args.truncate_source, num_buckets=self.args.num_batch_buckets, shuffle=(split != "test"), pad_to_multiple=self.args.required_seq_len_multiple, plus_encoder_loss=self.args.plus_encoder_loss, add_langs=add_langs, shuffle_lang_pair=self.args.shuffle_lang_pair, args=self.args, word_trans_dict=self.word_trans_dict, word_align_dict=self.word_align, policy_ratio_dicts=self.policy_ratio_dicts ) lang_datasets.append(dataset) dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) logger.info( "loaded total {} blocks for all languages".format( int(dataset_lengths.sum()), ) ) if split == self.args.train_subset: # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) logger.info( "Sample probability by language: {}".format( { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(self.args.langs.split(",")) } ) ) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths logger.info( "Up/Down Sampling ratio by language: {}".format( { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(self.args.langs.split(",")) } ) ) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] >= 1.0, ) for i, d in enumerate(lang_datasets) ] dataset = ConcatPairDataset( resampled_lang_datasets, ) else: dataset = ConcatPairDataset(lang_datasets) self.datasets[split] = dataset