def convert_t5(args): logging.info('converting T5 model from Huggingface...') if not os.path.exists(args.dest_dir): os.mkdir(args.dest_dir) converted = {} # convert and save vocab convert_vocab(args, converted) # convert and save config gluon_cfg = convert_config(args, converted) # convert, (test), and save model hf_t5 = HF_T5.from_pretrained(args.model_name) gluon_t5 = Gluon_T5.from_cfg(gluon_cfg) gluon_t5 = convert_params(args, converted, hf_t5, gluon_t5) gluon_t5.hybridize() # test model if needed if args.test: test_conversion(args, hf_t5, gluon_t5) # rename with sha1sum rename(args, converted) logging.info('conversion completed.') logging.info('file statistics:') for item, new_path in converted.items(): logging.info('filename: {}\tsize: {}\tsha1sum: {}'.format( os.path.basename(new_path), os.path.getsize(new_path), sha1sum(new_path))) return converted
def verify_download(url, sha1_hash, overwrite): with tempfile.TemporaryDirectory() as root: download_path = os.path.join(root, 'dat0') # Firstly, verify that we are able to get download the data correctly download(url, sha1_hash=sha1_hash, path=download_path, overwrite=overwrite) assert sha1sum(download_path) == sha1_hash os.remove(download_path) # Secondly, verify that we are able to download with multiprocessing download_path = os.path.join(root, 'dat1') with multiprocessing.Pool(2) as pool: pool.map(functools.partial(download, sha1_hash=sha1_hash, path=download_path, overwrite=overwrite), [url for _ in range(2)]) assert sha1sum(download_path) == sha1_hash os.remove(download_path)
def rename(save_dir): """Rename converted files with hash""" old_names = os.listdir(save_dir) for old_name in old_names: old_path = os.path.join(save_dir, old_name) long_hash = sha1sum(old_path) file_prefix, file_sufix = old_name.split('.') new_name = '{file_prefix}-{short_hash}.{file_sufix}'.format( file_prefix=file_prefix, short_hash=long_hash[:8], file_sufix=file_sufix) new_path = os.path.join(save_dir, new_name) shutil.move(old_path, new_path) file_size = os.path.getsize(new_path) logging.info('\t{} {} {}'.format(new_path, long_hash, file_size))