예제 #1
0
def visualize_attn(args):
    """ 
    Visualization of learned attention map.
    """
    config = CONFIGS[args.model_type]
    num_classes = 10 if args.dataset == "cifar10" else 100

    model = VisionTransformer(config, args.img_size, 
                            norm_type=args.norm_type, 
                            zero_head=True, 
                            num_classes=num_classes,
                            vis=True)

    ckpt_file = os.path.join(args.output_dir, args.name + "_checkpoint.bin")
    ckpt = torch.load(ckpt_file)
    # use single card for visualize attn map
    model.load_state_dict(ckpt)
    model.to(args.device)
    model.eval()
    
    _, test_loader = get_loader(args)
    sample_idx = 0
    layer_ids = [0, 3, 6, 9]
    head_id = 0
    with torch.no_grad():
        for step, batch in enumerate(test_loader):
            batch = tuple(t.to(args.device) for t in batch)
            x, y = batch
            select_x = x[sample_idx].unsqueeze(0)
            output, attn_weights = model(select_x)
            # attn_weights is List[(1, number_of_head, len_h, len_h)]
            for layer_id in layer_ids:
                vis_attn(args, attn_weights[layer_id].squeeze(0)[head_id], layer_id=layer_id)
            break # visualize the first sample in the first batch
    print("done.")
    exit(0)
imagenet_labels = dict(
    enumerate(open('attention_data/ilsvrc2012_wordnet_lemmas.txt')))

img_url = "https://images.mypetlife.co.kr/content/uploads/2019/04/09192811/welsh-corgi-1581119_960_720.jpg"
urlretrieve(img_url, "attention_data/img.jpg")

# Prepare Model
config = CONFIGS["ViT-B_16"]
model = VisionTransformer(config,
                          num_classes=1000,
                          zero_head=False,
                          img_size=224,
                          vis=True)
model.load_from(np.load("attention_data/ViT-B_16-224.npz"))
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
im = Image.open("attention_data/img.jpg")
x = transform(im)
x.size()
# %%
logits, att_mat = model(x.unsqueeze(0))

att_mat = torch.stack(att_mat).squeeze(1)

# Average the attention weights across all heads.