def load_sampled_multi_dataset(self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs): datasets, data_param_list = self.load_split_datasets( split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs) if training and split == getattr(self.args, "train_subset", None): sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch) return SampledMultiDataset( OrderedDict(datasets), epoch=epoch, # valid and test datasets will be degerate to concating datasets: sampling_ratios=sample_ratios, eval_key=None, collate_format=CollateFormat.single, virtual_size=self.args.virtual_data_size, split=split, # if not using lang_tok altering, simplified to use the same collater shared_collater=self._shared_collater(), ) else: return self.load_into_concat_dataset(split, datasets, data_param_list)
def load_into_concat_dataset(self, split, datasets, data_param_list): if self.args.lang_tok_replacing_bos_eos: # TODO: to investigate why TransformEosLangPairDataset doesn't work with ConcatDataset return SampledMultiDataset( OrderedDict(datasets), sampling_ratios=None, eval_key=None, collate_format=CollateFormat.single, virtual_size=None, split=split, ) return ConcatDataset([d for _, d in datasets])