def visualize_attn(args): """ Visualization of learned attention map. """ config = CONFIGS[args.model_type] num_classes = 10 if args.dataset == "cifar10" else 100 model = VisionTransformer(config, args.img_size, norm_type=args.norm_type, zero_head=True, num_classes=num_classes, vis=True) ckpt_file = os.path.join(args.output_dir, args.name + "_checkpoint.bin") ckpt = torch.load(ckpt_file) # use single card for visualize attn map model.load_state_dict(ckpt) model.to(args.device) model.eval() _, test_loader = get_loader(args) sample_idx = 0 layer_ids = [0, 3, 6, 9] head_id = 0 with torch.no_grad(): for step, batch in enumerate(test_loader): batch = tuple(t.to(args.device) for t in batch) x, y = batch select_x = x[sample_idx].unsqueeze(0) output, attn_weights = model(select_x) # attn_weights is List[(1, number_of_head, len_h, len_h)] for layer_id in layer_ids: vis_attn(args, attn_weights[layer_id].squeeze(0)[head_id], layer_id=layer_id) break # visualize the first sample in the first batch print("done.") exit(0)
imagenet_labels = dict( enumerate(open('attention_data/ilsvrc2012_wordnet_lemmas.txt'))) img_url = "https://images.mypetlife.co.kr/content/uploads/2019/04/09192811/welsh-corgi-1581119_960_720.jpg" urlretrieve(img_url, "attention_data/img.jpg") # Prepare Model config = CONFIGS["ViT-B_16"] model = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=True) model.load_from(np.load("attention_data/ViT-B_16-224.npz")) model.eval() transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) im = Image.open("attention_data/img.jpg") x = transform(im) x.size() # %% logits, att_mat = model(x.unsqueeze(0)) att_mat = torch.stack(att_mat).squeeze(1) # Average the attention weights across all heads.