def write_batches_to_tarfiles( args, src_fname, tgt_fname, num_tokens, encoder_tokenizer, decoder_tokenizer, num_files_in_tar, tar_file_ptr, tar_file_ctr, global_batch_ctr, ): """ Writes current fragment of the overall parallel corpus to tarfiles by: (1) Creating a minibatches using a TranslationDataset object. (2) Writing each minibatch to a pickle file. (3) Adding pickle files to a tarfile until it reaches args.num_batches_per_tarfile. """ dataset = TranslationDataset( dataset_src=src_fname, dataset_tgt=tgt_fname, tokens_in_batch=num_tokens, clean=args.clean, max_seq_length=args.max_seq_length, min_seq_length=args.min_seq_length, max_seq_length_diff=args.max_seq_length, max_seq_length_ratio=args.max_seq_length, cache_ids=False, cache_data_per_node=False, use_cache=False, ) dataset.batchify(encoder_tokenizer, decoder_tokenizer) for _, batch in dataset.batches.items(): global_batch_ctr += 1 pickle.dump(batch, open(os.path.join(args.out_dir, 'batch-%d.pkl' % (global_batch_ctr)), 'wb')) if num_files_in_tar == args.num_batches_per_tarfile: tar_file_ctr += 1 tar_file_ptr.close() tar_file_ptr = tarfile.open( os.path.join(args.out_dir, 'batches.tokens.%d.%d.tar' % (num_tokens, tar_file_ctr)), 'w' ) num_files_in_tar = 0 tar_file_ptr.add(os.path.join(args.out_dir, 'batch-%d.pkl' % (global_batch_ctr))) num_files_in_tar += 1 os.remove(os.path.join(args.out_dir, 'batch-%d.pkl' % (global_batch_ctr))) return tar_file_ptr, global_batch_ctr, num_files_in_tar, tar_file_ctr
def _setup_dataloader_from_config(self, cfg: DictConfig): if cfg.get("load_from_cached_dataset", False): logging.info('Loading from cached dataset %s' % (cfg.src_file_name)) if cfg.src_file_name != cfg.tgt_file_name: raise ValueError("src must be equal to target for cached dataset") dataset = pickle.load(open(cfg.src_file_name, 'rb')) dataset.reverse_lang_direction = cfg.get("reverse_lang_direction", False) elif cfg.get("load_from_tarred_dataset", False): logging.info('Loading from tarred dataset %s' % (cfg.src_file_name)) if cfg.src_file_name != cfg.tgt_file_name: raise ValueError("src must be equal to target for tarred dataset") if cfg.get("metadata_path", None) is None: raise FileNotFoundError("Could not find metadata path in config") dataset = TarredTranslationDataset( text_tar_filepaths=cfg.src_file_name, metadata_path=cfg.metadata_path, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, shuffle_n=cfg.get("tar_shuffle_n", 100), shard_strategy=cfg.get("shard_strategy", "scatter"), global_rank=self.global_rank, world_size=self.world_size, reverse_lang_direction=cfg.get("reverse_lang_direction", False), ) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) else: dataset = TranslationDataset( dataset_src=str(Path(cfg.src_file_name).expanduser()), dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()), tokens_in_batch=cfg.tokens_in_batch, clean=cfg.get("clean", False), max_seq_length=cfg.get("max_seq_length", 512), min_seq_length=cfg.get("min_seq_length", 1), max_seq_length_diff=cfg.get("max_seq_length_diff", 512), max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512), cache_ids=cfg.get("cache_ids", False), cache_data_per_node=cfg.get("cache_data_per_node", False), use_cache=cfg.get("use_cache", False), reverse_lang_direction=cfg.get("reverse_lang_direction", False), ) dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer) if cfg.shuffle: sampler = pt_data.RandomSampler(dataset) else: sampler = pt_data.SequentialSampler(dataset) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, sampler=sampler, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), )
def _setup_dataloader_from_config(self, cfg: DictConfig): if cfg.get("load_from_cached_dataset", False): logging.info('Loading from cached dataset %s' % (cfg.src_file_name)) if cfg.src_file_name != cfg.tgt_file_name: raise ValueError( "src must be equal to target for cached dataset") dataset = pickle.load(open(cfg.src_file_name, 'rb')) dataset.reverse_lang_direction = cfg.get("reverse_lang_direction", False) else: dataset = TranslationDataset( dataset_src=str(Path(cfg.src_file_name).expanduser()), dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()), tokens_in_batch=cfg.tokens_in_batch, clean=cfg.get("clean", False), max_seq_length=cfg.get("max_seq_length", 512), min_seq_length=cfg.get("min_seq_length", 1), max_seq_length_diff=cfg.get("max_seq_length_diff", 512), max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512), cache_ids=cfg.get("cache_ids", False), cache_data_per_node=cfg.get("cache_data_per_node", False), use_cache=cfg.get("use_cache", False), reverse_lang_direction=cfg.get("reverse_lang_direction", False), ) dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer) if cfg.shuffle: sampler = pt_data.RandomSampler(dataset) else: sampler = pt_data.SequentialSampler(dataset) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, sampler=sampler, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), )
def _setup_eval_dataloader_from_config(self, cfg: DictConfig): src_file_name = cfg.get('src_file_name') tgt_file_name = cfg.get('tgt_file_name') if src_file_name is None or tgt_file_name is None: raise ValueError( 'Validation dataloader needs both cfg.src_file_name and cfg.tgt_file_name to not be None.' ) else: # convert src_file_name and tgt_file_name to list of strings if isinstance(src_file_name, str): src_file_list = [src_file_name] elif isinstance(src_file_name, ListConfig): src_file_list = src_file_name else: raise ValueError( "cfg.src_file_name must be string or list of strings") if isinstance(tgt_file_name, str): tgt_file_list = [tgt_file_name] elif isinstance(tgt_file_name, ListConfig): tgt_file_list = tgt_file_name else: raise ValueError( "cfg.tgt_file_name must be string or list of strings") if len(src_file_list) != len(tgt_file_list): raise ValueError( 'The same number of filepaths must be passed in for source and target validation.' ) dataloaders = [] prepend_idx = 0 for idx, src_file in enumerate(src_file_list): if self.multilingual: prepend_idx = idx dataset = TranslationDataset( dataset_src=str(Path(src_file).expanduser()), dataset_tgt=str(Path(tgt_file_list[idx]).expanduser()), tokens_in_batch=cfg.tokens_in_batch, clean=cfg.get("clean", False), max_seq_length=cfg.get("max_seq_length", 512), min_seq_length=cfg.get("min_seq_length", 1), max_seq_length_diff=cfg.get("max_seq_length_diff", 512), max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512), cache_ids=cfg.get("cache_ids", False), cache_data_per_node=cfg.get("cache_data_per_node", False), use_cache=cfg.get("use_cache", False), reverse_lang_direction=cfg.get("reverse_lang_direction", False), prepend_id=self.multilingual_ids[prepend_idx] if self.multilingual else None, ) dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer) if cfg.shuffle: sampler = pt_data.RandomSampler(dataset) else: sampler = pt_data.SequentialSampler(dataset) dataloader = torch.utils.data.DataLoader( dataset=dataset, batch_size=1, sampler=sampler, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) dataloaders.append(dataloader) return dataloaders
def _setup_dataloader_from_config(self, cfg: DictConfig): if cfg.get("use_tarred_dataset", False): if cfg.get("metadata_file") is None: raise FileNotFoundError( "Trying to use tarred data set but could not find metadata path in config." ) metadata_file_list = cfg.get('metadata_file') tar_files_list = cfg.get('tar_files', None) if isinstance(metadata_file_list, str): metadata_file_list = [metadata_file_list] if tar_files_list is not None and isinstance(tar_files_list, str): tar_files_list = [tar_files_list] if tar_files_list is not None and len(tar_files_list) != len( metadata_file_list): raise ValueError( 'The config must have the same number of tarfile paths and metadata file paths.' ) datasets = [] for idx, metadata_file in enumerate(metadata_file_list): with open(metadata_file) as metadata_reader: metadata = json.load(metadata_reader) if tar_files_list is None: tar_files = metadata.get('tar_files') if tar_files is not None: logging.info( f'Loading from tarred dataset {tar_files}') else: tar_files = tar_files_list[idx] if metadata.get('tar_files') is not None: logging.info( f'Tar file paths found in both cfg and metadata using one in cfg by default - {tar_files}' ) dataset = TarredTranslationDataset( text_tar_filepaths=tar_files, metadata_path=metadata_file, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, shuffle_n=cfg.get("tar_shuffle_n", 100), shard_strategy=cfg.get("shard_strategy", "scatter"), global_rank=self.global_rank, world_size=self.world_size, reverse_lang_direction=cfg.get("reverse_lang_direction", False), prepend_id=self.multilingual_ids[idx] if self.multilingual else None, ) datasets.append(dataset) if len(datasets) > 1: dataset = ConcatDataset( datasets=datasets, sampling_technique=cfg.get('concat_sampling_technique'), sampling_temperature=cfg.get( 'concat_sampling_temperature'), sampling_probabilities=cfg.get( 'concat_sampling_probabilities'), global_rank=self.global_rank, world_size=self.world_size, ) else: dataset = datasets[0] else: src_file_list = cfg.src_file_name tgt_file_list = cfg.tgt_file_name if isinstance(src_file_list, str): src_file_list = [src_file_list] if isinstance(tgt_file_list, str): tgt_file_list = [tgt_file_list] if len(src_file_list) != len(tgt_file_list): raise ValueError( 'The same number of filepaths must be passed in for source and target.' ) datasets = [] for idx, src_file in enumerate(src_file_list): dataset = TranslationDataset( dataset_src=str(Path(src_file).expanduser()), dataset_tgt=str(Path(tgt_file_list[idx]).expanduser()), tokens_in_batch=cfg.tokens_in_batch, clean=cfg.get("clean", False), max_seq_length=cfg.get("max_seq_length", 512), min_seq_length=cfg.get("min_seq_length", 1), max_seq_length_diff=cfg.get("max_seq_length_diff", 512), max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512), cache_ids=cfg.get("cache_ids", False), cache_data_per_node=cfg.get("cache_data_per_node", False), use_cache=cfg.get("use_cache", False), reverse_lang_direction=cfg.get("reverse_lang_direction", False), prepend_id=self.multilingual_ids[idx] if self.multilingual else None, ) dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer) datasets.append(dataset) if len(datasets) > 1: dataset = ConcatDataset( datasets=datasets, shuffle=cfg.get('shuffle'), sampling_technique=cfg.get('concat_sampling_technique'), sampling_temperature=cfg.get( 'concat_sampling_temperature'), sampling_probabilities=cfg.get( 'concat_sampling_probabilities'), global_rank=self.global_rank, world_size=self.world_size, ) else: dataset = datasets[0] if cfg.shuffle: sampler = pt_data.RandomSampler(dataset) else: sampler = pt_data.SequentialSampler(dataset) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, sampler=None if cfg.get("use_tarred_dataset", False) else sampler, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), )
tokenizer_model=encoder_tokenizer_model, bpe_dropout=args.bpe_dropout) decoder_tokenizer = get_tokenizer(tokenizer_name='yttm', tokenizer_model=decoder_tokenizer_model, bpe_dropout=args.bpe_dropout) tokens_in_batch = [int(item) for item in args.tokens_in_batch.split(',')] for num_tokens in tokens_in_batch: dataset = TranslationDataset( dataset_src=str(Path(args.src_fname).expanduser()), dataset_tgt=str(Path(args.tgt_fname).expanduser()), tokens_in_batch=num_tokens, clean=args.clean, max_seq_length=args.max_seq_length, min_seq_length=args.min_seq_length, max_seq_length_diff=args.max_seq_length, max_seq_length_ratio=args.max_seq_length, cache_ids=False, cache_data_per_node=False, use_cache=False, ) print('Batchifying ...') dataset.batchify(encoder_tokenizer, decoder_tokenizer) start = time.time() pickle.dump( dataset, open( os.path.join(args.out_dir, 'batches.tokens.%d.pkl' % (num_tokens)), 'wb')) end = time.time()