def load_sampled_multi_epoch_dataset( self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs ): datasets, data_param_list = self.load_split_datasets( split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs ) if training and split == getattr(self.args, "train_subset", None): sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch) data_sizes = self.get_ordered_train_dataset_sizes(data_param_list, datasets) self.data_ratios = data_sizes / sum(data_sizes) datasets = OrderedDict(datasets) idx = 1 if self.target_group == "target_lang" else 0 my_lang_ids = [_lang_id(self.lang_dict, key.split(":")[-1].split("-")[idx]) for key in list(datasets.keys())] logger.info("Mapped lang ids = {}".format(my_lang_ids)) return SampledMultiEpochDataset( datasets, epoch=epoch, shard_epoch=shard_epoch, # valid and test datasets will be degenerate to concating datasets: sampling_ratios=sample_ratios, eval_key=None, collate_format=CollateFormat.single, virtual_size=self.args.virtual_data_size, split=split, virtual_epoch_size=self.args.virtual_epoch_size, # if not using lang_tok altering, simplified to use the same collater shared_collater=self._shared_collater(), remapped_lang_ids=np.array(my_lang_ids) ) else: return self.load_into_concat_dataset(split, datasets, data_param_list)
def load_sampled_multi_epoch_dataset(self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs): datasets, data_param_list = self.load_split_datasets( split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs) if training and split == getattr(self.args, "train_subset", None): sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch) return SampledMultiEpochDataset( OrderedDict(datasets), epoch=epoch, shard_epoch=shard_epoch, # valid and test datasets will be degenerate to concating datasets: sampling_ratios=sample_ratios, eval_key=None, collate_format=CollateFormat.single, virtual_size=self.args.virtual_data_size, split=split, virtual_epoch_size=self.args.virtual_epoch_size, # if not using lang_tok altering, simplified to use the same collater shared_collater=self._shared_collater(), ) else: return self.load_into_concat_dataset(split, datasets, data_param_list)
def load_into_sampled_multi_epoch_dataset( self, split, datasets, data_param_list, epoch, shard_epoch=None ): sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch) return SampledMultiEpochDataset( OrderedDict(datasets), epoch=epoch, shard_epoch=shard_epoch, # valid and test datasets will be degerate to concating datasets: sampling_ratios=sample_ratios, eval_key=None, batch_by_size=True, collate_format=CollateFormat.single, virtual_size=self.args.virtual_data_size, split=split, virtual_epoch_size=self.args.virtual_epoch_size, # if not using lang_tok altering, simplified to use the same collater shared_collater=self._shared_collater(), )