Beispiel #1
0
    def _setup_dataloader_from_config(self, cfg: DictConfig):
        if cfg.get("use_tarred_dataset", False):
            if cfg.get("metadata_file") is None:
                raise FileNotFoundError("Trying to use tarred data set but could not find metadata path in config.")
            else:
                if not self.multilingual:
                    metadata_file_list = [cfg.get('metadata_file')]
                else:
                    metadata_file_list = cfg.get('metadata_file')

                datasets = []
                for idx, metadata_file in enumerate(metadata_file_list):
                    with open(metadata_file) as metadata_reader:
                        metadata = json.load(metadata_reader)
                    if cfg.get('tar_files') is None:
                        tar_files = metadata.get('tar_files')
                        if tar_files is not None:
                            logging.info(f'Loading from tarred dataset {tar_files}')
                        else:
                            raise FileNotFoundError("Could not find tarred dataset in config or metadata.")
                    else:
                        tar_files = cfg.get('tar_files')
                        if self.multilingual:
                            tar_files = tar_files[idx]
                        if metadata.get('tar_files') is not None:
                            logging.info(
                                f'Tar file paths found in both cfg and metadata using one in cfg by default - {tar_files}'
                            )

                    dataset = TarredTranslationDataset(
                        text_tar_filepaths=tar_files,
                        metadata_path=metadata_file,
                        encoder_tokenizer=self.encoder_tokenizer,
                        decoder_tokenizer=self.decoder_tokenizer,
                        shuffle_n=cfg.get("tar_shuffle_n", 100),
                        shard_strategy=cfg.get("shard_strategy", "scatter"),
                        global_rank=self.global_rank,
                        world_size=self.world_size,
                        reverse_lang_direction=cfg.get("reverse_lang_direction", False),
                        prepend_id=self.multilingual_ids[idx],
                    )
                    datasets.append(dataset)

                if self.multilingual:
                    dataset = ConcatDataset(
                        datasets=datasets,
                        sampling_technique=cfg.get('concat_sampling_technique'),
                        sampling_temperature=cfg.get('concat_sampling_temperature'),
                        sampling_probabilities=cfg.get('concat_sampling_probabilities'),
                        global_rank=self.global_rank,
                        world_size=self.world_size,
                    )
                else:
                    dataset = datasets[0]

            return torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=1,
                num_workers=cfg.get("num_workers", 2),
                pin_memory=cfg.get("pin_memory", False),
                drop_last=cfg.get("drop_last", False),
            )
        else:
            if not self.multilingual:
                src_file_list = [cfg.src_file_name]
                tgt_file_list = [cfg.tgt_file_name]
            else:
                src_file_list = cfg.src_file_name
                tgt_file_list = cfg.tgt_file_name

            if len(src_file_list) != len(tgt_file_list):
                raise ValueError(
                    'The same number of filepaths must be passed in for source and target while training multilingual.'
                )

            datasets = []
            for idx, src_file in enumerate(src_file_list):
                dataset = TranslationDataset(
                    dataset_src=str(Path(src_file).expanduser()),
                    dataset_tgt=str(Path(tgt_file_list[idx]).expanduser()),
                    tokens_in_batch=cfg.tokens_in_batch,
                    clean=cfg.get("clean", False),
                    max_seq_length=cfg.get("max_seq_length", 512),
                    min_seq_length=cfg.get("min_seq_length", 1),
                    max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
                    max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
                    cache_ids=cfg.get("cache_ids", False),
                    cache_data_per_node=cfg.get("cache_data_per_node", False),
                    use_cache=cfg.get("use_cache", False),
                    reverse_lang_direction=cfg.get("reverse_lang_direction", False),
                    prepend_id=self.multilingual_ids[idx],
                )
                dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer)
                datasets.append(dataset)

            if self.multilingual:
                dataset = ConcatDataset(
                    datasets=datasets,
                    shuffle=cfg.get('shuffle'),
                    sampling_technique=cfg.get('concat_sampling_technique'),
                    sampling_temperature=cfg.get('concat_sampling_temperature'),
                    sampling_probabilities=cfg.get('concat_sampling_probabilities'),
                    global_rank=self.global_rank,
                    world_size=self.world_size,
                )
                return torch.utils.data.DataLoader(
                    dataset=dataset,
                    batch_size=1,
                    num_workers=cfg.get("num_workers", 2),
                    pin_memory=cfg.get("pin_memory", False),
                    drop_last=cfg.get("drop_last", False),
                )
            else:
                dataset = datasets[0]

        if cfg.shuffle:
            sampler = pt_data.RandomSampler(dataset)
        else:
            sampler = pt_data.SequentialSampler(dataset)
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=1,
            sampler=sampler,
            num_workers=cfg.get("num_workers", 2),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
        )
Beispiel #2
0
 def _setup_dataloader_from_config(self, cfg: DictConfig):
     if cfg.get("load_from_cached_dataset", False):
         logging.info('Loading from cached dataset %s' %
                      (cfg.src_file_name))
         if cfg.src_file_name != cfg.tgt_file_name:
             raise ValueError(
                 "src must be equal to target for cached dataset")
         dataset = pickle.load(open(cfg.src_file_name, 'rb'))
         dataset.reverse_lang_direction = cfg.get("reverse_lang_direction",
                                                  False)
     elif cfg.get("use_tarred_dataset", False):
         if cfg.get('tar_files') is None:
             raise FileNotFoundError("Could not find tarred dataset.")
         logging.info(f'Loading from tarred dataset {cfg.get("tar_files")}')
         if cfg.get("metadata_file", None) is None:
             raise FileNotFoundError(
                 "Could not find metadata path in config")
         dataset = TarredTranslationDataset(
             text_tar_filepaths=cfg.tar_files,
             metadata_path=cfg.metadata_file,
             encoder_tokenizer=self.encoder_tokenizer,
             decoder_tokenizer=self.decoder_tokenizer,
             shuffle_n=cfg.get("tar_shuffle_n", 100),
             shard_strategy=cfg.get("shard_strategy", "scatter"),
             global_rank=self.global_rank,
             world_size=self.world_size,
             reverse_lang_direction=cfg.get("reverse_lang_direction",
                                            False),
         )
         return torch.utils.data.DataLoader(
             dataset=dataset,
             batch_size=1,
             num_workers=cfg.get("num_workers", 2),
             pin_memory=cfg.get("pin_memory", False),
             drop_last=cfg.get("drop_last", False),
         )
     else:
         dataset = TranslationDataset(
             dataset_src=str(Path(cfg.src_file_name).expanduser()),
             dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()),
             tokens_in_batch=cfg.tokens_in_batch,
             clean=cfg.get("clean", False),
             max_seq_length=cfg.get("max_seq_length", 512),
             min_seq_length=cfg.get("min_seq_length", 1),
             max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
             max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
             cache_ids=cfg.get("cache_ids", False),
             cache_data_per_node=cfg.get("cache_data_per_node", False),
             use_cache=cfg.get("use_cache", False),
             reverse_lang_direction=cfg.get("reverse_lang_direction",
                                            False),
         )
         dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer)
     if cfg.shuffle:
         sampler = pt_data.RandomSampler(dataset)
     else:
         sampler = pt_data.SequentialSampler(dataset)
     return torch.utils.data.DataLoader(
         dataset=dataset,
         batch_size=1,
         sampler=sampler,
         num_workers=cfg.get("num_workers", 2),
         pin_memory=cfg.get("pin_memory", False),
         drop_last=cfg.get("drop_last", False),
     )