def _build_train_dataset(self, data_cfg): """Build the training dataset.""" if (data_cfg.drop_last is False and data_cfg.global_batch_size > data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size()): raise ValueError( f"Cannot use drop_last=False in your training data with gradient accumulation found grad acc of {data_cfg.global_batch_size // (data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size())} with global_batch_size {data_cfg.global_batch_size}, micro_batch_size {data_cfg.micro_batch_size}, data parallel size {parallel_state.get_data_parallel_world_size()}" ) datasets = [] # Determine if we are using a single dataset or a list of datasets. is_src_list_config = isinstance(data_cfg.src_file_name, ListConfig) is_tgt_list_config = isinstance(data_cfg.tgt_file_name, ListConfig) if (is_src_list_config and not is_tgt_list_config) or (is_tgt_list_config and not is_src_list_config): raise ValueError( "src_list and tgt_list must both be either a ListConfig or a string. " ) if is_src_list_config: if len(data_cfg.src_file_name) != len(data_cfg.tgt_file_name): raise ValueError( "src_file_name and tgt_file_name must have the same number of elements. " ) else: data_cfg.src_file_name = [data_cfg.src_file_name] data_cfg.tgt_file_name = [data_cfg.tgt_file_name] for src, tgt in zip(data_cfg.src_file_name, data_cfg.tgt_file_name): dataset = SequenceToSequenceDataset( src_file_name=src, tgt_file_name=tgt, src_tokenizer=self.tokenizer, tgt_tokenizer=self.tokenizer, max_src_seq_length=data_cfg.max_src_seq_length, max_tgt_seq_length=data_cfg.max_tgt_seq_length, ) datasets.append(dataset) if len(datasets) > 1: dataset = ConcatDataset( datasets=datasets, sampling_technique=data_cfg.get('concat_sampling_technique', 'temperature'), sampling_temperature=data_cfg.get( 'concat_sampling_temperature', 5), sampling_probabilities=data_cfg.get( 'concat_sampling_probabilities', [1 / len(datasets)] * len(datasets)), global_rank=parallel_state.get_data_parallel_rank(), world_size=parallel_state.get_data_parallel_world_size(), ) return dataset else: return datasets[0]
def _setup_dataloader_from_config(self, cfg: DictConfig): if cfg.get("use_tarred_dataset", False): if cfg.get("metadata_file") is None: raise FileNotFoundError( "Trying to use tarred data set but could not find metadata path in config." ) metadata_file_list = cfg.get('metadata_file') tar_files_list = cfg.get('tar_files', None) if isinstance(metadata_file_list, str): metadata_file_list = [metadata_file_list] if tar_files_list is not None and isinstance(tar_files_list, str): tar_files_list = [tar_files_list] if tar_files_list is not None and len(tar_files_list) != len( metadata_file_list): raise ValueError( 'The config must have the same number of tarfile paths and metadata file paths.' ) datasets = [] for idx, metadata_file in enumerate(metadata_file_list): with open(metadata_file) as metadata_reader: metadata = json.load(metadata_reader) if tar_files_list is None: tar_files = metadata.get('tar_files') if tar_files is not None: logging.info( f'Loading from tarred dataset {tar_files}') else: tar_files = tar_files_list[idx] if metadata.get('tar_files') is not None: logging.info( f'Tar file paths found in both cfg and metadata using one in cfg by default - {tar_files}' ) dataset = TarredTranslationDataset( text_tar_filepaths=tar_files, metadata_path=metadata_file, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, shuffle_n=cfg.get("tar_shuffle_n", 100), shard_strategy=cfg.get("shard_strategy", "scatter"), global_rank=self.global_rank, world_size=self.world_size, reverse_lang_direction=cfg.get("reverse_lang_direction", False), prepend_id=self.multilingual_ids[idx] if self.multilingual else None, ) datasets.append(dataset) if len(datasets) > 1: dataset = ConcatDataset( datasets=datasets, sampling_technique=cfg.get('concat_sampling_technique'), sampling_temperature=cfg.get( 'concat_sampling_temperature'), sampling_probabilities=cfg.get( 'concat_sampling_probabilities'), global_rank=self.global_rank, world_size=self.world_size, ) else: dataset = datasets[0] else: src_file_list = cfg.src_file_name tgt_file_list = cfg.tgt_file_name if isinstance(src_file_list, str): src_file_list = [src_file_list] if isinstance(tgt_file_list, str): tgt_file_list = [tgt_file_list] if len(src_file_list) != len(tgt_file_list): raise ValueError( 'The same number of filepaths must be passed in for source and target.' ) datasets = [] for idx, src_file in enumerate(src_file_list): dataset = TranslationDataset( dataset_src=str(Path(src_file).expanduser()), dataset_tgt=str(Path(tgt_file_list[idx]).expanduser()), tokens_in_batch=cfg.tokens_in_batch, clean=cfg.get("clean", False), max_seq_length=cfg.get("max_seq_length", 512), min_seq_length=cfg.get("min_seq_length", 1), max_seq_length_diff=cfg.get("max_seq_length_diff", 512), max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512), cache_ids=cfg.get("cache_ids", False), cache_data_per_node=cfg.get("cache_data_per_node", False), use_cache=cfg.get("use_cache", False), reverse_lang_direction=cfg.get("reverse_lang_direction", False), prepend_id=self.multilingual_ids[idx] if self.multilingual else None, ) dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer) datasets.append(dataset) if len(datasets) > 1: dataset = ConcatDataset( datasets=datasets, shuffle=cfg.get('shuffle'), sampling_technique=cfg.get('concat_sampling_technique'), sampling_temperature=cfg.get( 'concat_sampling_temperature'), sampling_probabilities=cfg.get( 'concat_sampling_probabilities'), global_rank=self.global_rank, world_size=self.world_size, ) else: dataset = datasets[0] if cfg.shuffle: sampler = pt_data.RandomSampler(dataset) else: sampler = pt_data.SequentialSampler(dataset) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, sampler=None if cfg.get("use_tarred_dataset", False) else sampler, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), )
def _build_eval_dataset(self, data_cfg, mode='train'): """Build the evaluation dataset.""" if data_cfg.global_batch_size > data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size( ): raise ValueError( f'You are trying to use "implicit gradient accumulation" of {data_cfg.global_batch_size // (data_cfg.micro_batch_size * parallel_state.get_data_parallel_world_size())} in your validation/test datasets. This is not supported. Please set global_batch_size equal to micro_batch_size * data_parallel_world_size.' ) datasets = [] # Determine if we are using a single dataset or a list of datasets. is_src_list_config = isinstance(data_cfg.src_file_name, ListConfig) is_tgt_list_config = isinstance(data_cfg.tgt_file_name, ListConfig) is_names_list_config = False if hasattr(data_cfg, "names"): if isinstance(data_cfg.names, ListConfig): is_names_list_config = True if (is_src_list_config and not is_tgt_list_config) or (is_tgt_list_config and not is_src_list_config): raise ValueError( "src_list and tgt_list must both be either a ListConfig or a string. " ) if is_src_list_config: if len(data_cfg.src_file_name) != len(data_cfg.tgt_file_name): raise ValueError( "src_file_name and tgt_file_name must have the same number of elements. " ) if is_names_list_config and len(data_cfg.names) != len( data_cfg.src_file_name): raise ValueError( "If you are providing names for each src/tgt file, they must have the same number of elements." ) else: data_cfg.src_file_name = [data_cfg.src_file_name] data_cfg.tgt_file_name = [data_cfg.tgt_file_name] for src, tgt in zip(data_cfg.src_file_name, data_cfg.tgt_file_name): dataset = SequenceToSequenceDataset( src_file_name=src, tgt_file_name=tgt, tokenizer=self.tokenizer, max_src_seq_length=data_cfg.max_src_seq_length, max_tgt_seq_length=data_cfg.max_tgt_seq_length, ) datasets.append(dataset) if mode == 'train' and len(datasets) > 1: if len(datasets) > 1: dataset = ConcatDataset( datasets=datasets, sampling_technique=data_cfg.get( 'concat_sampling_technique', 'temperature'), sampling_temperature=data_cfg.get( 'concat_sampling_temperature', 5), sampling_probabilities=data_cfg.get( 'concat_sampling_probabilities', [1 / len(datasets)] * len(datasets)), global_rank=parallel_state.get_data_parallel_rank(), world_size=parallel_state.get_data_parallel_world_size(), ) return dataset return datasets