def write_batches_to_tarfiles( args, src_fname, tgt_fname, num_tokens, encoder_tokenizer, decoder_tokenizer, num_files_in_tar, tar_file_ptr, tar_file_ctr, global_batch_ctr, ): """ Writes current fragment of the overall parallel corpus to tarfiles by: (1) Creating a minibatches using a TranslationDataset object. (2) Writing each minibatch to a pickle file. (3) Adding pickle files to a tarfile until it reaches args.num_batches_per_tarfile. """ dataset = TranslationDataset( dataset_src=src_fname, dataset_tgt=tgt_fname, tokens_in_batch=num_tokens, clean=args.clean, max_seq_length=args.max_seq_length, min_seq_length=args.min_seq_length, max_seq_length_diff=args.max_seq_length, max_seq_length_ratio=args.max_seq_length, cache_ids=False, cache_data_per_node=False, use_cache=False, ) dataset.batchify(encoder_tokenizer, decoder_tokenizer) for _, batch in dataset.batches.items(): global_batch_ctr += 1 pickle.dump(batch, open(os.path.join(args.out_dir, 'batch-%d.pkl' % (global_batch_ctr)), 'wb')) if num_files_in_tar == args.num_batches_per_tarfile: tar_file_ctr += 1 tar_file_ptr.close() tar_file_ptr = tarfile.open( os.path.join(args.out_dir, 'batches.tokens.%d.%d.tar' % (num_tokens, tar_file_ctr)), 'w' ) num_files_in_tar = 0 tar_file_ptr.add(os.path.join(args.out_dir, 'batch-%d.pkl' % (global_batch_ctr))) num_files_in_tar += 1 os.remove(os.path.join(args.out_dir, 'batch-%d.pkl' % (global_batch_ctr))) return tar_file_ptr, global_batch_ctr, num_files_in_tar, tar_file_ctr
def _setup_dataloader_from_config(self, cfg: DictConfig): if cfg.get("load_from_cached_dataset", False): logging.info('Loading from cached dataset %s' % (cfg.src_file_name)) if cfg.src_file_name != cfg.tgt_file_name: raise ValueError( "src must be equal to target for cached dataset") dataset = pickle.load(open(cfg.src_file_name, 'rb')) dataset.reverse_lang_direction = cfg.get("reverse_lang_direction", False) else: dataset = TranslationDataset( dataset_src=str(Path(cfg.src_file_name).expanduser()), dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()), tokens_in_batch=cfg.tokens_in_batch, clean=cfg.get("clean", False), max_seq_length=cfg.get("max_seq_length", 512), min_seq_length=cfg.get("min_seq_length", 1), max_seq_length_diff=cfg.get("max_seq_length_diff", 512), max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512), cache_ids=cfg.get("cache_ids", False), cache_data_per_node=cfg.get("cache_data_per_node", False), use_cache=cfg.get("use_cache", False), reverse_lang_direction=cfg.get("reverse_lang_direction", False), ) dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer) if cfg.shuffle: sampler = pt_data.RandomSampler(dataset) else: sampler = pt_data.SequentialSampler(dataset) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, sampler=sampler, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), )
def _setup_eval_dataloader_from_config(self, cfg: DictConfig): src_file_name = cfg.get('src_file_name') tgt_file_name = cfg.get('tgt_file_name') if src_file_name is None or tgt_file_name is None: raise ValueError( 'Validation dataloader needs both cfg.src_file_name and cfg.tgt_file_name to not be None.' ) else: # convert src_file_name and tgt_file_name to list of strings if isinstance(src_file_name, str): src_file_list = [src_file_name] elif isinstance(src_file_name, ListConfig): src_file_list = src_file_name else: raise ValueError( "cfg.src_file_name must be string or list of strings") if isinstance(tgt_file_name, str): tgt_file_list = [tgt_file_name] elif isinstance(tgt_file_name, ListConfig): tgt_file_list = tgt_file_name else: raise ValueError( "cfg.tgt_file_name must be string or list of strings") if len(src_file_list) != len(tgt_file_list): raise ValueError( 'The same number of filepaths must be passed in for source and target validation.' ) dataloaders = [] prepend_idx = 0 for idx, src_file in enumerate(src_file_list): if self.multilingual: prepend_idx = idx dataset = TranslationDataset( dataset_src=str(Path(src_file).expanduser()), dataset_tgt=str(Path(tgt_file_list[idx]).expanduser()), tokens_in_batch=cfg.tokens_in_batch, clean=cfg.get("clean", False), max_seq_length=cfg.get("max_seq_length", 512), min_seq_length=cfg.get("min_seq_length", 1), max_seq_length_diff=cfg.get("max_seq_length_diff", 512), max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512), cache_ids=cfg.get("cache_ids", False), cache_data_per_node=cfg.get("cache_data_per_node", False), use_cache=cfg.get("use_cache", False), reverse_lang_direction=cfg.get("reverse_lang_direction", False), prepend_id=self.multilingual_ids[prepend_idx] if self.multilingual else None, ) dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer) if cfg.shuffle: sampler = pt_data.RandomSampler(dataset) else: sampler = pt_data.SequentialSampler(dataset) dataloader = torch.utils.data.DataLoader( dataset=dataset, batch_size=1, sampler=sampler, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) dataloaders.append(dataloader) return dataloaders
for num_tokens in tokens_in_batch: dataset = TranslationDataset( dataset_src=str(Path(args.src_fname).expanduser()), dataset_tgt=str(Path(args.tgt_fname).expanduser()), tokens_in_batch=num_tokens, clean=args.clean, max_seq_length=args.max_seq_length, min_seq_length=args.min_seq_length, max_seq_length_diff=args.max_seq_length, max_seq_length_ratio=args.max_seq_length, cache_ids=False, cache_data_per_node=False, use_cache=False, ) print('Batchifying ...') dataset.batchify(encoder_tokenizer, decoder_tokenizer) start = time.time() pickle.dump( dataset, open( os.path.join(args.out_dir, 'batches.tokens.%d.pkl' % (num_tokens)), 'wb')) end = time.time() print('Took %f time to pickle' % (end - start)) start = time.time() dataset = pickle.load( open( os.path.join(args.out_dir, 'batches.tokens.%d.pkl' % (num_tokens)), 'rb')) end = time.time() print('Took %f time to unpickle' % (end - start))