Beispiel #1
0
def main():
	args = get_args()
	root_dir = args.root_dir
	imgs = list(os.walk(root_dir))[0][2]

	save_dir = args.save_dir
	num_classes = 100 # CIFAR100
	model = ResNet.resnet(arch='resnet50', pretrained=False, num_classes=num_classes,
		use_att=args.use_att, att_mode=args.att_mode)
	#model = nn.DataParallel(model)
	#print(model)

	if args.resume:
		if os.path.isfile(args.resume):
			print(f'=> loading checkpoint {args.resume}')
			checkpoint = torch.load(args.resume)
			best_acc5 = checkpoint['best_acc5']
			model.load_state_dict(checkpoint['state_dict'], strict=False)
			print(f"=> loaded checkpoint {args.resume} (epoch {checkpoint['epoch']})")
			print(f'=> best accuracy {best_acc5}')
		else:
			print(f'=> no checkpoint found at {args.resume}')

	model_dict = get_model_dict(model, args.type)
	normalizer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

	for img_name in imgs:
		img_path = os.path.join(root_dir, img_name)
		pil_img = PIL.Image.open(img_path)
	
		torch_img = torch.from_numpy(np.asarray(pil_img))
		torch_img = torch_img.permute(2, 0, 1).unsqueeze(0)
		torch_img = torch_img.float().div(255)
		torch_img = F.interpolate(torch_img, size=(224, 224), mode='bilinear', align_corners=False)

		normalized_torch_img = normalizer(torch_img)

		gradcam = GradCAM(model_dict, True)
		gradcam_pp = GradCAMpp(model_dict, True)

		mask, _ = gradcam(normalized_torch_img)
		heatmap, result = visualize_cam(mask, torch_img)

		mask_pp, _ = gradcam_pp(normalized_torch_img)
		heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)
		
		images = torch.stack([torch_img.squeeze().cpu(), heatmap, heatmap_pp, result, result_pp], 0)

		images = make_grid(images, nrow=1)

		if args.use_att:
			save_dir = os.path.join(args.save_dir, 'att')
		else:
			save_dir = os.path.join(args.save_dir, 'no_att')

		os.makedirs(save_dir, exist_ok=True)
		output_name = img_name
		output_path = os.path.join(save_dir, output_name)

		save_image(images, output_path)
def grad_cam(img, model, layer):

    configs = [dict(model_type='resnet', arch=model, layer_name=layer)]

    for config in configs:
        config['arch'].to(device).eval()

    torch_img = transforms.Compose([transforms.ToTensor()])(img).to(device)
    normed_torch_img = transforms.Normalize(
        [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(torch_img)[None]

    cams = [[cls.from_config(**config) for cls in (GradCAM, GradCAMpp)]
            for config in configs]

    images = []
    for gradcam, gradcam_pp in cams:
        mask, _ = gradcam(normed_torch_img)
        heatmap, result = visualize_cam(mask, torch_img)

        mask_pp, _ = gradcam_pp(normed_torch_img)
        heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)

        images.extend(
            [torch_img.cpu(), heatmap, heatmap_pp, result, result_pp])

    return images
def generate_saliency_map(img, img_name):
    start = time.time()

    normalizer = Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    torch_img = torch.from_numpy(np.asarray(img)).permute(
        2, 0, 1).unsqueeze(0).float().div(255)
    torch_img = F.upsample(torch_img,
                           size=(512, 512),
                           mode='bilinear',
                           align_corners=False)
    normed_torch_img = normalizer(torch_img)

    resnet = models.resnet101(pretrained=True)
    resnet.eval()
    cam_dict = dict()
    model_dict = dict(type='resnet',
                      arch=resnet,
                      layer_name='layer4',
                      input_size=(512, 512))
    gradcam = GradCAM(model_dict, True)
    gradcam_pp = GradCAMpp(model_dict, True)

    images = []

    mask, _ = gradcam(normed_torch_img)
    heatmap, result = visualize_cam(mask, torch_img)
    mask_pp, _ = gradcam_pp(normed_torch_img)
    heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)
    images.append(
        torch.stack([
            torch_img.squeeze().cpu(), heatmap, heatmap_pp, result, result_pp
        ], 0))
    images = make_grid(torch.cat(images, 0), nrow=1)

    # Only going to use result_pp
    output_dir = 'outputs'
    os.makedirs(output_dir, exist_ok=True)
    output_name = img_name
    output_path = os.path.join(output_dir, output_name)
    save_image(result_pp, output_path)

    end = time.time()
    duration = round(end - start, 2)
    return output_path
Beispiel #4
0
def show_map():
    target_model = ResNet18()
    gradcam = GradCAM.from_config(model_type='resnet', arch=target_model, layer_name='layer4')
    img=PIL.Image.open('test.jpg')
    img =  transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])(img)
    img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)[None]
    mask, logit = gradcam(img)#class_idx=10
    heatmap, cam_result = visualize_cam(mask, img)
    return heatmap, cam_result
Beispiel #5
0
def show_map(img, model):
    target_model = model
    gradcam = GradCAM.from_config(model_type='resnet',
                                  arch=target_model,
                                  layer_name='layer4')
    img = transforms.Normalize([0.485, 0.456, 0.406],
                               [0.229, 0.224, 0.225])(img)[None]
    mask, logit = gradcam(img)
    heatmap, cam_result = visualize_cam(mask, img)
    return heatmap, cam_result
Beispiel #6
0
    target_layer = 'conv_fc'

    # use_conv_fc = False
    # checkpoint = 'checkpoints/fc_checkpoints/epoch_26.pth'
    # target_layer = 'layer4'

    root_dir = '/home1/share/pascal_voc/VOCdevkit'
    ann_file = 'VOC2007/ImageSets/Main/trainval.txt'
    img_dir = 'VOC2007'
    imgdir = os.path.join(root_dir, img_dir, 'JPEGImages')

    dataset = VOCClsDataset(root_dir=root_dir,
                            ann_file=ann_file,
                            img_dir=img_dir,
                            phase='test')
    model = resnet50(pretrained=False, num_classes=20, use_conv_fc=use_conv_fc)
    model.load_state_dict(
        torch.load(checkpoint, map_location=lambda storage, loc: storage))

    cam = CAM(model, target_layer, dataset.label2cat)
    # cam = GradCAMPlus(model, target_layer, dataset.label2cat)

    for i in range(20):
        img, label, file_name = dataset[i]
        heatmap, cats = cam(img.unsqueeze(0), label, use_gt_label)
        visualize_cam(imgdir,
                      file_name,
                      heatmap,
                      save_dir,
                      cats,
                      is_split=is_split)
Beispiel #7
0
        weights = alpha.view(b, k, 1, 1)

        saliency_map = (weights*activations).sum(1, keepdim=True)
        saliency_map = F.relu(saliency_map)
        saliency_map = F.upsample(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
        saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
        saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data

        return saliency_map, logit

    def __call__(self, input, class_idx=None, retain_graph=False):
        return self.forward(input, class_idx, retain_graph)




def show_map(img,model):
<<<<<<< HEAD
    target_model = model
=======
    target_model = model# ResNet18()
>>>>>>> 901a3b5cac6f7a8579010e301ee09bdcb2dae693
    gradcam = GradCAM.from_config(model_type='resnet', arch=target_model, layer_name='layer4')
    img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)[None] #
    mask, logit = gradcam(img)
    heatmap, cam_result = visualize_cam(mask, img)
    return heatmap, cam_result
   


Beispiel #8
0
resnet.load_state_dict(resnet101_dict)	#加载参数
resnet.eval(), resnet.cuda();
###


cam_dict = dict()

resnet_model_dict = dict(type='resnet', arch=resnet, layer_name='layer4', input_size=(224, 224))
resnet_gradcam = GradCAM(resnet_model_dict, True)
resnet_gradcampp = GradCAMpp(resnet_model_dict, True)
cam_dict['resnet'] = [resnet_gradcam, resnet_gradcampp]

images = []
for gradcam, gradcam_pp in cam_dict.values():
    mask, _ = gradcam(normed_torch_img)
    heatmap, result = visualize_cam(mask.cpu(), torch_img.cpu())

    mask_pp, _ = gradcam_pp(normed_torch_img)
    heatmap_pp, result_pp = visualize_cam(mask_pp.cpu(), torch_img.cpu())

    images.append(torch.stack([torch_img.squeeze().cpu(), heatmap, heatmap_pp, result, result_pp], 0))


# images = make_grid(torch.cat(images, 0), nrow=5)


output_dir = 'outputs'


os.makedirs(output_dir, exist_ok=True)
output_name = img_name