Exemplo n.º 1
0
 def _setup_dataloader_from_config(self, cfg: DictConfig):
     if cfg.get("load_from_cached_dataset", False):
         logging.info('Loading from cached dataset %s' % (cfg.src_file_name))
         if cfg.src_file_name != cfg.tgt_file_name:
             raise ValueError("src must be equal to target for cached dataset")
         dataset = pickle.load(open(cfg.src_file_name, 'rb'))
         dataset.reverse_lang_direction = cfg.get("reverse_lang_direction", False)
     elif cfg.get("load_from_tarred_dataset", False):
         logging.info('Loading from tarred dataset %s' % (cfg.src_file_name))
         if cfg.src_file_name != cfg.tgt_file_name:
             raise ValueError("src must be equal to target for tarred dataset")
         if cfg.get("metadata_path", None) is None:
             raise FileNotFoundError("Could not find metadata path in config")
         dataset = TarredTranslationDataset(
             text_tar_filepaths=cfg.src_file_name,
             metadata_path=cfg.metadata_path,
             encoder_tokenizer=self.encoder_tokenizer,
             decoder_tokenizer=self.decoder_tokenizer,
             shuffle_n=cfg.get("tar_shuffle_n", 100),
             shard_strategy=cfg.get("shard_strategy", "scatter"),
             global_rank=self.global_rank,
             world_size=self.world_size,
             reverse_lang_direction=cfg.get("reverse_lang_direction", False),
         )
         return torch.utils.data.DataLoader(
             dataset=dataset,
             batch_size=1,
             num_workers=cfg.get("num_workers", 2),
             pin_memory=cfg.get("pin_memory", False),
             drop_last=cfg.get("drop_last", False),
         )
     else:
         dataset = TranslationDataset(
             dataset_src=str(Path(cfg.src_file_name).expanduser()),
             dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()),
             tokens_in_batch=cfg.tokens_in_batch,
             clean=cfg.get("clean", False),
             max_seq_length=cfg.get("max_seq_length", 512),
             min_seq_length=cfg.get("min_seq_length", 1),
             max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
             max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
             cache_ids=cfg.get("cache_ids", False),
             cache_data_per_node=cfg.get("cache_data_per_node", False),
             use_cache=cfg.get("use_cache", False),
             reverse_lang_direction=cfg.get("reverse_lang_direction", False),
         )
         dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer)
     if cfg.shuffle:
         sampler = pt_data.RandomSampler(dataset)
     else:
         sampler = pt_data.SequentialSampler(dataset)
     return torch.utils.data.DataLoader(
         dataset=dataset,
         batch_size=1,
         sampler=sampler,
         num_workers=cfg.get("num_workers", 2),
         pin_memory=cfg.get("pin_memory", False),
         drop_last=cfg.get("drop_last", False),
     )
Exemplo n.º 2
0
    def _setup_dataloader_from_config(self, cfg: DictConfig):
        if cfg.get("use_tarred_dataset", False):
            if cfg.get("metadata_file") is None:
                raise FileNotFoundError(
                    "Trying to use tarred data set but could not find metadata path in config."
                )
            metadata_file_list = cfg.get('metadata_file')
            tar_files_list = cfg.get('tar_files', None)
            if isinstance(metadata_file_list, str):
                metadata_file_list = [metadata_file_list]
            if tar_files_list is not None and isinstance(tar_files_list, str):
                tar_files_list = [tar_files_list]
            if tar_files_list is not None and len(tar_files_list) != len(
                    metadata_file_list):
                raise ValueError(
                    'The config must have the same number of tarfile paths and metadata file paths.'
                )

            datasets = []
            for idx, metadata_file in enumerate(metadata_file_list):
                with open(metadata_file) as metadata_reader:
                    metadata = json.load(metadata_reader)
                if tar_files_list is None:
                    tar_files = metadata.get('tar_files')
                    if tar_files is not None:
                        logging.info(
                            f'Loading from tarred dataset {tar_files}')
                else:
                    tar_files = tar_files_list[idx]
                    if metadata.get('tar_files') is not None:
                        logging.info(
                            f'Tar file paths found in both cfg and metadata using one in cfg by default - {tar_files}'
                        )

                dataset = TarredTranslationDataset(
                    text_tar_filepaths=tar_files,
                    metadata_path=metadata_file,
                    encoder_tokenizer=self.encoder_tokenizer,
                    decoder_tokenizer=self.decoder_tokenizer,
                    shuffle_n=cfg.get("tar_shuffle_n", 100),
                    shard_strategy=cfg.get("shard_strategy", "scatter"),
                    global_rank=self.global_rank,
                    world_size=self.world_size,
                    reverse_lang_direction=cfg.get("reverse_lang_direction",
                                                   False),
                    prepend_id=self.multilingual_ids[idx]
                    if self.multilingual else None,
                )
                datasets.append(dataset)

            if len(datasets) > 1:
                dataset = ConcatDataset(
                    datasets=datasets,
                    sampling_technique=cfg.get('concat_sampling_technique'),
                    sampling_temperature=cfg.get(
                        'concat_sampling_temperature'),
                    sampling_probabilities=cfg.get(
                        'concat_sampling_probabilities'),
                    global_rank=self.global_rank,
                    world_size=self.world_size,
                )
            else:
                dataset = datasets[0]
        else:
            src_file_list = cfg.src_file_name
            tgt_file_list = cfg.tgt_file_name
            if isinstance(src_file_list, str):
                src_file_list = [src_file_list]
            if isinstance(tgt_file_list, str):
                tgt_file_list = [tgt_file_list]

            if len(src_file_list) != len(tgt_file_list):
                raise ValueError(
                    'The same number of filepaths must be passed in for source and target.'
                )

            datasets = []
            for idx, src_file in enumerate(src_file_list):
                dataset = TranslationDataset(
                    dataset_src=str(Path(src_file).expanduser()),
                    dataset_tgt=str(Path(tgt_file_list[idx]).expanduser()),
                    tokens_in_batch=cfg.tokens_in_batch,
                    clean=cfg.get("clean", False),
                    max_seq_length=cfg.get("max_seq_length", 512),
                    min_seq_length=cfg.get("min_seq_length", 1),
                    max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
                    max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
                    cache_ids=cfg.get("cache_ids", False),
                    cache_data_per_node=cfg.get("cache_data_per_node", False),
                    use_cache=cfg.get("use_cache", False),
                    reverse_lang_direction=cfg.get("reverse_lang_direction",
                                                   False),
                    prepend_id=self.multilingual_ids[idx]
                    if self.multilingual else None,
                )
                dataset.batchify(self.encoder_tokenizer,
                                 self.decoder_tokenizer)
                datasets.append(dataset)

            if len(datasets) > 1:
                dataset = ConcatDataset(
                    datasets=datasets,
                    shuffle=cfg.get('shuffle'),
                    sampling_technique=cfg.get('concat_sampling_technique'),
                    sampling_temperature=cfg.get(
                        'concat_sampling_temperature'),
                    sampling_probabilities=cfg.get(
                        'concat_sampling_probabilities'),
                    global_rank=self.global_rank,
                    world_size=self.world_size,
                )
            else:
                dataset = datasets[0]

        if cfg.shuffle:
            sampler = pt_data.RandomSampler(dataset)
        else:
            sampler = pt_data.SequentialSampler(dataset)
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=1,
            sampler=None if cfg.get("use_tarred_dataset", False) else sampler,
            num_workers=cfg.get("num_workers", 2),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
        )