예제 #1
0
def convert_t5(args):
    logging.info('converting T5 model from Huggingface...')
    if not os.path.exists(args.dest_dir):
        os.mkdir(args.dest_dir)
    converted = {}
    # convert and save vocab
    convert_vocab(args, converted)
    # convert and save config
    gluon_cfg = convert_config(args, converted)
    # convert, (test), and save model
    hf_t5 = HF_T5.from_pretrained(args.model_name)
    gluon_t5 = Gluon_T5.from_cfg(gluon_cfg)
    gluon_t5 = convert_params(args, converted, hf_t5, gluon_t5)
    gluon_t5.hybridize()
    # test model if needed
    if args.test:
        test_conversion(args, hf_t5, gluon_t5)
    # rename with sha1sum
    rename(args, converted)
    logging.info('conversion completed.')
    logging.info('file statistics:')
    for item, new_path in converted.items():
        logging.info('filename: {}\tsize: {}\tsha1sum: {}'.format(
            os.path.basename(new_path), os.path.getsize(new_path),
            sha1sum(new_path)))
    return converted
예제 #2
0
def verify_download(url, sha1_hash, overwrite):
    with tempfile.TemporaryDirectory() as root:
        download_path = os.path.join(root, 'dat0')
        # Firstly, verify that we are able to get download the data correctly
        download(url, sha1_hash=sha1_hash, path=download_path, overwrite=overwrite)
        assert sha1sum(download_path) == sha1_hash
        os.remove(download_path)

        # Secondly, verify that we are able to download with multiprocessing
        download_path = os.path.join(root, 'dat1')
        with multiprocessing.Pool(2) as pool:
            pool.map(functools.partial(download, sha1_hash=sha1_hash,
                                       path=download_path, overwrite=overwrite),
                     [url for _ in range(2)])
        assert sha1sum(download_path) == sha1_hash
        os.remove(download_path)
예제 #3
0
def rename(save_dir):
    """Rename converted files with hash"""
    old_names = os.listdir(save_dir)
    for old_name in old_names:
        old_path = os.path.join(save_dir, old_name)
        long_hash = sha1sum(old_path)
        file_prefix, file_sufix = old_name.split('.')
        new_name = '{file_prefix}-{short_hash}.{file_sufix}'.format(
            file_prefix=file_prefix,
            short_hash=long_hash[:8],
            file_sufix=file_sufix)
        new_path = os.path.join(save_dir, new_name)
        shutil.move(old_path, new_path)
        file_size = os.path.getsize(new_path)
        logging.info('\t{} {} {}'.format(new_path, long_hash, file_size))