示例#1
0
def vit_clone(key: str):
    src = timm.create_model(key, pretrained='True')
    dst = AutoModel.from_name(key)

    dst.embedding.positions.data.copy_(src.pos_embed.data.squeeze(0))
    dst.embedding.cls_token.data.copy_(src.cls_token.data)

    cfg = AutoConfig.from_name(key)

    return clone_model(src, dst,
                       torch.randn((1, 3, cfg.input_size, cfg.input_size)))
示例#2
0
def vit_clone(key: str):
    src = timm.create_model(key, pretrained="True")
    dst = AutoModel.from_name(key)

    cfg = AutoTransform.from_name(key)

    dst = clone_model(
        src,
        dst,
        torch.randn((1, 3, cfg.input_size, cfg.input_size)),
        dest_skip=[ViTTokens],
    )

    dst.embedding.positions.data.copy_(src.pos_embed.data.squeeze(0))
    dst.embedding.tokens.cls.data.copy_(src.cls_token.data)

    return dst
def deit_clone(key: str):
    k_split = key.split('_')
    hub_key = "_".join(k_split[:2]) + '_distilled_' + "_".join(k_split[2:])
    src = torch.hub.load('facebookresearch/deit:main',
                         hub_key, pretrained=True)

    dst = AutoModel.from_name(key)

    cfg = AutoConfig.from_name(f"vit_{'_'.join(key.split('_')[1:])}")

    dst = clone_model(src, dst, torch.randn(
        (1, 3, cfg.input_size, cfg.input_size)), dest_skip=[DeiTTokens])

    dst.embedding.positions.data.copy_(src.pos_embed.data.squeeze(0))
    dst.embedding.tokens.cls.data.copy_(src.cls_token.data)
    dst.embedding.tokens.dist.data.copy_(src.dist_token.data)

    return dst
示例#4
0
    with open("pretrained_models.txt", "w") as f:
        f.write(",".join(list(zoo_source.keys())))

    if args.o is not None:
        save_dir = args.o
        save_dir.mkdir(exist_ok=True)
    storages = {"local": LocalStorage, "hf": HuggingFaceStorage}
    storage = storages[args.storage]()

    if args.storage == "local":
        logging.info(f"Store root={storage.root}")

    override = True

    bar = tqdm(zoo_source.items())
    uploading_bar = tqdm()
    for key, src_def in bar:
        bar.set_description(key)
        if src_def is None:
            # it means I was lazy and I meant to use timm
            src_def = partial(timm.create_model, key, pretrained=True)
        if key not in storage or override:
            if type(src_def) is tuple:
                # I have a custom clone func -> not the most elegant way, but it works!
                clone_func, flag = src_def
                cloned = clone_func(key)
            else:
                src, dst = src_def(), AutoModel.from_name(key)
                cloned = clone_model(src, dst)
            storage.put(key, cloned)