示例#1
0
def setup(args):
    # Prepare model
    config = CONFIGS[args.model_type]
    config.split = args.split
    config.slide_step = args.slide_step

    if args.dataset == "CUB_200_2011":
        num_classes = 4
    elif args.dataset == "car":
        num_classes = 196
    elif args.dataset == "nabirds":
        num_classes = 555
    elif args.dataset == "dog":
        num_classes = 120
    elif args.dataset == "INat2017":
        num_classes = 5089

    model = VisionTransformer(config,
                              args.img_size,
                              zero_head=True,
                              num_classes=num_classes,
                              smoothing_value=args.smoothing_value)
    model.load_from(np.load(args.pretrained_dir))
    if args.pretrained_model is not None:
        pretrained_model = torch.load(args.pretrained_model)['model']
        model.load_state_dict(pretrained_model)
    model.to(args.device)
    num_params = count_parameters(model)

    logger.info("{}".format(config))
    logger.info("Training parameters %s", args)
    logger.info("Total Parameter: \t%2.1fM" % num_params)
    return args, model
示例#2
0
def setup(args):
    # Prepare model
    config = CONFIGS[args.model_type]
    config.mixup = args.mixup
    config.mixup_layer = args.mixup_layer
    config.mixup_alpha = args.mixup_alpha

    num_classes = 10 if args.dataset == "cifar10" else 100

    model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes)
    try:
        model.load_from(np.load(args.pretrained_dir))
    except:
        load_model(args.pretrained_dir, model)
    model.to(args.device)
    num_params = count_parameters(model)

    print("{}".format(config))
    
    for key, value in vars(args).items():
        logging.info(f"{key}: {value}")
        #print("Training parameters %s", arg)
    print("#Trainable Parameter: \t%2.2fM" % num_params)
    print(num_params)
    return args, model
示例#3
0
def setup(args, device):
    # Prepare model
    config = CONFIGS[args.model_type]

    output_size = 256

    model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=output_size)
    model.load_from(np.load(args.pretrained_dir))
    model.to(device)
    num_params = count_parameters(model)

    logger.info("{}".format(config))
    logger.info("Training parameters %s", args)
    logger.info("Total Parameter: \t%2.1fM" % num_params)
    print(num_params)
    return args, model
示例#4
0
def setup(args):
    # Prepare model
    config = CONFIGS[args.model_type]

    num_classes = 10 if args.dataset == "cifar10" else 100

    model = VisionTransformer(config,
                              args.img_size,
                              zero_head=True,
                              num_classes=num_classes)
    model.load_from(np.load(args.pretrained_dir))
    model.to(args.device)
    num_params = count_parameters(model)

    logger.info("{}".format(config))
    logger.info("Training parameters %s", args)
    logger.info("Total Parameter: \t%2.1fM" % num_params)
    return args, model
        "attention_data/ViT-B_16-224.npz")

imagenet_labels = dict(
    enumerate(open('attention_data/ilsvrc2012_wordnet_lemmas.txt')))

img_url = "https://images.mypetlife.co.kr/content/uploads/2019/04/09192811/welsh-corgi-1581119_960_720.jpg"
urlretrieve(img_url, "attention_data/img.jpg")

# Prepare Model
config = CONFIGS["ViT-B_16"]
model = VisionTransformer(config,
                          num_classes=1000,
                          zero_head=False,
                          img_size=224,
                          vis=True)
model.load_from(np.load("attention_data/ViT-B_16-224.npz"))
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
im = Image.open("attention_data/img.jpg")
x = transform(im)
x.size()
# %%
logits, att_mat = model(x.unsqueeze(0))

att_mat = torch.stack(att_mat).squeeze(1)