Example #1
0
def main():
	args = get_args()
	root_dir = args.root_dir
	imgs = list(os.walk(root_dir))[0][2]

	save_dir = args.save_dir
	num_classes = 100 # CIFAR100
	model = ResNet.resnet(arch='resnet50', pretrained=False, num_classes=num_classes,
		use_att=args.use_att, att_mode=args.att_mode)
	#model = nn.DataParallel(model)
	#print(model)

	if args.resume:
		if os.path.isfile(args.resume):
			print(f'=> loading checkpoint {args.resume}')
			checkpoint = torch.load(args.resume)
			best_acc5 = checkpoint['best_acc5']
			model.load_state_dict(checkpoint['state_dict'], strict=False)
			print(f"=> loaded checkpoint {args.resume} (epoch {checkpoint['epoch']})")
			print(f'=> best accuracy {best_acc5}')
		else:
			print(f'=> no checkpoint found at {args.resume}')

	model_dict = get_model_dict(model, args.type)
	normalizer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

	for img_name in imgs:
		img_path = os.path.join(root_dir, img_name)
		pil_img = PIL.Image.open(img_path)
	
		torch_img = torch.from_numpy(np.asarray(pil_img))
		torch_img = torch_img.permute(2, 0, 1).unsqueeze(0)
		torch_img = torch_img.float().div(255)
		torch_img = F.interpolate(torch_img, size=(224, 224), mode='bilinear', align_corners=False)

		normalized_torch_img = normalizer(torch_img)

		gradcam = GradCAM(model_dict, True)
		gradcam_pp = GradCAMpp(model_dict, True)

		mask, _ = gradcam(normalized_torch_img)
		heatmap, result = visualize_cam(mask, torch_img)

		mask_pp, _ = gradcam_pp(normalized_torch_img)
		heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)
		
		images = torch.stack([torch_img.squeeze().cpu(), heatmap, heatmap_pp, result, result_pp], 0)

		images = make_grid(images, nrow=1)

		if args.use_att:
			save_dir = os.path.join(args.save_dir, 'att')
		else:
			save_dir = os.path.join(args.save_dir, 'no_att')

		os.makedirs(save_dir, exist_ok=True)
		output_name = img_name
		output_path = os.path.join(save_dir, output_name)

		save_image(images, output_path)
def main():
    device = utils.get_device()
    
    # load trained model
    net = mobilenet_v2(task = 'classification', moco = False, ctonly = False).to(device)
    state_dict = torch.load("./model.pth")
    net.load_state_dict(state_dict)
    net.eval() # eval mode

    # load GradCAM
    model_dict = dict(type='mobilenet', arch=net, layer_name='features_18', input_size=(480, 480))
    gradcampp = GradCAMpp(model_dict, True)

    # load image
    raw_images = ['CT_COVID/2020.01.24.919183-p27-132.png',
                  'CT_COVID/2020.02.22.20024927-p18-66%2.png',
                  'CT_COVID/2020.02.26.20026989-p34-114_1%1.png',
                  'CT_NonCOVID/673.png',
                  'CT_NonCOVID/1262.png',
                  'CT_NonCOVID/40%1.jpg']

    images = []    
    for raw_image in raw_images:
        ctimagename = os.path.join(args.datapath, 'Images-processed', raw_image)
        lungimagename = os.path.join(args.datapath, 'lung_segmentation', raw_image)
        ctimage = Image.open(ctimagename).convert('L')
        lungimage = Image.open(lungimagename).convert('L')
        ctimage = ctprocessing(ctimage)
        ctimage = torch.from_numpy(ctimage).float().unsqueeze(0).unsqueeze(0).to(device)

        lungimage = np.asarray(lungimage)
        lungimage = lungimage.astype(np.float32)/255.
        lungimage = skt.resize(lungimage, (480,480), mode='constant', anti_aliasing=False) # resize
        lungimage = torch.from_numpy(lungimage).float().unsqueeze(0).unsqueeze(0).to(device)

        mask_pp, _ = gradcampp(ctimage, lungimage)
        heatmap_pp, result_pp = visualize_cam(mask_pp.cpu(), ctimage)
        images.append(
            torch.stack([
                ctimage.squeeze().cpu().unsqueeze(0).expand(3,-1,-1), \
                lungimage.squeeze().cpu().unsqueeze(0).expand(3,-1,-1), \
                heatmap_pp, result_pp], 0))

    images = make_grid(torch.cat(images, 0), nrow=4) #####

    output_dir = 'gradcam_output'
    os.makedirs(output_dir, exist_ok=True)
    output_name =  'output.png'
    output_path = os.path.join(output_dir, output_name)

    save_image(images, output_path)
    
    result_img = Image.open(output_path)
    result_img.show()
def generate_saliency_map(img, img_name):
    start = time.time()

    normalizer = Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    torch_img = torch.from_numpy(np.asarray(img)).permute(
        2, 0, 1).unsqueeze(0).float().div(255)
    torch_img = F.upsample(torch_img,
                           size=(512, 512),
                           mode='bilinear',
                           align_corners=False)
    normed_torch_img = normalizer(torch_img)

    resnet = models.resnet101(pretrained=True)
    resnet.eval()
    cam_dict = dict()
    model_dict = dict(type='resnet',
                      arch=resnet,
                      layer_name='layer4',
                      input_size=(512, 512))
    gradcam = GradCAM(model_dict, True)
    gradcam_pp = GradCAMpp(model_dict, True)

    images = []

    mask, _ = gradcam(normed_torch_img)
    heatmap, result = visualize_cam(mask, torch_img)
    mask_pp, _ = gradcam_pp(normed_torch_img)
    heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)
    images.append(
        torch.stack([
            torch_img.squeeze().cpu(), heatmap, heatmap_pp, result, result_pp
        ], 0))
    images = make_grid(torch.cat(images, 0), nrow=1)

    # Only going to use result_pp
    output_dir = 'outputs'
    os.makedirs(output_dir, exist_ok=True)
    output_name = img_name
    output_path = os.path.join(output_dir, output_name)
    save_image(result_pp, output_path)

    end = time.time()
    duration = round(end - start, 2)
    return output_path
Example #4
0
def Grad_Cam(model, train_datasets):
  device = torch.device("cuda:0" if torch.cuda.is_available()  else "cpu")
  model.eval()
  target_layer = model.features
  gradcam = GradCAM(model, target_layer)
  gradcam_pp = GradCAMpp(model, target_layer)
  images = []
  for i in range(10):
      index = random.randint(0, 212)
      first_inputs, _ = train_datasets.__getitem__(index)
      inputs = first_inputs.to(device).unsqueeze(0)
      mask, _ = gradcam(inputs)
      heatmap, result = visualize_cam(mask, first_inputs)

      mask_pp, _ = gradcam_pp(inputs)
      heatmap_pp, result_pp = visualize_cam(mask_pp, first_inputs)

      images.extend([first_inputs.cpu(), heatmap, heatmap_pp, result, result_pp])
  grid_image = make_grid(images, nrow=5)

  return transforms.ToPILImage()(grid_image)
Example #5
0
def make_plot_and_save(input_img,
                       img_name,
                       no_norm_image,
                       segm,
                       model,
                       train_or_val,
                       epoch=None,
                       vis_prefix=None):
    global is_server
    # get Grad-CAM results and prepare them to show on the plot
    target_layer = model.layer4
    gradcam = GradCAM(model, target_layer=target_layer)
    gradcam_pp = GradCAMpp(model, target_layer=target_layer)

    # sam_output shapes:
    # [1, 1, 56, 56]x3 , [1, 1, 28, 28]x4 [1, 1, 14, 14]x6 , [1, 1, 7, 7]x3
    mask, no_norm_mask, logit, sam_output = gradcam(input_img)

    sam1_show = torch.squeeze(sam_output[0].cpu()).detach().numpy()
    sam4_show = torch.squeeze(sam_output[3].cpu()).detach().numpy()
    sam8_show = torch.squeeze(sam_output[7].cpu()).detach().numpy()
    sam14_show = torch.squeeze(sam_output[13].cpu()).detach().numpy()

    heatmap, result = visualize_cam(mask, no_norm_image)

    result_show = np.moveaxis(torch.squeeze(result).detach().numpy(), 0, -1)

    mask_pp, no_norm_mask_pp, logit_pp, sam_output_pp = gradcam_pp(input_img)
    heatmap_pp, result_pp = visualize_cam(mask_pp, no_norm_image)

    result_pp_show = np.moveaxis(
        torch.squeeze(result_pp).detach().numpy(), 0, -1)

    # prepare mask and original image to show on the plot
    segm_show = torch.squeeze(segm.cpu()).detach().numpy()
    segm_show = np.moveaxis(segm_show, 0, 2)
    input_show = np.moveaxis(
        torch.squeeze(no_norm_image).detach().numpy(), 0, -1)

    # draw and save the plot
    plt.close('all')
    fig, axs = plt.subplots(nrows=2, ncols=6, figsize=(24, 9))
    plt.suptitle(f'{train_or_val}-Image: {img_name}')
    axs[1][0].imshow(segm_show)
    axs[1][0].set_title('Mask')
    axs[0][0].imshow(input_show)
    axs[0][0].set_title('Original Image')

    axs[0][1].imshow(result_show)
    axs[0][1].set_title('Grad-CAM')
    axs[1][1].imshow(result_pp_show)
    axs[1][1].set_title('Grad-CAM++')

    axs[1][2].imshow(sam1_show, cmap='gray')
    axs[1][2].set_title('SAM-1 relative')
    axs[0][2].imshow(sam1_show, vmin=0., vmax=1., cmap='gray')
    axs[0][2].set_title('SAM-1 absolute')

    axs[1][3].imshow(sam4_show, cmap='gray')
    axs[1][3].set_title('SAM-4 relative')
    axs[0][3].imshow(sam4_show, vmin=0., vmax=1., cmap='gray')
    axs[0][3].set_title('SAM-4 absolute')

    axs[1][4].imshow(sam8_show, cmap='gray')
    axs[1][4].set_title('SAM-8 relative')
    axs[0][4].imshow(sam8_show, vmin=0., vmax=1., cmap='gray')
    axs[0][4].set_title('SAM-8 absolute')

    axs[1][5].imshow(sam14_show, cmap='gray')
    axs[1][5].set_title('SAM-14 relative')
    axs[0][5].imshow(sam14_show, vmin=0., vmax=1., cmap='gray')
    axs[0][5].set_title('SAM-14 absolute')
    plt.show()
    if vis_prefix is not None:
        plt.savefig(f'vis/{vis_prefix}/{train_or_val}/{img_name}.png',
                    bbox_inches='tight')
    if is_server:
        if epoch is not None:
            wandb.log({f'{train_or_val}/{img_name}': fig}, step=epoch)
        else:
            wandb.log({f'{train_or_val}/{img_name}': fig})
Example #6
0
# Lets got ahead and run the network and get back the saliency map

# In[11]:

csmap, smaps, _ = get_salmap(in_tensor)

# ## Run With Grad-CAM or Grad-CAM++

# Let's go ahead and push our network model into the Grad-CAM library.
#
# **NOTE** much of this code is borrowed from the Pytorch GradCAM package.

# In[12]:

resnet_gradcampp4 = GradCAMpp.from_config(model_type='resnet',
                                          arch=model,
                                          layer_name='layer4')

# Let's get our original input image back. We will just use this one for visualization.

# In[13]:

raw_tensor = misc.LoadImageToTensor(load_image_name, device, norm=False)
raw_tensor = F.interpolate(raw_tensor,
                           size=(in_height, in_width),
                           mode='bilinear',
                           align_corners=False)

# We create an object to get back the mask of the saliency map

# In[14]:
    def eval(self, gradcam=False, rise=False, test_on_val=False):
        """The function for the meta-eval phase."""
        # Load the logs
        if os.path.exists(osp.join(self.args.save_path, 'trlog')):
            trlog = torch.load(osp.join(self.args.save_path, 'trlog'))
        else:
            trlog = None

        torch.manual_seed(1)
        np.random.seed(1)
        # Load meta-test set
        test_set = Dataset('val' if test_on_val else 'test', self.args)
        sampler = CategoriesSampler(test_set.label, 600, self.args.way,
                                    self.args.shot + self.args.val_query)
        loader = DataLoader(test_set,
                            batch_sampler=sampler,
                            num_workers=8,
                            pin_memory=True)

        # Set test accuracy recorder
        test_acc_record = np.zeros((600, ))

        # Load model for meta-test phase
        if self.args.eval_weights is not None:
            weights = self.addOrRemoveModule(
                self.model,
                torch.load(self.args.eval_weights)['params'])
            self.model.load_state_dict(weights)
        else:
            self.model.load_state_dict(
                torch.load(osp.join(self.args.save_path,
                                    'max_acc' + '.pth'))['params'])
        # Set model to eval mode
        self.model.eval()

        # Set accuracy averager
        ave_acc = Averager()

        # Generate labels
        label = torch.arange(self.args.way).repeat(self.args.val_query)
        if torch.cuda.is_available():
            label = label.type(torch.cuda.LongTensor)
        else:
            label = label.type(torch.LongTensor)
        label_shot = torch.arange(self.args.way).repeat(self.args.shot)
        if torch.cuda.is_available():
            label_shot = label_shot.type(torch.cuda.LongTensor)
        else:
            label_shot = label_shot.type(torch.LongTensor)

        if gradcam:
            self.model.layer3 = self.model.encoder.layer3
            model_dict = dict(type="resnet",
                              arch=self.model,
                              layer_name='layer3')
            grad_cam = GradCAM(model_dict, True)
            grad_cam_pp = GradCAMpp(model_dict, True)
            self.model.features = self.model.encoder
            guided = GuidedBackprop(self.model)
        if rise:
            self.model.layer3 = self.model.encoder.layer3
            score_mod = ScoreCam(self.model)

        # Start meta-test
        for i, batch in enumerate(loader, 1):
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
            k = self.args.way * self.args.shot
            data_shot, data_query = data[:k], data[k:]

            if i % 5 == 0:
                suff = "_val" if test_on_val else ""

                if self.args.rep_vec or self.args.cross_att:
                    print('batch {}: {:.2f}({:.2f})'.format(
                        i,
                        ave_acc.item() * 100, acc * 100))

                    if self.args.cross_att:
                        label_one_hot = self.one_hot(label).to(label.device)
                        _, _, logits, simMapQuer, simMapShot, normQuer, normShot = self.model(
                            (data_shot, label_shot, data_query),
                            ytest=label_one_hot,
                            retSimMap=True)
                    else:
                        logits, simMapQuer, simMapShot, normQuer, normShot, fast_weights = self.model(
                            (data_shot, label_shot, data_query),
                            retSimMap=True)

                    torch.save(
                        simMapQuer,
                        "../results/{}/{}_simMapQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        simMapShot,
                        "../results/{}/{}_simMapShot{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        data_query, "../results/{}/{}_dataQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        data_shot, "../results/{}/{}_dataShot{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        normQuer, "../results/{}/{}_normQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        normShot, "../results/{}/{}_normShot{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                else:
                    logits, normQuer, normShot, fast_weights = self.model(
                        (data_shot, label_shot, data_query),
                        retFastW=True,
                        retNorm=True)
                    torch.save(
                        normQuer, "../results/{}/{}_normQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        normShot, "../results/{}/{}_normShot{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))

                if gradcam:
                    print("Saving gradmaps", i)
                    allMasks, allMasks_pp, allMaps = [], [], []
                    for l in range(len(data_query)):
                        allMasks.append(
                            grad_cam(data_query[l:l + 1], fast_weights, None))
                        allMasks_pp.append(
                            grad_cam_pp(data_query[l:l + 1], fast_weights,
                                        None))
                        allMaps.append(
                            guided.generate_gradients(data_query[l:l + 1],
                                                      fast_weights))
                    allMasks = torch.cat(allMasks, dim=0)
                    allMasks_pp = torch.cat(allMasks_pp, dim=0)
                    allMaps = torch.cat(allMaps, dim=0)

                    torch.save(
                        allMasks, "../results/{}/{}_gradcamQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        allMasks_pp,
                        "../results/{}/{}_gradcamppQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))
                    torch.save(
                        allMaps, "../results/{}/{}_guidedQuer{}{}.th".format(
                            self.args.exp_id, self.args.model_id, i, suff))

                if rise:
                    print("Saving risemaps", i)
                    allScore = []
                    for l in range(len(data_query)):
                        allScore.append(
                            score_mod(data_query[l:l + 1], fast_weights))

            else:
                if self.args.cross_att:
                    label_one_hot = self.one_hot(label).to(label.device)
                    _, _, logits = self.model(
                        (data_shot, label_shot, data_query),
                        ytest=label_one_hot)
                else:
                    logits = self.model((data_shot, label_shot, data_query))

            acc = count_acc(logits, label)
            ave_acc.add(acc)
            test_acc_record[i - 1] = acc

        # Calculate the confidence interval, update the logs
        m, pm = compute_confidence_interval(test_acc_record)
        if trlog is not None:
            print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format(
                trlog['max_acc_epoch'], trlog['max_acc'], ave_acc.item()))
        print('Test Acc {:.4f} + {:.4f}'.format(m, pm))

        return m
Example #8
0
state_dict = torch.load('./model/2real/extractor_8.pth')	#加载预先训练好net-a的.pth文件

new_state_dict = OrderedDict()		#不是必要的【from collections import OrderedDict】

new_state_dict = {k:v for k,v in state_dict.items() if k in resnet101_dict}	#删除net-b不需要的键
resnet101_dict.update(new_state_dict)	#更新参数
resnet.load_state_dict(resnet101_dict)	#加载参数
resnet.eval(), resnet.cuda();
###


cam_dict = dict()

resnet_model_dict = dict(type='resnet', arch=resnet, layer_name='layer4', input_size=(224, 224))
resnet_gradcam = GradCAM(resnet_model_dict, True)
resnet_gradcampp = GradCAMpp(resnet_model_dict, True)
cam_dict['resnet'] = [resnet_gradcam, resnet_gradcampp]

images = []
for gradcam, gradcam_pp in cam_dict.values():
    mask, _ = gradcam(normed_torch_img)
    heatmap, result = visualize_cam(mask.cpu(), torch_img.cpu())

    mask_pp, _ = gradcam_pp(normed_torch_img)
    heatmap_pp, result_pp = visualize_cam(mask_pp.cpu(), torch_img.cpu())

    images.append(torch.stack([torch_img.squeeze().cpu(), heatmap, heatmap_pp, result, result_pp], 0))


# images = make_grid(torch.cat(images, 0), nrow=5)