def write_parallel_batches_to_tarfiles( out_dir, num_batches_per_tarfile, clean, max_seq_length, min_seq_length, src_fname, tgt_fname, num_tokens, encoder_tokenizer_name, encoder_tokenizer_model, encoder_tokenizer_r2l, encoder_bpe_dropout, encoder_model_name, decoder_tokenizer_name, decoder_tokenizer_model, decoder_bpe_dropout, decoder_model_name, decoder_tokenizer_r2l, fragment_index, ): """ Writes current fragment of the overall parallel corpus to tarfiles by: (1) Creating a minibatches using a TranslationDataset object. (2) Writing each minibatch to a pickle file. (3) Adding pickle files to a tarfile until it reaches num_batches_per_tarfile. """ dataset = TranslationDataset( dataset_src=src_fname, dataset_tgt=tgt_fname, tokens_in_batch=num_tokens, clean=clean, max_seq_length=max_seq_length, min_seq_length=min_seq_length, max_seq_length_diff=max_seq_length, max_seq_length_ratio=max_seq_length, cache_ids=False, cache_data_per_node=False, use_cache=False, ) encoder_tokenizer, decoder_tokenizer = MTDataPreproc.get_enc_dec_tokenizers( encoder_tokenizer_name=encoder_tokenizer_name, encoder_tokenizer_model=encoder_tokenizer_model, encoder_bpe_dropout=encoder_bpe_dropout, encoder_model_name=encoder_model_name, encoder_r2l=encoder_tokenizer_r2l, decoder_tokenizer_name=decoder_tokenizer_name, decoder_tokenizer_model=decoder_tokenizer_model, decoder_bpe_dropout=decoder_bpe_dropout, decoder_model_name=decoder_model_name, decoder_r2l=decoder_tokenizer_r2l, ) dataset.batchify(encoder_tokenizer, decoder_tokenizer) tar_file_ctr = 0 tar_file_path = os.path.join( out_dir, 'fragment-%s-batches.tokens.%d.%d.tar' % (fragment_index, num_tokens, tar_file_ctr) ) tar_file_ptr = tarfile.open(tar_file_path, 'w') total_batch_ctr = 0 batch_ctr = 0 for _, batch in dataset.batches.items(): total_batch_ctr += 1 batch_ctr += 1 pickle.dump( batch, open(os.path.join(out_dir, 'fragment-%s-batch-%d.pkl' % (fragment_index, total_batch_ctr)), 'wb'), ) tar_file_ptr.add(os.path.join(out_dir, 'fragment-%s-batch-%d.pkl' % (fragment_index, total_batch_ctr))) os.remove(os.path.join(out_dir, 'fragment-%s-batch-%d.pkl' % (fragment_index, total_batch_ctr))) if batch_ctr == num_batches_per_tarfile: tar_file_ctr += 1 tar_file_ptr.close() tar_file_path = os.path.join( out_dir, 'fragment-%s-batches.tokens.%d.%d.tar' % (fragment_index, num_tokens, tar_file_ctr) ) tar_file_ptr = tarfile.open(tar_file_path, 'w',) batch_ctr = 0 # return tar files paths that have batches remaining remainder_tar_file_path = tar_file_ptr.name tar_file_ptr.close() return total_batch_ctr, remainder_tar_file_path
def write_parallel_batches_to_tarfiles( out_dir, num_batches_per_tarfile, clean, max_seq_length, min_seq_length, src_fname, tgt_fname, num_tokens, encoder_tokenizer, decoder_tokenizer, num_files_in_tar, tar_file_ptr, tar_file_ctr, global_batch_ctr, pkl_file_prefix, ): """ Writes current fragment of the overall parallel corpus to tarfiles by: (1) Creating a minibatches using a TranslationDataset object. (2) Writing each minibatch to a pickle file. (3) Adding pickle files to a tarfile until it reaches num_batches_per_tarfile. """ dataset = TranslationDataset( dataset_src=src_fname, dataset_tgt=tgt_fname, tokens_in_batch=num_tokens, clean=clean, max_seq_length=max_seq_length, min_seq_length=min_seq_length, max_seq_length_diff=max_seq_length, max_seq_length_ratio=max_seq_length, cache_ids=False, cache_data_per_node=False, use_cache=False, ) dataset.batchify(encoder_tokenizer, decoder_tokenizer) for _, batch in dataset.batches.items(): global_batch_ctr += 1 pickle.dump( batch, open( os.path.join( out_dir, '%s-batch-%d.pkl' % (pkl_file_prefix, global_batch_ctr)), 'wb')) if num_files_in_tar == num_batches_per_tarfile: tar_file_ctr += 1 tar_file_ptr.close() tar_file_ptr = tarfile.open( os.path.join( out_dir, '%s-batches.tokens.%d.%d.tar' % (pkl_file_prefix, num_tokens, tar_file_ctr)), 'w', ) num_files_in_tar = 0 tar_file_ptr.add( os.path.join( out_dir, '%s-batch-%d.pkl' % (pkl_file_prefix, global_batch_ctr))) num_files_in_tar += 1 os.remove( os.path.join( out_dir, '%s-batch-%d.pkl' % (pkl_file_prefix, global_batch_ctr))) return tar_file_ptr, global_batch_ctr, num_files_in_tar, tar_file_ctr