Exemple #1
0
 def load_sampled_multi_dataset(self,
                                split,
                                training,
                                epoch=0,
                                combine=False,
                                shard_epoch=None,
                                **kwargs):
     datasets, data_param_list = self.load_split_datasets(
         split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs)
     if training and split == getattr(self.args, "train_subset", None):
         sample_ratios = self.get_sampling_ratios(data_param_list, datasets,
                                                  epoch)
         return SampledMultiDataset(
             OrderedDict(datasets),
             epoch=epoch,
             # valid and test datasets will be degerate to concating datasets:
             sampling_ratios=sample_ratios,
             eval_key=None,
             collate_format=CollateFormat.single,
             virtual_size=self.args.virtual_data_size,
             split=split,
             # if not using lang_tok altering, simplified to use the same collater
             shared_collater=self._shared_collater(),
         )
     else:
         return self.load_into_concat_dataset(split, datasets,
                                              data_param_list)
 def load_into_concat_dataset(self, split, datasets, data_param_list):
     if self.args.lang_tok_replacing_bos_eos:
         # TODO: to investigate why TransformEosLangPairDataset doesn't work with ConcatDataset
         return SampledMultiDataset(
             OrderedDict(datasets),
             sampling_ratios=None,
             eval_key=None,
             collate_format=CollateFormat.single,
             virtual_size=None,
             split=split,
         )
     return ConcatDataset([d for _, d in datasets])