Beispiel #1
0
    def load_sampled_multi_epoch_dataset(
        self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs
    ):
        datasets, data_param_list = self.load_split_datasets(
            split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs
        )

        if training and split == getattr(self.args, "train_subset", None):
            sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch)
            data_sizes = self.get_ordered_train_dataset_sizes(data_param_list, datasets)
            self.data_ratios = data_sizes / sum(data_sizes)
            datasets = OrderedDict(datasets)
            idx = 1 if self.target_group == "target_lang" else 0
            my_lang_ids = [_lang_id(self.lang_dict, key.split(":")[-1].split("-")[idx]) for key in list(datasets.keys())]
            logger.info("Mapped lang ids = {}".format(my_lang_ids))
            return SampledMultiEpochDataset(
                datasets,
                epoch=epoch,
                shard_epoch=shard_epoch,
                # valid and test datasets will be degenerate to concating datasets:
                sampling_ratios=sample_ratios,
                eval_key=None,
                collate_format=CollateFormat.single,
                virtual_size=self.args.virtual_data_size,
                split=split,
                virtual_epoch_size=self.args.virtual_epoch_size,
                # if not using lang_tok altering, simplified to use the same collater
                shared_collater=self._shared_collater(),
                remapped_lang_ids=np.array(my_lang_ids)
            )
        else:
            return self.load_into_concat_dataset(split, datasets, data_param_list)
Beispiel #2
0
 def load_sampled_multi_epoch_dataset(self,
                                      split,
                                      training,
                                      epoch=0,
                                      combine=False,
                                      shard_epoch=None,
                                      **kwargs):
     datasets, data_param_list = self.load_split_datasets(
         split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs)
     if training and split == getattr(self.args, "train_subset", None):
         sample_ratios = self.get_sampling_ratios(data_param_list, datasets,
                                                  epoch)
         return SampledMultiEpochDataset(
             OrderedDict(datasets),
             epoch=epoch,
             shard_epoch=shard_epoch,
             # valid and test datasets will be degenerate to concating datasets:
             sampling_ratios=sample_ratios,
             eval_key=None,
             collate_format=CollateFormat.single,
             virtual_size=self.args.virtual_data_size,
             split=split,
             virtual_epoch_size=self.args.virtual_epoch_size,
             # if not using lang_tok altering, simplified to use the same collater
             shared_collater=self._shared_collater(),
         )
     else:
         return self.load_into_concat_dataset(split, datasets,
                                              data_param_list)
Beispiel #3
0
 def load_into_sampled_multi_epoch_dataset(
     self, split, datasets, data_param_list, epoch, shard_epoch=None
 ):
     sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch)
     return SampledMultiEpochDataset(
         OrderedDict(datasets),
         epoch=epoch,
         shard_epoch=shard_epoch,
         # valid and test datasets will be degerate to concating datasets:
         sampling_ratios=sample_ratios,
         eval_key=None,
         batch_by_size=True,
         collate_format=CollateFormat.single,
         virtual_size=self.args.virtual_data_size,
         split=split,
         virtual_epoch_size=self.args.virtual_epoch_size,
         # if not using lang_tok altering, simplified to use the same collater
         shared_collater=self._shared_collater(),
     )